8000 [ONNX] Refactor dispatcher and registry (#147396) · pytorch/pytorch@c46f86e · GitHub
[go: up one dir, main page]

Skip to content

Commit c46f86e

Browse files
justinchubyRyo-not-rio
authored andcommitted
[ONNX] Refactor dispatcher and registry (#147396)
This PR sets up the registry to accept onnx decomp functions to be moved into PyTorch (#139301). The ops from onnx script are currently appended to the registry. When the ops are moved into PyTorch, the moved ops takes precedence because they appear first in the registry list. After the migration hooks for loading ops from onnx script will be removed. 1. Use a private field `_pt_onnx_signature` to store function signatures to avoid conflicts 2. Update the registry to record the signature in OnnxDecompMeta and update the dispatcher to leverage the data structure 3. Update registry to prepare for onnx op registration, and update the the onnx_impl decorator to support a no_compile option Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com> Pull Request resolved: #147396 Approved by: https://github.com/titaiwangms
1 parent 00bb540 commit c46f86e

File tree

6 files changed

+144
-91
lines changed

6 files changed

+144
-91
lines changed

torch/onnx/_internal/exporter/_dispatching.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import logging
55
from collections.abc import Sequence
6-
from typing import Callable
6+
from typing import Any, Callable
77

88
from onnxscript import ir
99

@@ -188,11 +188,11 @@ def _get_type_from_tensor(
188188

189189

190190
def _get_first_tensor_in_node_list(
191-
nodes: Sequence[torch.fx.Node | None],
191+
nodes: Sequence[torch.fx.Node | Any],
192192
) -> torch.Tensor | None:
193193
for node in nodes:
194194
if (
195-
node is not None
195+
isinstance(node, torch.fx.Node)
196196
and "val" in node.meta
197197
and isinstance(node.meta["val"], torch.Tensor)
198198
):
@@ -213,13 +213,13 @@ def _get_named_fx_node_args(node: torch.fx.Node) -> dict[str, torch.fx.node.Argu
213213

214214
def get_matching_overload(
215215
node: torch.fx.Node,
216-
overloads: Sequence[Callable],
216+
overloads: Sequence[_registration.OnnxDecompMeta],
217217
) -> tuple[Callable | None, str]:
218218
"""Get the overload that matches the node's arguments.
219219
220220
Args:
221221
node: The node to match.
222-
overloads: The overloads to match against.
222+
overloads: The OnnxDecompMeta with overloads and their signatures to match against.
223223
224224
Returns:
225225
A tuple containing the matched overload and a string describing the reason for failure or success.
@@ -230,7 +230,7 @@ def get_matching_overload(
230230
# now we assume all inputs are named.
231231
return overloads[
232232
0
233-
], "The node target does not have a schema. Return the first one."
233+
].onnx_function, "The node target does not have a schema. Return the first one."
234234
named_args = _get_named_fx_node_args(node)
235235
# FIXME: Handle when we don't know the names of the arguments
236236
schema_args: dict[str, torch.Argument] = {
@@ -241,10 +241,10 @@ def get_matching_overload(
241241
for overload in overloads:
242242
assigned_types: dict[str, ir.TypeProtocol] = {}
243243
fail_reason = ""
244-
if not hasattr(overload, "signature"):
244+
if overload.signature is None:
245245
# When an overload does not have a signature, we assume it is a custom op and should be matched
246246
return (
247-
overload,
247+
overload.onnx_function,
248248
"The overload does not have a signature. Assuming it is a custom op and matching it.",
249249
)
250250
for param in overload.signature:
@@ -266,7 +266,7 @@ def get_matching_overload(
266266
arg = schema_args[param.name].default_value
267267
elif param.has_default():
268268
# Provided in the ONNX op definition
269-
arg = param.default
269+
arg = param.default # type: ignore[assignment]
270270
else:
271271
fail_reason = "Parameter not provided"
272272
break
@@ -297,8 +297,10 @@ def get_matching_overload(
297297
if not _attribute_type_compatible_with_arg(param, arg): # type: ignore[arg-type]
298298
fail_reason = f"Attribute type not compatible with argument: param=`{param}`, arg=`{arg}`"
299299
break
300+
else:
301+
raise TypeError(f"Unknown parameter type: {type(param)}")
300302
if not fail_reason:
301-
return overload, "Successfully matched overload"
303+
return overload.onnx_function, "Successfully matched overload"
302304
else:
303305
failure_messages.append(
304306
f"- Failed to match overload `{overload}`: {fail_reason}"
@@ -357,7 +359,5 @@ def dispatch(
357359
"Fast path: Only one decomposition is defined",
358360
)
359361

360-
overload, message = get_matching_overload(
361-
node, [decomp.onnx_function for decomp in decomp_metas]
362-
)
362+
overload, message = get_matching_overload(node, decomp_metas)
363363
return overload, message

torch/onnx/_internal/exporter/_ir_passes.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
import re
66
from typing import TYPE_CHECKING
77

8-
from torch.onnx._internal._lazy_import import onnxscript_apis, onnxscript_ir as ir
8+
from torch.onnx._internal._lazy_import import onnxscript_ir as ir
9+
from torch.onnx._internal.exporter import _constants
910

1011

1112
if TYPE_CHECKING:
@@ -115,8 +116,7 @@ def _maybe_set_opset_version(
115116
# Already set
116117
return
117118
if domain == _ONNX_DOMAIN:
118-
# Set the default opset version for ONNX operators
119-
opset_imports[domain] = onnxscript_apis.torchlib_opset_version()
119+
opset_imports[domain] = _constants.TORCHLIB_OPSET
120120
return
121121
if version is None:
122122
# We don't know the opset version, so set it to 1

torch/onnx/_internal/exporter/_registration.py

Lines changed: 63 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -33,22 +33,59 @@
3333
logger = logging.getLogger(__name__)
3434

3535

36-
@dataclasses.dataclass(frozen=True)
36+
@dataclasses.dataclass
3737
class OnnxDecompMeta:
3838
"""A wrapper of onnx-script function with additional metadata.
3939
4040
onnx_function: The onnx-script function from torchlib.
4141
fx_target: The PyTorch node callable target.
42+
signature: The ONNX signature of the function. When None, the signature is inferred.
4243
is_custom: Whether the function is a custom function.
4344
is_complex: Whether the function is a function that handles complex valued inputs.
4445
device: The device the function is registered to. If None, it is registered to all devices.
46+
skip_signature_inference: Whether to skip signature inference for the function.
4547
"""
4648

4749
onnx_function: Callable
4850
fx_target: TorchOp
51+
signature: _schemas.OpSignature | None
4952
is_custom: bool = False
5053
is_complex: bool = False
5154
device: Literal["cuda", "cpu"] | str | None = None # noqa: PYI051
55+
skip_signature_inference: bool = False
56+
57+
def __post_init__(self) -> None:
58+
if self.signature is None and not self.skip_signature_inference:
59+
try:
60+
if isinstance(self.onnx_function, onnxscript.OnnxFunction):
61+
signature = _schemas.OpSignature.from_function( # type: ignore[attr-defined]
62+
self.onnx_function,
63+
self.onnx_function.function_ir.domain,
64+
self.onnx_function.name,
65+
opset_version=self.onnx_function.opset.version,
66+
)
67+
else:
68+
signature = _schemas.OpSignature.from_function(
69+
self.onnx_function, "__traced", self.onnx_function.__name__
70+
)
71+
except Exception as e:
72+
# Log an warning if the op is custom. Raise exception for builtin ops.
73+
if not self.is_custom:
74+
raise
75+
else:
76+
# When the function is targeting an HOP, for example, it will accept
77+
# functions as arguments and fail to generate an ONNX signature.
78+
# In this case we set signature to None and dispatch to this function always.
79+
logger.warning(
80+
"Failed to infer the signature for function '%s' because '%s'"
81+
"All nodes targeting `%s` will be dispatched to this function",
82+
self.onnx_function,
83+
e,
84+
self.fx_target,
85+
)
86+
else:
87+
self.signature = signature
88+
self.onnx_function._pt_onnx_signature = signature # type: ignore[attr-defined]
5289

5390

5491
def _get_overload(qualified_name: str) -> torch._ops.OpOverload | None:
@@ -120,57 +157,34 @@ def from_torchlib(cls) -> ONNXRegistry:
120157
torchlib_registry: The torchlib registry to use for populating the registry.
121158
"""
122159
registry = cls()
160+
for meta in _torchlib_registry.get_torchlib_ops():
161+
registry._register(meta.fx_target, meta)
123162

163+
# TODO(justinchuby): Remove this once torchlib is migrated to PyTorch
124164
torchlib_ops = onnxscript_apis.get_torchlib_ops()
125165

126-
for meta in torchlib_ops:
127-
qualified_name = meta.qualified_name
128-
overload_func = meta.function
129-
domain = meta.domain
130-
name = meta.name
166+
for torchlib_meta in torchlib_ops:
167+
qualified_name = torchlib_meta.qualified_name
168+
overload_func = torchlib_meta.function
131169
try:
132170
# NOTE: This is heavily guarded with try-except because we don't want
133171
# to fail the entire registry population if one function fails.
134172
target = _get_overload(qualified_name)
135173
if target is None:
136174
continue
137175

138-
if isinstance(overload_func, onnxscript.OnnxFunction):
139-
opset_version = overload_func.opset.version
140-
else:
141-
opset_version = 1
142-
143-
overload_func.signature = _schemas.OpSignature.from_function( # type: ignore[attr-defined]
144-
overload_func,
145-
domain,
146-
name,
147-
opset_version=opset_version,
148-
)
149-
onnx_decomposition = OnnxDecompMeta(
176+
meta = OnnxDecompMeta(
150177
onnx_function=overload_func,
151178
fx_target=target,
179+
signature=None,
152180
is_custom=False,
153-
is_complex=meta.is_complex,
181+
is_complex=torchlib_meta.is_complex,
154182
)
155-
registry._register(target, onnx_decomposition)
183+
registry._register(target, meta)
156184
except Exception:
157185
logger.exception("Failed to register '%s'. Skipped", qualified_name)
158186
continue
159187

160-
# Gather ops from the internal torchlib registry
161-
# TODO(justinchuby): Make this the main registry after torchlib is migrated to PyTorch
162-
# Trigger registration
163-
from torch.onnx._internal.exporter._torchlib import ops
164-
165-
del ops
166-
for target, implementations in _torchlib_registry.registry.items(): # type: ignore[assignment]
167-
for impl in implementations:
168-
onnx_decomposition = OnnxDecompMeta(
169-
onnx_function=impl,
170-
fx_target=target, # type: ignore[arg-type]
171-
)
172-
registry._register(target, onnx_decomposition) # type: ignore[arg-type]
173-
174188
return registry
175189

176190
def _register(
@@ -209,32 +223,23 @@ def register_op(
209223
function: The onnx-script function to register.
210224
is_complex: Whether the function is a function that handles complex valued inputs.
211225
"""
212-
if not hasattr(function, "signature"):
213-
try:
214-
# TODO(justinchuby): Use the op_signature attribute when onnxscript is updated in CI
215-
if isinstance(function, onnxscript.OnnxFunction):
216-
function.signature = _schemas.OpSignature.from_function( # type: ignore[attr-defined]
217-
function,
218-
function.function_ir.domain,
219-
function.name,
220-
opset_version=function.opset.version,
221-
)
222-
else:
223-
function.signature = _schemas.OpSignature.from_function( # type: ignore[attr-defined]
224-
function, "__custom", function.__name__
225-
)
226-
except Exception:
227-
logger.exception(
228-
"Failed to infer the signature for function '%s'", function
229-
)
226+
if isinstance(target, torch._ops.OpOverloadPacket):
227+
raise TypeError(
228+
f"Target '{target}' should be provided as an OpOverload instead of an "
229+
"OpOverloadPacket. You can get the default overload with "
230+
"<op>.default"
231+
)
230232

231-
onnx_decomposition = OnnxDecompMeta(
232-
onnx_function=function,
233-
fx_target=target,
234-
is_custom=True,
235-
is_complex=is_complex,
233+
self._register(
234+
target,
235+
OnnxDecompMeta(
236+
onnx_function=function,
237+
fx_target=target,
238+
signature=None,
239+
is_custom=True,
240+
is_complex=is_complex,
241+
),
236242
)
237-
self._register(target, onnx_decomposition)
238243

239244
def get_decomps(self, target: TorchOp) -> list[OnnxDecompMeta]:
240245
"""Returns a list of OnnxDecompMeta for the given op: torch.ops.<namespace>.<op_name>.<overload>.
Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +0,0 @@
1-

torch/onnx/_internal/exporter/_torchlib/_torchlib_registry.py

Lines changed: 64 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,37 +5,86 @@
55
from __future__ import annotations
66

77

8-
__all__ = ["registry", "onnx_impl"]
8+
__all__ = ["onnx_impl", "get_torchlib_ops"]
99

10-
import collections
11-
from typing import Callable, TypeVar
10+
import logging
11+
from typing import Any, Callable, Sequence, TypeVar
1212

13+
import onnxscript
1314

14-
_T = TypeVar("_T", bound=Callable)
15-
15+
import torch
16+
from torch.onnx._internal.exporter import _constants, _registration
1617

17-
class Registry(collections.UserDict[Callable, list[Callable]]):
18-
"""Registry for aten functions."""
1918

20-
def register(self, target: Callable, impl: Callable) -> None:
21-
"""Register a function."""
19+
_T = TypeVar("_T", bound=Callable)
2220

23-
self.data.setdefault(target, []).append(impl)
21+
logger = logging.getLogger("__name__")
2422

2523

26-
# Default registry
27-
registry = Registry()
24+
_registry: list[_registration.OnnxDecompMeta] = []
2825

2926

3027
def onnx_impl(
31-
target: Callable,
28+
target: _registration.TorchOp | tuple[_registration.TorchOp, ...],
29+
*,
30+
trace_only: bool = False,
31+
complex: bool = False,
32+
no_compile: bool = False,
33+
private: bool = False,
3234
) -> Callable[[_T], _T]:
3335
"""Register an ONNX implementation of a torch op."""
3436

37+
if isinstance(target, torch._ops.OpOverloadPacket):
38+
raise TypeError(
39+
f"Target '{target}' should be provided as an OpOverload instead of an "
40+
"OpOverloadPacket. You can get the default overload with "
41+
"<op>.default"
42+
)
43+
3544
def wrapper(
3645
func: _T,
3746
) -> _T:
38-
registry.register(target, func)
39-
return func
47+
processed_func: Any
48+
if no_compile:
49+
processed_func = func
50+
else:
51+
torchlib_opset = onnxscript.values.Opset(
52+
domain=_constants.TORCHLIB_DOMAIN, version=1
53+
)
54+
55+
if not trace_only:
56+
# Compile the function
57+
processed_func = onnxscript.script(opset=torchlib_opset)(func)
58+
else:
59+
processed_func = onnxscript.TracedOnnxFunction(torchlib_opset, func)
60+
61+
if not private:
62+
# TODO(justinchuby): Simplify the logic and remove the private attribute
63+
# Skip registration if private
64+
if not isinstance(target, Sequence):
65+
targets = (target,)
66+
else:
67+
targets = target # type: ignore[assignment]
68+
69+
for t in targets:
70+
_registry.append(
71+
_registration.OnnxDecompMeta(
72+
onnx_function=processed_func,
73+
fx_target=t,
74+
signature=None,
75+
is_complex=complex,
76+
skip_signature_inference=no_compile,
77+
)
78+
)
79+
return processed_func # type: ignore[return-value]
4080

4181
return wrapper
82+
83+
84+
def get_torchlib_ops() -> tuple[_registration.OnnxDecompMeta, ...]:
85+
# Trigger op registration
86+
from torch.onnx._internal.exporter._torchlib import ops
87+
88+
del ops
89+
assert len(_registry) != 0
90+
return tuple(_registry)

0 commit comments

Comments
 (0)
0