8000 [ONNX] Refactor dispatcher and registry by justinchuby · Pull Request #147396 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[ONNX] Refactor dispatcher and registry #147396

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 13 additions & 13 deletions torch/onnx/_internal/exporter/_dispatching.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import logging
from collections.abc import Sequence
from typing import Callable
from typing import Any, Callable

from onnxscript import ir

Expand Down Expand Up @@ -188,11 +188,11 @@ def _get_type_from_tensor(


def _get_first_tensor_in_node_list(
nodes: Sequence[torch.fx.Node | None],
nodes: Sequence[torch.fx.Node | Any],
) -> torch.Tensor | None:
for node in nodes:
if (
node is not None
isinstance(node, torch.fx.Node)
and "val" in node.meta
and isinstance(node.meta["val"], torch.Tensor)
):
Expand All @@ -213,13 +213,13 @@ def _get_named_fx_node_args(node: torch.fx.Node) -> dict[str, torch.fx.node.Argu

def get_matching_overload(
node: torch.fx.Node,
overloads: Sequence[Callable],
overloads: Sequence[_registration.OnnxDecompMeta],
) -> tuple[Callable | None, str]:
"""Get the overload that matches the node's arguments.

Args:
node: The node to match.
overloads: The overloads to match against.
overloads: The OnnxDecompMeta with overloads and their signatures to match against.

Returns:
A tuple containing the matched overload and a string describing the reason for failure or success.
Expand All @@ -230,7 +230,7 @@ def get_matching_overload(
# now we assume all inputs are named.
return overloads[
0
], "The node target does not have a schema. Return the first one."
].onnx_function, "The node target does not have a schema. Return the first one."
named_args = _get_named_fx_node_args(node)
# FIXME: Handle when we don't know the names of the arguments
schema_args: dict[str, torch.Argument] = {
Expand All @@ -241,10 +241,10 @@ def get_matching_overload(
for overload in overloads:
assigned_types: dict[str, ir.TypeProtocol] = {}
fail_reason = ""
if not hasattr(overload, "signature"):
if overload.signature is None:
# When an overload does not have a signature, we assume it is a custom op and should be matched
return (
overload,
overload.onnx_function,
"The overload does not have a signature. Assuming it is a custom op and matching it.",
)
for param in overload.signature:
Expand All @@ -266,7 +266,7 @@ def get_matching_overload(
arg = schema_args[param.name].default_value
elif param.has_default():
# Provided in the ONNX op definition
arg = param.default
arg = param.default # type: ignore[assignment]
else:
fail_reason = "Parameter not provided"
break
Expand Down Expand Up @@ -297,8 +297,10 @@ def get_matching_overload(
if not _attribute_type_compatible_with_arg(param, arg): # type: ignore[arg-type]
fail_reason = f"Attribute type not compatible with argument: param=`{param}`, arg=`{arg}`"
break
else:
raise TypeError(f"Unknown parameter type: {type(param)}")
if not fail_reason:
return overload, "Successfully matched overload"
return overload.onnx_function, "Successfully matched overload"
else:
failure_messages.append(
f"- Failed to match overload `{overload}`: {fail_reason}"
Expand Down Expand Up @@ -357,7 +359,5 @@ def dispatch(
"Fast path: Only one decomposition is defined",
)

overload, message = get_matching_overload(
node, [decomp.onnx_function for decomp in decomp_metas]
)
overload, message = get_matching_overload(node, decomp_metas)
return overload, message
6 changes: 3 additions & 3 deletions torch/onnx/_internal/exporter/_ir_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import re
from typing import TYPE_CHECKING

from torch.onnx._internal._lazy_import import onnxscript_apis, onnxscript_ir as ir
from torch.onnx._internal._lazy_import import onnxscript_ir as ir
from torch.onnx._internal.exporter import _constants


if TYPE_CHECKING:
Expand Down Expand Up @@ -115,8 +116,7 @@ def _maybe_set_opset_version(
# Already set
return
if domain == _ONNX_DOMAIN:
# Set the default opset version for ONNX operators
opset_imports[domain] = onnxscript_apis.torchlib_opset_version()
opset_imports[domain] = _constants.TORCHLIB_OPSET
return
if version is None:
# We don't know the opset version, so set it to 1
Expand Down
121 changes: 63 additions & 58 deletions torch/onnx/_internal/exporter/_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,22 +33,59 @@
logger = logging.getLogger(__name__)


@dataclasses.dataclass(frozen=True)
@dataclasses.dataclass
class OnnxDecompMeta:
"""A wrapper of onnx-script function with additional metadata.

onnx_function: The onnx-script function from torchlib.
fx_target: The PyTorch node callable target.
signature: The ONNX signature of the function. When None, the signature is inferred.
is_custom: Whether the function is a custom function.
is_complex: Whether the function is a function that handles complex valued inputs.
device: The device the function is registered to. If None, it is registered to all devices.
skip_signature_inference: Whether to skip signature inference for the function.
"""

onnx_function: Callable
fx_target: TorchOp
signature: _schemas.OpSignature | None
is_custom: bool = False
is_complex: bool = False
device: Literal["cuda", "cpu"] | str | None = None # noqa: PYI051
skip_signature_inference: bool = False

def __post_init__(self) -> None:
if self.signature is None and not self.skip_signature_inference:
try:
if isinstance(self.onnx_function, onnxscript.OnnxFunction):
signature = _schemas.OpSignature.from_function( # type: ignore[attr-defined]
self.onnx_function,
self.onnx_function.function_ir.domain,
self.onnx_function.name,
opset_version=self.onnx_function.opset.version,
)
else:
signature = _schemas.OpSignature.from_function(
self.onnx_function, "__traced", self.onnx_function.__name__
)
except Exception as e:
# Log an warning if the op is custom. Raise exception for builtin ops.
if not self.is_custom:
raise
else:
# When the function is targeting an HOP, for example, it will accept
# functions as arguments and fail to generate an ONNX signature.
# In this case we set signature to None and dispatch to this function always.
logger.warning(
"Failed to infer the signature for function '%s' because '%s'"
"All nodes targeting `%s` will be dispatched to this function",
self.onnx_function,
e,
self.fx_target,
)
else:
self.signature = signature
self.onnx_function._pt_onnx_signature = signature # type: ignore[attr-defined]


def _get_overload(qualified_name: str) -> torch._ops.OpOverload | None:
Expand Down Expand Up @@ -120,57 +157,34 @@ def from_torchlib(cls) -> ONNXRegistry:
torchlib_registry: The torchlib registry to use for populating the registry.
"""
registry = cls()
for meta in _torchlib_registry.get_torchlib_ops():
registry._register(meta.fx_target, meta)

# TODO(justinchuby): Remove this once torchlib is migrated to PyTorch
torchlib_ops = onnxscript_apis.get_torchlib_ops()

for meta in torchlib_ops:
qualified_name = meta.qualified_name
overload_func = meta.function
domain = meta.domain
name = meta.name
for torchlib_meta in torchlib_ops:
qualified_name = torchlib_meta.qualified_name
overload_func = torchlib_meta.function
try:
# NOTE: This is heavily guarded with try-except because we don't want
# to fail the entire registry population if one function fails.
target = _get_overload(qualified_name)
if target is None:
continue

if isinstance(overload_func, onnxscript.OnnxFunction):
opset_version = overload_func.opset.version
else:
opset_version = 1

overload_func.signature = _schemas.OpSignature.from_function( # type: ignore[attr-defined]
overload_func,
domain,
name,
opset_version=opset_version,
)
onnx_decomposition = OnnxDecompMeta(
meta = OnnxDecompMeta(
onnx_function=overload_func,
fx_target=target,
signature=None,
is_custom=False,
is_complex=meta.is_complex,
is_complex=torchlib_meta.is_complex,
)
registry._register(target, onnx_decomposition)
registry._register(target, meta)
except Exception:
logger.exception("Failed to register '%s'. Skipped", qualified_name)
continue

# Gather ops from the internal torchlib registry
# TODO(justinchuby): Make this the main registry after torchlib is migrated to PyTorch
# Trigger registration
from torch.onnx._internal.exporter._torchlib import ops

del ops
for target, implementations in _torchlib_registry.registry.items(): # type: ignore[assignment]
for impl in implementations:
onnx_decomposition = OnnxDecompMeta(
onnx_function=impl,
fx_target=target, # type: ignore[arg-type]
)
registry._register(target, onnx_decomposition) # type: ignore[arg-type]

return registry

def _register(
Expand Down Expand Up @@ -209,32 +223,23 @@ def register_op(
function: The onnx-script function to register.
is_complex: Whether the function is a function that handles complex valued inputs.
"""
if not hasattr(function, "signature"):
try:
# TODO(justinchuby): Use the op_signature attribute when onnxscript is updated in CI
if isinstance(function, onnxscript.OnnxFunction):
function.signature = _schemas.OpSignature.from_function( # type: ignore[attr-defined]
function,
function.function_ir.domain,
function.name,
opset_version=function.opset.version,
)
else:
function.signature = _schemas.OpSignature.from_function( # type: ignore[attr-defined]
function, "__custom", function.__name__
)
except Exception:
logger.exception(
"Failed to infer the signature for function '%s'", function
)
if isinstance(target, torch._ops.OpOverloadPacket):
raise TypeError(
f"Target '{target}' should be provided as an OpOverload instead of an "
"OpOverloadPacket. You can get the default overload with "
"<op>.default"
)

onnx_decomposition = OnnxDecompMeta(
onnx_function=function,
fx_target=target,
is_custom=True,
is_complex=is_complex,
self._register(
target,
OnnxDecompMeta(
onnx_function=function,
fx_target=target,
signature=None,
is_custom=True,
is_complex=is_complex,
),
)
self._register(target, onnx_decomposition)

def get_decomps(self, target: TorchOp) -> list[OnnxDecompMeta]:
"""Returns a list of OnnxDecompMeta for the given op: torch.ops.<namespace>.<op_name>.<overload>.
Expand Down
1 change: 0 additions & 1 deletion torch/onnx/_internal/exporter/_torchlib/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +0,0 @@

79 changes: 64 additions & 15 deletions torch/onnx/_internal/exporter/_torchlib/_torchlib_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,37 +5,86 @@
from __future__ import annotations


__all__ = ["registry", "onnx_impl"]
__all__ = ["onnx_impl", "get_torchlib_ops"]

import collections
from typing import Callable, TypeVar
import logging
from typing import Any, Callable, Sequence, TypeVar

import onnxscript

_T = TypeVar("_T", bound=Callable)

import torch
from torch.onnx._internal.exporter import _constants, _registration

class Registry(collections.UserDict[Callable, list[Callable]]):
"""Registry for aten functions."""

def register(self, target: Callable, impl: Callable) -> None:
"""Register a function."""
_T = TypeVar("_T", bound=Callable)

self.data.setdefault(target, []).append(impl)
logger = logging.getLogger("__name__")


# Default registry
registry = Registry()
_registry: list[_registration.OnnxDecompMeta] = []


def onnx_impl(
target: Callable,
target: _registration.TorchOp | tuple[_registration.TorchOp, ...],
*,
trace_only: bool = False,
complex: bool = False,
no_compile: bool = False,
private: bool = False,
) -> Callable[[_T], _T]:
"""Register an ONNX implementation of a torch op."""

if isinstance(target, torch._ops.OpOverloadPacket):
raise TypeError(
f"Target '{target}' should be provided as an OpOverload instead of an "
"OpOverloadPacket. You can get the default overload with "
"<op>.default"
)

def wrapper(
func: _T,
) -> _T:
registry.register(target, func)
return func
processed_func: Any
if no_compile:
processed_func = func
else:
torchlib_opset = onnxscript.values.Opset(
domain=_constants.TORCHLIB_DOMAIN, version=1
)

if not trace_only:
# Compile the function
processed_func = onnxscript.script(opset=torchlib_opset)(func)
else:
processed_func = onnxscript.TracedOnnxFunction(torchlib_opset, func)

if not private:
# TODO(justinchuby): Simplify the logic and remove the private attribute
# Skip registration if private
if not isinstance(target, Sequence):
targets = (target,)
else:
targets = target # type: ignore[assignment]

for t in targets:
_registry.append(
_registration.OnnxDecompMeta(
onnx_function=processed_func,
fx_target=t,
signature=None,
is_complex=complex,
skip_signature_inference=no_compile,
)
)
return processed_func # type: ignore[return-value]

return wrapper


def get_torchlib_ops() -> tuple[_registration.OnnxDecompMeta, ...]:
# Trigger op registration
from torch.onnx._internal.exporter._torchlib import ops

del ops
assert len(_registry) != 0
return tuple(_registry)
Loading
Loading
0