10000 [ONNX] Use onnxscript apis for 2.7 (#148453) · pytorch/pytorch@c6a05df · GitHub
[go: up one dir, main page]

Skip to content

Commit c6a05df

Browse files
justinchubypytorchmergebot
authored andcommitted
[ONNX] Use onnxscript apis for 2.7 (#148453)
Use onnxscript apis for 2.7. Remove reference to `torchlib_opset()` and `torchlib_opset_version()` which were removed in the onnxscript 2.7 apis. These apis were removed because torchlib in onnxscript will always stay on opset 18. Future opset version bumps will happen in pytorch core after the migration of torchlib. Pull Request resolved: #148453 Approved by: https://github.com/titaiwangms, https://github.com/shubhambhokare1
1 parent c9edd37 commit c6a05df

File tree

5 files changed

+11
-12
lines changed

5 files changed

+11
-12
lines changed

torch/onnx/_internal/_exporter_legacy.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from torch.onnx._internal import io_adapter
3131
from torch.onnx._internal._lazy_import import onnxscript_apis, onnxscript_ir as ir
3232
from torch.onnx._internal.diagnostics import infra
33-
from torch.onnx._internal.exporter import _onnx_program
33+
from torch.onnx._internal.exporter import _constants, _onnx_program
3434
from torch.onnx._internal.fx import (
3535
decomposition_table,
3636
patcher as patcher,
@@ -105,7 +105,7 @@ def __init__(self) -> None:
105105
defaultdict(list)
106106
)
107107

108-
self._opset_version = onnxscript_apis.torchlib_opset_version()
108+
self._opset_version = _constants.TORCHLIB_OPSET
109109
warnings.warn(
110110
f"torch.onnx.dynamo_export only implements opset version {self._opset_version} for now. If you need to use a "
111111
"different opset version, please register them with register_custom_op."

torch/onnx/_internal/_lazy_import.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,12 @@ def __getattr__(self, attr: str) -> object:
2929
if TYPE_CHECKING:
3030
import onnx
3131
import onnxscript
32-
import onnxscript._framework_apis.torch_2_6 as onnxscript_apis
32+
import onnxscript._framework_apis.torch_2_7 as onnxscript_apis
3333

3434
onnxscript_ir = onnxscript.ir
3535

3636
else:
3737
onnx = _LazyModule("onnx")
3838
onnxscript = _LazyModule("onnxscript")
3939
onnxscript_ir = _LazyModule("onnxscript.ir")
40-
onnxscript_apis = _LazyModule("onnxscript._framework_apis.torch_2_6")
40+
onnxscript_apis = _LazyModule("onnxscript._framework_apis.torch_2_7")

torch/onnx/_internal/exporter/_compat.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import torch
1313
from torch.onnx._internal._lazy_import import onnxscript_apis, onnxscript_ir as ir
1414
from torch.onnx._internal.exporter import (
15+
_constants,
1516
_core,
1617
_dynamic_shapes,
1718
_onnx_program,
@@ -67,7 +68,7 @@ def export_compat(
6768
fallback: bool = False,
6869
) -> _onnx_program.ONNXProgram:
6970
if opset_version is None:
70-
opset_version = onnxscript_apis.torchlib_opset_version()
71+
opset_version = _constants.TORCHLIB_OPSET
7172

7273
if isinstance(model, torch.export.ExportedProgram):
7374
# We know the model is already exported program, so the args, kwargs, and dynamic_shapes

torch/onnx/_internal/exporter/_registration.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
import torch
2525
import torch._ops
2626
from torch.onnx._internal._lazy_import import onnxscript, onnxscript_apis
27-
from torch.onnx._internal.exporter import _schemas
27+
from torch.onnx._internal.exporter import _constants, _schemas
2828
from torch.onnx._internal.exporter._torchlib import _torchlib_registry
2929

3030

@@ -141,7 +141,7 @@ class ONNXRegistry:
141141

142142
def __init__(self) -> None:
143143
"""Initializes the registry"""
144-
self._opset_version = onnxscript_apis.torchlib_opset_version()
144+
self._opset_version = _constants.TORCHLIB_OPSET
145145
self.functions: dict[TorchOp | str, list[OnnxDecompMeta]] = {}
146146

147147
@property

torch/onnx/_internal/fx/fx_onnx_interpreter.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
import torch
1616
import torch.fx
1717
from torch.onnx import _type_utils as jit_type_utils
18-
from torch.onnx._internal._lazy_import import onnxscript_apis
1918
from torch.onnx._internal.fx import (
2019
_pass,
2120
diagnostics,
@@ -101,6 +100,7 @@ def _retrieve_or_adapt_input_to_graph_set(
101100
When creating TorchScript graph from FX graph, we need a mapping from FX variable
102101
to TorchScript variable. This function maps FX variable, fx_node_arg, to torch.jit.Value.
103102
"""
103+
from onnxscript import opset18 as op
104104

105105
onnx_tensor = fx_node_arg
106106
if isinstance(onnx_tensor, torch.fx.Node):
@@ -149,7 +149,7 @@ def _retrieve_or_adapt_input_to_graph_set(
149149
# Since tensors with rank=0 (i.e., scalar) cannot be concated, all
150150
# scalars are promoted to tensors with shape (1,).
151151
with onnxscript.evaluator.default_as(tracer):
152-
element_value = onnxscript_apis.torchlib_opset().Reshape(
152+
element_value = op.Reshape(
153153
element_value, # type: ignore[arg-type, type-var]
154154
[1], # type: ignore[arg-type, type-var]
155155
)
@@ -173,9 +173,7 @@ def _retrieve_or_adapt_input_to_graph_set(
173173
# onnx-script auto wraps python number with op.Constants,
174174
# so we don't need to specifically process them.
175175
with onnxscript.evaluator.default_as(tracer):
176-
output = onnxscript_apis.torchlib_opset().Concat(
177-
*sequence_mixed_elements, axis=0
178-
) # type: ignore[type-var]
176+
output = op.Concat(*sequence_mixed_elements, axis=0) # type: ignore[type-var]
179177
output.dtype = torch.int64 # type: ignore[union-attr]
180178
output.shape = [len(sequence_mixed_elements)] # type: ignore[union-attr]
181179
return output

0 commit comments

Comments
 (0)
0