8000 [export] Remove torch._export.export by angelayi · Pull Request #119095 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[export] Remove torch._export.export #119095

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/ci_commit_pins/xla.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
6a0cb712f6335d6b5996e686ddec4a541e4b6ba5
fba464b199559f61faa720de8bf64cf955cfdce7
2 changes: 1 addition & 1 deletion docs/source/export.ir_spec.rst
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ Example::
def forward(self, x, y):
return x + y

mod = torch._export.export(MyModule())
mod = torch.export.export(MyModule())
print(mod.graph)

The above is the textual representation of a Graph, with each line being a node.
Expand Down
6 changes: 3 additions & 3 deletions test/distributed/_tensor/experimental/test_tp_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def test_tp_transform_with_uncovered_op(self):
inputs = (torch.randn(7, 3, requires_grad=False).to(device=self.device_type),)
with torch.no_grad():
res = model(*inputs)
exported_program = torch._export.export(
exported_program = torch.export.export(
model,
inputs,
constraints=None,
Expand Down Expand Up @@ -111,7 +111,7 @@ def test_tp_transform_e2e(self):

with torch.inference_mode():
res = model(*inputs)
exported_program = torch._export.export(
exported_program = torch.export.export(
model,
inputs,
constraints=None,
Expand Down Expand Up @@ -148,7 +148,7 @@ def test_tp_transform_no_bias(self):

with torch.inference_mode():
res = model(*inputs)
exported_program = torch._export.export(
exported_program = torch.export.export(
model,
inputs,
constraints=None,
Expand Down
2 changes: 1 addition & 1 deletion test/dynamo/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@
from torch._dynamo import config
from torch._dynamo.exc import UserError
from torch._dynamo.testing import normalize_gm
from torch._export import dynamic_dim
from torch._higher_order_ops.out_dtype import out_dtype
from torch._subclasses import fake_tensor
from torch.export import dynamic_dim
from torch.fx.experimental.proxy_tensor import make_fx
from torch.fx.experimental.symbolic_shapes import (
ConstraintViolationError,
Expand Down
2 changes: 1 addition & 1 deletion test/export/test_pass_infra.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def forward(self, x1, x2):
input_tensor1 = torch.tensor(5.0)
input_tensor2 = torch.tensor(6.0)

ep_before = torch._export.export(my_module, (input_tensor1, input_tensor2))
ep_before = torch.export.export(my_module, (input_tensor1, input_tensor2))
from torch.fx.passes.infra.pass_base import PassResult

def modify_input_output_pass(gm):
Expand Down
4 changes: 2 additions & 2 deletions test/quantization/pt2e/test_duplicate_dq.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Any, Dict

import torch
import torch._export as export
from torch._export import capture_pre_autograd_graph

from torch.ao.quantization.observer import (
HistogramObserver,
Expand Down Expand Up @@ -102,7 +102,7 @@ def _test_duplicate_dq(

# program capture
m = copy.deepcopy(m_eager)
m = export.capture_pre_autograd_graph(
m = capture_pre_autograd_graph(
m,
example_inputs,
)
Expand Down
6 changes: 3 additions & 3 deletions test/quantization/pt2e/test_metadata_porting.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import List

import torch
import torch._export as export
import torch._export
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
from torch.ao.quantization.quantizer import Quantizer
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
Expand Down Expand Up @@ -64,7 +64,7 @@ class TestMetaDataPorting(QuantizationTestCase):
def _test_quant_tag_preservation_through_decomp(
self, model, example_inputs, from_node_to_tags
):
ep = export.export(model, example_inputs)
ep = torch.export.export(model, example_inputs)
found_tags = True
not_found_nodes = ""
for from_node, tag in from_node_to_tags.items():
Expand Down Expand Up @@ -102,7 +102,7 @@ def _test_metadata_porting(

# program capture
m = copy.deepcopy(m_eager)
m = export.capture_pre_autograd_graph(
m = torch._export.capture_pre_autograd_graph(
m,
example_inputs,
)
Expand Down
16 changes: 8 additions & 8 deletions test/test_out_dtype_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import torch._dynamo
import torch._inductor
import torch._inductor.decomposition
import torch._export
from torch._higher_order_ops.out_dtype import out_dtype
from torch.fx.experimental.proxy_tensor import make_fx
from torch.testing._internal.common_utils import (
Expand Down Expand Up @@ -62,7 +61,7 @@ def forward(self, x):
weight = torch.randint(-128, 127, (5, 5), dtype=torch.int8)
m = M(weight)
x = torch.randint(-128, 127, (5, 5), dtype=torch.int8)
ep = torch._export.export(
ep = torch.export.export(
m,
(x,),
)
Expand Down Expand Up @@ -121,14 +120,15 @@ def f(x, y):
self.assertTrue(torch.allclose(numerical_res, gm(*inp)))

def test_out_dtype_non_functional(self):
def f(x, y):
return out_dtype(
torch.ops.aten.add_.Tensor, torch.int32, x, y
)
class M(torch.nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

function no longer supported?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not since #117528

def forward(self, x, y):
return out_dtype(
torch.ops.aten.add_.Tensor, torch.int32, x, y
)

10000 with self.assertRaisesRegex(ValueError, "out_dtype's first argument needs to be a functional operator"):
_ = torch._export.export(
f, (torch.randint(-128, 127, (5, 5), dtype=torch.int8), torch.randint(-128, 127, (5, 5), dtype=torch.int8)),
_ = torch.export.export(
M(), (torch.randint(-128, 127, (5, 5), dtype=torch.int8), torch.randint(-128, 127, (5, 5), dtype=torch.int8)),
)

def test_out_dtype_non_op_overload(self):
Expand Down
30 changes: 0 additions & 30 deletions torch/_export/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,36 +223,6 @@ def _eval(self, mode: bool = True):
return module


def export(
f: Callable,
args: Tuple[Any, ...],
kwargs: Optional[Dict[str, Any]] = None,
constraints: Optional[List[Constraint]] = None,
*,
strict: bool = True,
preserve_module_call_signature: Tuple[str, ...] = (),
) -> ExportedProgram:
from torch.export._trace import _export
warnings.warn("This function is deprecated. Please use torch.export.export instead.")

if constraints is not None:
warnings.warn(
"Using `constraints` to specify dynamic shapes for export is DEPRECATED "
"and will not be supported in the future. "
"Please use `dynamic_shapes` instead (see docs on `torch.export.export`).",
DeprecationWarning,
stacklevel=2,
)
return _export(
f,
args,
kwargs,
constraints,
strict=strict,
preserve_module_call_signature=preserve_module_call_signature,
)


def save(
ep: ExportedProgram,
f: Union[str, os.PathLike, io.BytesIO],
Expand Down
2 changes: 1 addition & 1 deletion torch/_export/serde/upgrade.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Tuple, Dict, Optional, List

import torch
from torch._export import export
from torch.export import export
from torch._export.pass_base import _ExportPassBaseDeprecatedDoNotUse
from torch._export.pass_infra.node_metadata import NodeMetadata
from torch._export.pass_infra.proxy_value import ProxyValue
Expand Down
4 changes: 2 additions & 2 deletions torch/_export/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type

import torch

from torch._export import ExportedProgram
from torch._subclasses.fake_tensor import FakeTensor

from torch.export import ExportedProgram
from torch.utils._pytree import (
_register_pytree_node,
Context,
Expand Down
0