10000 Custom ops support arbitrary input types by migrating to python dispatcher by yanboliang · Pull Request #147927 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Custom ops support arbitrary input types by migrating to python dispatcher #147927

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions torch/_dynamo/variables/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import torch._refs
import torch.fx
import torch.nn
import torch.utils._pytree as pytree
from torch._guards import TracingContext
from torch._logging import warning_once
from torch.utils._python_dispatch import is_traceable_wrapper_subclass_type
Expand Down Expand Up @@ -1112,6 +1113,54 @@
if result:
return result

# If the function is custom op, we need to wrap it as flat_apply call in the fx graph.
if isinstance(
self.value, (torch._ops.OpOverload, torch._ops.OpOverloadPacket)
) and any(not isinstance(x, (variables.ConstantVariable)) for x in args):
from torch._higher_order_ops.flat_apply import _ConstantFunction, flat_apply

Check failure on line 1120 in torch/_dynamo/variables/torch.py

View workflow job for this annotation

GitHub Actions / lintrunner-noclang / linux-job

MYPY [no-redef]

Name "flat_apply" already defined (by an import)

# fn = self.value.op
fn = self.value.py_kernels[torch._C.DispatchKey.CompositeExplicitAutograd]
packed_input_vt = TupleVariable.build(
tx,
(
variables.TupleVariable.build(tx, args),
variables.ConstDictVariable.build(tx, kwargs),
),
)
flat_args_and_spec = variables.UserFunctionVariable(
pytree.tree_flatten
).call_function(tx, [packed_input_vt], {})

flat_args_vt, in_spec_vt = flat_args_and_spec.items

Check failure on line 1135 in torch/_dynamo/variables/torch.py

View workflow job for this annotation

GitHub Actions / lintrunner-noclang / linux-job

MYPY [attr-defined]

"VariableTracker" has no attribute "items"
_, func_spec = pytree.tree_flatten(_ConstantFunction(fn))
in_spec = in_spec_vt.as_python_constant()
func_spec_proxy = tx.output.register_static_attr_and_return_proxy(
f"{fn.__name__}_spec", func_spec
)
in_spec_proxy = tx.output.register_static_attr_and_return_proxy(
fn.__name__ + "_input_spec", in_spec
)

proxified_flat_args = [
flat_arg_vt.as_proxy() for flat_arg_vt in flat_args_vt.items
]

func_spec_proxy.node.type = type(func_spec)
in_spec_proxy.node.type = type(in_spec)
all_args = (func_spec_proxy, in_spec_proxy, *proxified_flat_args)

out_vt = wrap_fx_proxy(
tx=tx,
proxy=tx.output.create_proxy(
"call_function",
flat_apply,
all_args,
{},
),
)
return out_vt

any_symints_or_symfloats = any(isinstance(x, SymN 10000 odeVariable) for x in args)

all_ints_or_floats = all(
Expand Down
65 changes: 48 additions & 17 deletions torch/_library/custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,16 +346,30 @@
)
return result

need_python_dispatch = isinstance(
self._opoverload, torch._ops.CustomOpOverload
)

if device_type is None:
self._lib.impl(
self._name, backend_impl, "CompositeExplicitAutograd"
)
if need_python_dispatch:
self._opoverload.py_impl(
_C.DispatchKey.CompositeExplicitAutograd
)(backend_impl)
else:
self._lib.impl(
self._name, backend_impl, "CompositeExplicitAutograd"
)
else:
self._lib.impl(
self._name,
backend_impl,
_C._dispatch_key_for_device(device_type),
)
if need_python_dispatch:
self._opoverload.py_impl(
_C._dispatch_key_for_device(device_type)

Check failure on line 365 in torch/_library/custom_ops.py

View workflow job for this annotation

GitHub Actions / lintrunner-noclang / linux-job

MYPY [arg-type]

Argument 1 to "py_impl" of "OperatorBase" has incompatible type "str"; expected "type[TorchDispatchMode] | type[Tensor] | TransformType | DispatchKey"
)(backend_impl)
else:
self._lib.impl(
self._name,
backend_impl,
_C._dispatch_key_for_device(device_type),
)

# Wrap function to choose between the default implementation or the device-specific
# implementation depending on if the kernel is disabled.
Expand Down Expand Up @@ -609,6 +623,8 @@
)
self._opoverload = utils.lookup_op(self._qualname)

need_python_dispatch = isinstance(self._opoverload, torch._ops.CustomOpOverload)

def fake_impl(*args, **kwargs):
if self._abstract_fn is None:
if utils.can_generate_trivial_fake_impl(self._opoverload):
Expand All @@ -619,12 +635,22 @@
f"Please use `{self._init_fn.__name__}.register_fake` to add an "
f"fake impl."
)
if need_python_dispatch:
args = args[1:]
return self._abstract_fn(*args, **kwargs)

lib._register_fake(self._name, fake_impl, _stacklevel=4)

autograd_impl = autograd.make_autograd_impl(self._opoverload, self)
lib.impl(self._name, autograd_impl, "Autograd", with_keyset=True)

if need_python_dispatch:
self._opoverload.py_impl(torch._subclasses.fake_tensor.FakeTensorMode)(
fake_impl
)
Comment on lines +646 to +649
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NB: this is kind of a hack

self._opoverload.py_impl(_C.DispatchKey.Autograd, with_keyset=True)(
autograd_impl
)
else:
lib._register_fake(self._name, fake_impl, _stacklevel=4)
lib.impl(self._name, autograd_impl, "Autograd", with_keyset=True)

schema = self._opoverload._schema
if schema.is_mutable:
Expand All @@ -640,12 +666,17 @@
keyset & _C._after_ADInplaceOrView_keyset, *args, **kwargs
)

lib.impl(
self._name,
adinplaceorview_impl,
"ADInplaceOrView",
with_keyset=True,
)
if need_python_dispatch:
self._opoverload.py_impl(_C.DispatchKey.ADInplaceOrView)(
adinplaceorview_impl
)
else:
lib.impl(
self._name,
adinplaceorview_impl,
"ADInplaceOrView",
with_keyset=True,
)
Comment on lines +672 to +681
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like we should align the API of py_impl and lib.impl. Otherwise it's kind of annoying

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't need to happen in this PR, just a suggestion


def _register_backend_select_dispatcher(self, device_arg_index: int):
"""
Expand Down
8 changes: 5 additions & 3 deletions torch/_library/infer_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,9 @@
annotation_type, _ = unstringify_type(param.annotation)

if annotation_type not in SUPPORTED_PARAM_TYPES:
if annotation_type.__origin__ is tuple:
if annotation_type in torch.utils._pytree.SUPPORTED_NODES:
schema_type = "Any"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when are we going to change this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tbh I kind of like having "Any" in the type, we could keep it for now

elif annotation_type.__origin__ is tuple:
list_type = tuple_to_list(annotation_type)
example_type_str = "\n\n"
# Only suggest the list type if this type is supported.
Expand All @@ -141,25 +143,25 @@
f"Parameter {name} has unsupported type {param.annotation}. "
f"The valid types are: {SUPPORTED_PARAM_TYPES.keys()}."
)

schema_type = SUPPORTED_PARAM_TYPES[annotation_type]
else:
schema_type = SUPPORTED_PARAM_TYPES[annotation_type]
if type(mutates_args) == str:
if mutates_args != UNKNOWN_MUTATES:
raise ValueError(
"mutates_args must either be a sequence of the names of "
"the arguments that are mutated or the string 'unknown'. "
)
if schema_type.startswith("Tensor"):

Check failure on line 154 in torch/_library/infer_schema.py

View workflow job for this annotation

GitHub Actions / lintrunner-noclang / linux-job

MYPY [possibly-undefined]

Name "schema_type" may be undefined
schema_type = f"Tensor(a{idx}!){schema_type[len('Tensor'):]}"
elif name in mutates_args:
if not schema_type.startswith("Tensor"):

Check failure on line 157 in torch/_library/infer_schema.py

View workflow job for this annotation

GitHub Actions / lintrunner-noclang / linux-job

MYPY [possibly-undefined]

Name "schema_type" may be undefined
error_fn(
f"Parameter {name} is in mutable_args but only Tensors or collections of Tensors can be mutated"
)
schema_type = f"Tensor(a{idx}!){schema_type[len('Tensor'):]}"
seen_args.add(name)
if param.default is inspect.Parameter.empty:
params.append(f"{schema_type} {name}")

Check failure on line 164 in torch/_library/infer_schema.py

View workflow job for this annotation

GitHub Actions / lintrunner-noclang / linux-job

MYPY [possibly-undefined]

Name "schema_type" may be undefined
else:
default_repr = None
if param.default is None or isinstance(param.default, (int, float, bool)):
Expand All @@ -176,7 +178,7 @@
f"Parameter {name} has an unsupported default value type {type(param.default)}. "
f"Please file an issue on GitHub so we can prioritize this."
)
params.append(f"{schema_type} {name}={default_repr}")

Check failure on line 181 in torch/_library/infer_schema.py

View workflow job for this annotation

GitHub Actions / lintrunner-noclang / linux-job

MYPY [possibly-undefined]

Name "schema_type" may be undefined
if mutates_args != UNKNOWN_MUTATES:
mutates_args_not_seen = set(mutates_args) - seen_args
if len(mutates_args_not_seen) > 0:
Expand Down
127 changes: 117 additions & 10 deletions torch/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import ctypes
import importlib
import inspect
import itertools
import sys
import types
from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar, Union
Expand Down Expand Up @@ -96,6 +97,8 @@
# HigherOrderOperator
self.functorch_table = {}

self.needs_keyset = set()

def __call__(self, *args, **kwargs):
raise NotImplementedError

Expand All @@ -116,8 +119,11 @@
TransformType,
DispatchKey,
],
with_keyset: bool = False,
) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
def inner(fn: Callable[_P, _T]) -> Callable[_P, _T]:
if with_keyset:
self.needs_keyset.add(k)
if inspect.isclass(k) and (
issubclass(k, TorchDispatchMode) or issubclass(k, torch.Tensor)
):
Expand Down Expand Up @@ -293,7 +299,7 @@
# it to next key. This is only safe to do when PreDispatch key stack has no
# active modes.

def py_impl(

Check failure on line 302 in torch/_ops.py

View workflow job for this annotation

GitHub Actions / lintrunner-noclang / linux-job

MYPY [override]

Signature of "py_impl" incompatible with supertype "OperatorBase"
self,
k: Union[
type[TorchDispatchMode],
Expand Down Expand Up @@ -449,10 +455,14 @@
if dispatch_key != DispatchKey.PreDispatch:
self._dispatch_cache[dispatch_key] = self.py_kernels[final_key]
kernel = self.py_kernels[final_key]
# It's illegal to register DispatchKey to py_kernels, since there's no
# C++ kernel to call into
assert not isinstance(kernel, DispatchKey)
return kernel(*args, **kwargs)
if final_key in self.needs_keyset:
key_set = _compute_keyset(args, kwargs, self.non_fallthrough_keys)
return kernel(key_set, *args, **kwargs)
else:
# It's illegal to register DispatchKey to py_kernels, since there's no
# C++ kernel to call into
assert not isinstance(kernel, DispatchKey)
return kernel(*args, **kwargs)

@abc.abstractmethod
def __call__(self, /, *args, **kwargs):
Expand Down Expand Up @@ -925,6 +935,91 @@
# TODO: add more methods to expose information about input and output arguments


class CustomOpOverload(OpOverload):
def __repr__(self):
return "<CustomOpOverload(op='{}.{}', overload='{}')>".format(
*self._schema.name.split("::"), self._overloadname
)

def _fallthrough_keys(self) -> list[DispatchKey]:
# TODO: we should be calling the fallback for these, but a fallthrough is almost close
# enough to the fallback in most cases that we care about.
_DEFAULT_FALLTHROUGH_KEYS = [
DispatchKey.Autograd,
DispatchKey.AutogradCPU,
DispatchKey.AutogradCUDA,
DispatchKey.ADInplaceOrView,
DispatchKey.BackendSelect,
DispatchKey.PythonTLSSnapshot,
DispatchKey.PythonDispatcher,
]

def _may_use_fallthrough_instead_of_fallback(key: DispatchKey):
if torch._C._dispatch_has_kernel_for_dispatch_key(self.name(), key):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there's a kernel in py_impl as well, then we should not use the fallthrough either

return torch._C._dispatch_kernel_for_dispatch_key_is_fallthrough(
self.name(), key
)

return (
key not in self.py_kernels
or self.py_kernels[key] is torch.library.fallthrough_kernel
)

return [
key
for key in _DEFAULT_FALLTHROUGH_KEYS
if _may_use_fallthrough_instead_of_fallback(key)
]

def __call__(self, /, *args, **kwargs):
return self._dispatch_in_python(args, kwargs, self._fallthrough_keys())

def _dispatch_in_python(self, args, kwargs, fallthrough_keys):
non_fallthrough_keys = torch._C._dispatch_keyset_full()
for key in fallthrough_keys:
non_fallthrough_keys = non_fallthrough_keys.remove(key)

dispatch_key_set = _compute_keyset(args, kwargs, non_fallthrough_keys)
dispatch_key = dispatch_key_set.highestPriorityTypeId()

handler = (
self._get_dispatch(dispatch_key)
if dispatch_key not in self._dispatch_cache
else self._dispatch_cache[dispatch_key]
)

if isinstance(handler, DispatchKey):
# fallthrough keys can be registered at runtime via torch.library.impl
# so need to add it to fallthrough_keys and re-dispatch.
if torch._C._dispatch_kernel_for_dispatch_key_is_fallthrough(
self.name(), dispatch_key
):
return self._dispatch_in_python(
args, kwargs, fallthrough_keys + [dispatch_key]
)

raise RuntimeError(
f"Custom op {self} received a Pytree input when dispatching {handler}."
f" but no python implementation is found."
f" Please file an issue on this when you encounter this error."
)

assert isinstance(handler, Callable) # type: ignore[arg-type]
return handler(*args, **kwargs)


def _has_pytree_object_arg(schema: torch.FunctionSchema) -> bool:
return any(isinstance(arg.type, torch.AnyType) for arg in schema.arguments)


def _has_pytree_type_in_args_or_kwargs(args, kwargs) -> bool:
return any(
not isinstance(x, (list, tuple))
and type(x) in torch.utils._pytree.SUPPORTED_NODES
for x in itertools.chain(args, kwargs.values())
)
Comment on lines +1028 to +1033
Copy link
Collaborator
@XuehaiPan XuehaiPan Mar 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def _has_pytree_type_in_args_or_kwargs(args, kwargs) -> bool:
return any(
not isinstance(x, (list, tuple))
and type(x) in torch.utils._pytree.SUPPORTED_NODES
for x in itertools.chain(args, kwargs.values())
)
def _has_pytree_type_in_args_or_kwargs(args, kwargs) -> bool:
def is_list_or_tuple(x):
return isinstance(x, (list, tuple))
return any(
not torch.utils._pytree.tree_is_leaf(x, is_leaf=is_list_or_tuple)
for x in itertools.chain(args, kwargs.values())
)

You can use tree_is_leaf after #113257 get merged into main.



# TorchBindOpOverload are those custom ops which have at least one overload's
# schema consists of torch.ScriptObject (i.e. custom class) input.
# TorchBindOpOverload will skip C++ dispatcher and purely dispatched in python
Expand Down Expand Up @@ -1059,6 +1154,9 @@
self._has_torchbind_op_overload = any(
_has_script_object_arg(schema) for schema in self._schemas.values()
)
self._has_pytree_arg_overload = any(
_has_pytree_object_arg(schema) for schema in self._schemas.values()
)

# it's a no-op since OpOverloadPacket object is immutable and must be unique for a given op.
def __deepcopy__(self, memo=None):
Expand Down Expand Up @@ -1125,11 +1223,12 @@

op_, op_dk_, tags = op_dk_tags
schema = torch._C._get_schema(self._qualified_op_name, use_key)
overload = (
OpOverload(self, op_, op_dk_, schema, tags)
if not _has_script_object_arg(schema)
else TorchBindOpOverload(self, op_, op_dk_, schema, tags)
)
if _has_pytree_object_arg(schema):
overload = CustomOpOverload(self, op_, op_dk_, schema, tags)
elif _has_script_object_arg(schema):
overload = TorchBindOpOverload(self, op_, op_dk_, schema, tags)

Check failure on line 1229 in torch/_ops.py

View workflow job for this annotation

GitHub Actions / lintrunner-noclang / linux-job

MYPY [assignment]

Incompatible types in assignment (expression has type "TorchBindOpOverload", variable has type "CustomOpOverload")
else:
overload = OpOverload(self, op_, op_dk_, schema, tags)

Check failure on line 1231 in torch/_ops.py

View workflow job for this annotation

GitHub Actions / lintrunner-noclang / linux-job

MYPY [assignment]

Incompatible types in assignment (expression has type "OpOverload", variable has type "CustomOpOverload")
# cache the overload object
setattr(self, key, overload)
self._dir.append(key)
Expand All @@ -1153,7 +1252,11 @@
# Directly calling OverloadPacket goes into C++, which will check
# the schema and cause an error for torchbind op when inputs consist of FakeScriptObject so we
# intercept it here and call TorchBindOpverload instead.
if self._has_torchbind_op_overload and _must_dispatch_in_python(args, kwargs):
if (
self._has_torchbind_op_overload
and _must_dispatch_in_python(args, kwargs)
or self._has_pytree_arg_overload
):
return _call_overload_packet_from_python(self, args, kwargs)
return self._op(*args, **(kwargs or {}))

Expand All @@ -1173,6 +1276,10 @@
if torch_function_called:
return ret

if _has_pytree_type_in_args_or_kwargs(args, kwargs):
op_overload = getattr(op, op.overloads()[0])
return op_overload(*args, **kwargs)
Comment on lines +1292 to +1294
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. this is probably slow and needs caching on the OpOverload
  2. this is wrong if someone tries to register multiple OpOverloads for one OpOverloadPacket, so we should error out in that situation


# The following mirrors getOpWithStack.
# In cpp, we do a schema matching for the arguments, and call ToIValue to
# to check whether the arguments are valid. But need to do similar things here
Expand Down
Loading
0