From 594b3eddc9fabd065ca9273c32a0cbd02465f755 Mon Sep 17 00:00:00 2001 From: Tugsbayasgalan Manlaibaatar Date: Sun, 16 Mar 2025 20:16:58 -0700 Subject: [PATCH 1/3] add support for numpy Differential Revision: [D71294355](https://our.internmc.facebook.com/intern/diff/D71294355/) [ghstack-poisoned] --- test/export/test_export.py | 21 +++++++++++++++++++++ torch/_export/non_strict_utils.py | 3 +++ torch/_export/passes/lift_constants_pass.py | 2 +- torch/export/_unlift.py | 4 ++-- 4 files changed, 27 insertions(+), 3 deletions(-) diff --git a/test/export/test_export.py b/test/export/test_export.py index 07fcbb7d9b337..1309f40ccccc1 100755 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -1267,6 +1267,27 @@ def false_fn(x): ) torch.export.export(M(), args) + def test_numpy(self): + class Foo(torch.nn.Module): + def forward(self, x): + a = x.numpy() + return x + x.numpy().sum() + + foo = Foo() + foo(torch.randn(10, 10)) + + non_strict_graph = ( + export(foo, (torch.randn(10, 10),), strict=False) + .run_decompositions({}) + .graph + ) + strict_graph = ( + export(foo, (torch.randn(10, 10),), strict=True) + .run_decompositions({}) + .graph + ) + self.assertEqual(str(non_strict_graph), str(strict_graph)) + def test_cond_int_closure(self): class M(torch.nn.Module): def __init__(self): diff --git a/torch/_export/non_strict_utils.py b/torch/_export/non_strict_utils.py index dc61de00439a6..c21a0f5b70cf0 100644 --- a/torch/_export/non_strict_utils.py +++ b/torch/_export/non_strict_utils.py @@ -660,6 +660,9 @@ def _override(self, func, args, kwargs): # Redirect to torch.select for indexing with symint. if isinstance(args[1], torch.SymInt): return torch.select, [args[0], 0, args[1]], {} + + if func.__name__ == "numpy" and isinstance(args[0], torch.Tensor): + return torch._refs.view_as, [args[0], args[0]], {} return func, args, kwargs def __torch_function__(self, func, types, args=(), kwargs=None): diff --git a/torch/_export/passes/lift_constants_pass.py b/torch/_export/passes/lift_constants_pass.py index 5373e52877493..823b4d5777272 100644 --- a/torch/_export/passes/lift_constants_pass.py +++ b/torch/_export/passes/lift_constants_pass.py @@ -1,7 +1,6 @@ # mypy: allow-untyped-defs import collections import logging -import warnings from typing import Any, Union import torch @@ -19,6 +18,7 @@ ) from torch.fx.graph_module import _get_attr + log = logging.getLogger(__name__) diff --git a/torch/export/_unlift.py b/torch/export/_unlift.py index a4c99db1a2c9f..9c4a1ebb466a7 100644 --- a/torch/export/_unlift.py +++ b/torch/export/_unlift.py @@ -79,8 +79,8 @@ def _unlift_inputs_as_getattr( else: with gm.graph.inserting_after(input_node): # It is fine to ignore this warning because - # it is guaranteed that we will populate this - # attr later. + # it is guaranteed that we will populate this + # attr later. with warnings.catch_warnings(): warnings.simplefilter("ignore") getattr_node = gm.graph.get_attr(lifted_node) From cb6dbd93df9ee07838bb3d97b62b3f53542510a9 Mon Sep 17 00:00:00 2001 From: Tugsbayasgalan Manlaibaatar Date: Mon, 17 Mar 2025 08:51:47 -0700 Subject: [PATCH 2/3] Update on "add support for numpy" Differential Revision: [D71294355](https://our.internmc.facebook.com/intern/diff/D71294355/) [ghstack-poisoned] --- test/export/test_export.py | 17 ++++++++++------- torch/_export/non_strict_utils.py | 6 ++++++ torch/_export/passes/lift_constants_pass.py | 2 +- torch/export/_unlift.py | 4 ++-- 4 files changed, 19 insertions(+), 10 deletions(-) diff --git a/test/export/test_export.py b/test/export/test_export.py index 1309f40ccccc1..a33b66a332dac 100755 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -17,6 +17,7 @@ import torch._dynamo as torchdynamo import torch.nn.functional as F from functorch.experimental.control_flow import cond, map +import numpy as np from torch import Tensor from torch._decomp import decomposition_table from torch._dynamo.test_case import TestCase @@ -1271,7 +1272,7 @@ def test_numpy(self): class Foo(torch.nn.Module): def forward(self, x): a = x.numpy() - return x + x.numpy().sum() + return x + np.sin(x.numpy().sum()) foo = Foo() foo(torch.randn(10, 10)) @@ -1281,12 +1282,14 @@ def forward(self, x): .run_decompositions({}) .graph ) - strict_graph = ( - export(foo, (torch.randn(10, 10),), strict=True) - .run_decompositions({}) - .graph - ) - self.assertEqual(str(non_strict_graph), str(strict_graph)) + # strict_graph = ( + # export(foo, (torch.randn(10, 10),), strict=True) + # .run_decompositions({}) + # .graph + # ) + # g = torch.export._trace._export_to_torch_ir(foo, (torch.randn(10, 10),)) + # print("GRAPH", g.graph) + # self.assertEqual(str(non_strict_graph), str(strict_graph)) def test_cond_int_closure(self): class M(torch.nn.Module): diff --git a/torch/_export/non_strict_utils.py b/torch/_export/non_strict_utils.py index c21a0f5b70cf0..3e4a3e8d6dd43 100644 --- a/torch/_export/non_strict_utils.py +++ b/torch/_export/non_strict_utils.py @@ -625,6 +625,8 @@ class _NonStrictTorchFunctionHandler(torch.overrides.TorchFunctionMode): """ def _override(self, func, args, kwargs): + if "__array__" in str(func): + breakpoint() if torch.distributed.is_available(): from torch.distributed._functional_collectives import ( REDUCE_OP_TO_STR, @@ -663,6 +665,10 @@ def _override(self, func, args, kwargs): if func.__name__ == "numpy" and isinstance(args[0], torch.Tensor): return torch._refs.view_as, [args[0], args[0]], {} + + if func.__name__ == "__array__" and isinstance(args[0], torch.Tensor): + return torch._numpy.ndarray, [args[0]], {} + return func, args, kwargs def __torch_function__(self, func, types, args=(), kwargs=None): diff --git a/torch/_export/passes/lift_constants_pass.py b/torch/_export/passes/lift_constants_pass.py index 823b4d5777272..5373e52877493 100644 --- a/torch/_export/passes/lift_constants_pass.py +++ b/torch/_export/passes/lift_constants_pass.py @@ -1,6 +1,7 @@ # mypy: allow-untyped-defs import collections import logging +import warnings from typing import Any, Union import torch @@ -18,7 +19,6 @@ ) from torch.fx.graph_module import _get_attr - log = logging.getLogger(__name__) diff --git a/torch/export/_unlift.py b/torch/export/_unlift.py index 9c4a1ebb466a7..a4c99db1a2c9f 100644 --- a/torch/export/_unlift.py +++ b/torch/export/_unlift.py @@ -79,8 +79,8 @@ def _unlift_inputs_as_getattr( else: with gm.graph.inserting_after(input_node): # It is fine to ignore this warning because - # it is guaranteed that we will populate this - # attr later. + # it is guaranteed that we will populate this + # attr later. with warnings.catch_warnings(): warnings.simplefilter("ignore") getattr_node = gm.graph.get_attr(lifted_node) From 4846e37eea57b1b6e077c6675bb8fa11a43c78ee Mon Sep 17 00:00:00 2001 From: Tugsbayasgalan Manlaibaatar Date: Mon, 17 Mar 2025 08:57:35 -0700 Subject: [PATCH 3/3] Update on "add support for numpy" Differential Revision: [D71294355](https://our.internmc.facebook.com/intern/diff/D71294355/) [ghstack-poisoned]