8000 [export] Remove torch._export.export (#119095) · pytorch/pytorch@3369881 · GitHub
[go: up one dir, main page]

Skip to content

Commit 3369881

Browse files
angelayifacebook-github-bot
authored andcommitted
[export] Remove torch._export.export (#119095)
Summary: Pull Request resolved: #119095 Test Plan: CI Reviewed By: avikchaudhuri, tugsbayasgalan, ydwu4 Differential Revision: D53316196
1 parent 4f2bf7f commit 3369881

File tree

11 files changed

+23
-53
lines changed

11 files changed

+23
-53
lines changed

.github/ci_commit_pins/xla.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
6a0cb712f6335d6b5996e686ddec4a541e4b6ba5
1+
fba464b199559f61faa720de8bf64cf955cfdce7

docs/source/export.ir_spec.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ Example::
110110
def forward(self, x, y):
111111
return x + y
112112

113-
mod = torch._export.export(MyModule())
113+
mod = torch.export.export(MyModule())
114114
print(mod.graph)
115115

116116
The above is the textual representation of a Graph, with each line being a node.

test/distributed/_tensor/experimental/test_tp_transform.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def test_tp_transform_with_uncovered_op(self):
7272
inputs = (torch.randn(7, 3, requires_grad=False).to(device=self.device_type),)
7373
with torch.no_grad():
7474
res = model(*inputs)
75-
exported_program = torch._export.export(
75+
exported_program = torch.export.export(
7676
model,
7777
inputs,
7878
constraints=None,
@@ -111,7 +111,7 @@ def test_tp_transform_e2e(self):
111111

112112
with torch.inference_mode():
113113
res = model(*inputs)
114-
exported_program = torch._export.export(
114+
exported_program = torch.export.export(
115115
model,
116116
inputs,
117117
constraints=None,
@@ -148,7 +148,7 @@ def test_tp_transform_no_bias(self):
148148

149149
with torch.inference_mode():
150150
res = model(*inputs)
151-
exported_program = torch._export.export(
151+
exported_program = torch.export.export(
152152
model,
153153
inputs,
154154
constraints=None,

test/dynamo/test_export.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@
2121
from torch._dynamo import config
2222
from torch._dynamo.exc import UserError
2323
from torch._dynamo.testing import normalize_gm
24-
from torch._export import dynamic_dim
2524
from torch._higher_order_ops.out_dtype import out_dtype
2625
from torch._subclasses import fake_tensor
26+
from torch.export import dynamic_dim
2727
from torch.fx.experimental.proxy_tensor import make_fx
2828
from torch.fx.experimental.symbolic_shapes import (
2929
ConstraintViolationError,

test/export/test_pass_infra.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def forward(self, x1, x2):
124124
input_tensor1 = torch.tensor(5.0)
125125
input_tensor2 = torch.tensor(6.0)
126126

127-
ep_before = torch._export.export(my_module, (input_tensor1, input_tensor2))
127+
ep_before = torch.export.export(my_module, (input_tensor1, input_tensor2))
128128
from torch.fx.passes.infra.pass_base import PassResult
129129

130130
def modify_input_output_pass(gm):

test/quantization/pt2e/test_duplicate_dq.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from typing import Any, Dict
55

66
import torch
7-
import torch._export as export
7+
from torch._export import capture_pre_autograd_graph
88

99
from torch.ao.quantization.observer import (
1010
HistogramObserver,
@@ -102,7 +102,7 @@ def _test_duplicate_dq(
102102

103103
# program capture
104104
m = copy.deepcopy(m_eager)
105-
m = export.capture_pre_autograd_graph(
105+
m = capture_pre_autograd_graph(
106106
m,
107107
example_inputs,
108108
)

test/quantization/pt2e/test_metadata_porting.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from typing import List
66

77
import torch
8-
import torch._export as export
8+
import torch._export
99
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
1010
from torch.ao.quantization.quantizer import Quantizer
1111
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
@@ -64,7 +64,7 @@ class TestMetaDataPorting(QuantizationTestCase):
6464
def _test_quant_tag_preservation_through_decomp(
6565
self, model, example_inputs, from_node_to_tags
6666
):
67-
ep = export.export(model, example_inputs)
67+
ep = torch.export.export(model, example_inputs)
6868
found_tags = True
6969
not_found_nodes = ""
7070
for from_node, tag in from_node_to_tags.items():
@@ -102,7 +102,7 @@ def _test_metadata_porting(
102102

103103
# program capture
104104
m = copy.deepcopy(m_eager)
105-
m = export.capture_pre_autograd_graph(
105+
m = torch._export.capture_pre_autograd_graph(
106106
m,
107107
example_inputs,
108108
)

test/test_out_dtype_op.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import torch._dynamo
66
import torch._inductor
77
import torch._inductor.decomposition
8-
import torch._export
98
from torch._higher_order_ops.out_dtype import out_dtype
109
from torch.fx.experimental.proxy_tensor import make_fx
1110
from torch.testing._internal.common_utils import (
@@ -62,7 +61,7 @@ def forward(self, x):
6261
weight = torch.randint(-128, 127, (5, 5), dtype=torch.int8)
6362
m = M(weight)
6463
x = torch.randint(-128, 127, (5, 5), dtype=torch.int8)
65-
ep = torch._export.export(
64+
ep = torch.export.export(
6665
m,
6766
(x,),
6867
)
@@ -121,14 +120,15 @@ def f(x, y):
121120
self.assertTrue(torch.allclose(numerical_res, gm(*inp)))
122121

123122
def test_out_dtype_non_functional(self):
124-
def f(x, y):
125-
return out_dtype(
126-
torch.ops.aten.add_.Tensor, torch.int32, x, y
127-
)
123+
class M(torch.nn.Module):
124+
def forward(self, x, y):
125+
return out_dtype(
126+
torch.ops.aten.add_.Tensor, torch.int32, x, y
127+
)
128128

129129
with self.assertRaisesRegex(ValueError, "out_dtype's first argument needs to be a functional operator"):
130-
_ = torch._export.export(
131-
f, (torch.randint(-128, 127, (5, 5), dtype=torch.int8), torch.randint(-128, 127, (5, 5), dtype=torch.int8)),
130+
_ = torch.export.export(
131+
M(), (torch.randint(-128, 127, (5, 5), dtype=torch.int8), torch.randint(-128, 127, (5, 5), dtype=torch.int8)),
132132
)
133133

134134
def test_out_dtype_non_op_overload(self):

torch/_export/__init__.py

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -223,36 +223,6 @@ def _eval(self, mode: bool = True):
223223
return module
224224

225225

226-
def export(
227-
f: Callable,
228-
args: Tuple[Any, ...],
229-
kwargs: Optional[Dict[str, Any]] = None,
230-
constraints: Optional[List[Constraint]] = None,
231-
*,
232-
strict: bool = True,
233-
preserve_module_call_signature: Tuple[str, ...] = (),
234-
) -> ExportedProgram:
235-
from torch.export._trace import _export
236-
warnings.warn("This function is deprecated. Please use torch.export.export instead.")
237-
238-
if constraints is not None:
239-
warnings.warn(
240-
"Using `constraints` to specify dynamic shapes for export is DEPRECATED "
241-
"and will not be supported in the future. "
242-
"Please use `dynamic_shapes` instead (see docs on `torch.export.export`).",
243-
DeprecationWarning,
244-
stacklevel=2,
245-
)
246-
return _export(
247-
f,
248-
args,
249-
kwargs,
250-
constraints,
251-
strict=strict,
252-
preserve_module_call_signature=preserve_module_call_signature,
253-
)
254-
255-
256226
def save(
257227
ep: ExportedProgram,
258228
f: Union[str, os.PathLike, io.BytesIO],

torch/_export/serde/upgrade.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from typing import Tuple, Dict, Optional, List
44

55
import torch
6-
from torch._export import export
6+
from torch.export import export
77
from torch._export.pass_base import _ExportPassBaseDeprecatedDoNotUse
88
from torch._export.pass_infra.node_metadata import NodeMetadata
99
from torch._export.pass_infra.proxy_value import ProxyValue

0 commit comments

Comments
 (0)
0