8000 [ONNX] Use onnxscript apis for 2.7 by justinchuby · Pull Request #148453 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[ONNX] Use onnxscript apis for 2.7 #148453

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions torch/onnx/_internal/_exporter_legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from torch.onnx._internal import io_adapter
from torch.onnx._internal._lazy_import import onnxscript_apis, onnxscript_ir as ir
from torch.onnx._internal.diagnostics import infra
from torch.onnx._internal.exporter import _onnx_program
from torch.onnx._internal.exporter import _constants, _onnx_program
from torch.onnx._internal.fx import (
decomposition_table,
patcher as patcher,
Expand Down Expand Up @@ -105,7 +105,7 @@ def __init__(self) -> None:
defaultdict(list)
)

self._opset_version = onnxscript_apis.torchlib_opset_version()
self._opset_version = _constants.TORCHLIB_OPSET
warnings.warn(
f"torch.onnx.dynamo_export only implements opset version {self._opset_version} for now. If you need to use a "
"different opset version, please register them with register_custom_op."
Expand Down
4 changes: 2 additions & 2 deletions torch/onnx/_internal/_lazy_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,12 @@ def __getattr__(self, attr: str) -> object:
if TYPE_CHECKING:
import onnx
import onnxscript
import onnxscript._framework_apis.torch_2_6 as onnxscript_apis
import onnxscript._framework_apis.torch_2_7 as onnxscript_apis

onnxscript_ir = onnxscript.ir

else:
onnx = _LazyModule("onnx")
onnxscript = _LazyModule("onnxscript")
onnxscript_ir = _LazyModule("onnxscript.ir")
onnxscript_apis = _LazyModule("onnxscript._framework_apis.torch_2_6")
onnxscript_apis = _LazyModule("onnxscript._framework_apis.torch_2_7")
3 changes: 2 additions & 1 deletion torch/onnx/_internal/exporter/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import torch
from torch.onnx._internal._lazy_import import onnxscript_apis, onnxscript_ir as ir
from torch.onnx._internal.exporter import (
_constants,
_core,
_dynamic_shapes,
_onnx_program,
Expand Down Expand Up @@ -67,7 +68,7 @@ def export_compat(
fallback: bool = False,
) -> _onnx_program.ONNXProgram:
if opset_version is None:
opset_version = onnxscript_apis.torchlib_opset_version()
opset_version = _constants.TORCHLIB_OPSET

if isinstance(model, torch.export.ExportedProgram):
# We know the model is already exported program, so the args, kwargs, and dynamic_shapes
Expand Down
4 changes: 2 additions & 2 deletions torch/onnx/_internal/exporter/_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import torch
import torch._ops
from torch.onnx._internal._lazy_import import onnxscript, onnxscript_apis
from torch.onnx._internal.exporter import _schemas
from torch.onnx._internal.exporter import _constants, _schemas
from torch.onnx._internal.exporter._torchlib import _torchlib_registry


Expand Down Expand Up @@ -141,7 +141,7 @@ class ONNXRegistry:

def __init__(self) -> None:
"""Initializes the registry"""
self._opset_version = onnxscript_apis.torchlib_opset_version()
self._opset_version = _constants.TORCHLIB_OPSET
self.functions: dict[TorchOp | str, list[OnnxDecompMeta]] = {}

@property
Expand Down
8000
8 changes: 3 additions & 5 deletions torch/onnx/_internal/fx/fx_onnx_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import torch
import torch.fx
from torch.onnx import _type_utils as jit_type_utils
from torch.onnx._internal._lazy_import import onnxscript_apis
from torch.onnx._internal.fx import (
_pass,
diagnostics,
Expand Down Expand Up @@ -101,6 +100,7 @@ def _retrieve_or_adapt_input_to_graph_set(
When creating TorchScript graph from FX graph, we need a mapping from FX variable
to TorchScript variable. This function maps FX variable, fx_node_arg, to torch.jit.Value.
"""
from onnxscript import opset18 as op

onnx_tensor = fx_node_arg
if isinstance(onnx_tensor, torch.fx.Node):
Expand Down Expand Up @@ -149,7 +149,7 @@ def _retrieve_or_adapt_input_to_graph_set(
# Since tensors with rank=0 (i.e., scalar) cannot be concated, all
# scalars are promoted to tensors with shape (1,).
with onnxscript.evaluator.default_as(tracer):
element_value = onnxscript_apis.torchlib_opset().Reshape(
element_value = op.Reshape(
element_value, # type: ignore[arg-type, type-var]
[1], # type: ignore[arg-type, type-var]
)
Expand All @@ -173,9 +173,7 @@ def _retrieve_or_adapt_input_to_graph_set(
# onnx-script auto wraps python number with op.Constants,
# so we don't need to specifically process them.
with onnxscript.evaluator.default_as(tracer):
output = onnxscript_apis.torchlib_opset().Concat(
*sequence_mixed_elements, axis=0
) # type: ignore[type-var]
output = op.Concat(*sequence_mixed_elements, axis=0) # type: ignore[type-var]
output.dtype = torch.int64 # type: ignore[union-attr]
output.shape = [len(sequence_mixed_elements)] # type: ignore[union-attr]
return output
Expand Down
Loading
0