-
Notifications
You must be signed in to change notification settings - Fork 24.8k
Custom ops support arbitrary input types by migrating to python dispatcher #147927
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 2 commits
cb34942
cd1dc0c
858b0b0
db58378
8156410
2f51c4a
71fcd73
d84db1a
5fc5fa2
d883874
14c0439
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -346,16 +346,30 @@ | |
) | ||
return result | ||
|
||
need_python_dispatch = isinstance( | ||
self._opoverload, torch._ops.CustomOpOverload | ||
) | ||
|
||
if device_type is None: | ||
self._lib.impl( | ||
self._name, backend_impl, "CompositeExplicitAutograd" | ||
) | ||
if need_python_dispatch: | ||
self._opoverload.py_impl( | ||
_C.DispatchKey.CompositeExplicitAutograd | ||
)(backend_impl) | ||
else: | ||
self._lib.impl( | ||
self._name, backend_impl, "CompositeExplicitAutograd" | ||
) | ||
else: | ||
self._lib.impl( | ||
self._name, | ||
backend_impl, | ||
_C._dispatch_key_for_device(device_type), | ||
) | ||
if need_python_dispatch: | ||
self._opoverload.py_impl( | ||
_C._dispatch_key_for_device(device_type) | ||
Check failure on line 365 in torch/_library/custom_ops.py
|
||
)(backend_impl) | ||
else: | ||
self._lib.impl( | ||
self._name, | ||
backend_impl, | ||
_C._dispatch_key_for_device(device_type), | ||
) | ||
|
||
# Wrap function to choose between the default implementation or the device-specific | ||
# implementation depending on if the kernel is disabled. | ||
|
@@ -609,6 +623,8 @@ | |
) | ||
self._opoverload = utils.lookup_op(self._qualname) | ||
|
||
need_python_dispatch = isinstance(self._opoverload, torch._ops.CustomOpOverload) | ||
|
||
def fake_impl(*args, **kwargs): | ||
if self._abstract_fn is None: | ||
if utils.can_generate_trivial_fake_impl(self._opoverload): | ||
|
@@ -619,12 +635,22 @@ | |
f"Please use `{self._init_fn.__name__}.register_fake` to add an " | ||
f"fake impl." | ||
) | ||
if need_python_dispatch: | ||
args = args[1:] | ||
return self._abstract_fn(*args, **kwargs) | ||
|
||
lib._register_fake(self._name, fake_impl, _stacklevel=4) | ||
|
||
autograd_impl = autograd.make_autograd_impl(self._opoverload, self) | ||
lib.impl(self._name, autograd_impl, "Autograd", with_keyset=True) | ||
|
||
if need_python_dispatch: | ||
self._opoverload.py_impl(torch._subclasses.fake_tensor.FakeTensorMode)( | ||
fake_impl | ||
) | ||
Comment on lines
+646
to
+649
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. NB: this is kind of a hack |
||
self._opoverload.py_impl(_C.DispatchKey.Autograd, with_keyset=True)( | ||
autograd_impl | ||
) | ||
else: | ||
lib._register_fake(self._name, fake_impl, _stacklevel=4) | ||
lib.impl(self._name, autograd_impl, "Autograd", with_keyset=True) | ||
|
||
schema = self._opoverload._schema | ||
if schema.is_mutable: | ||
|
@@ -640,12 +666,17 @@ | |
keyset & _C._after_ADInplaceOrView_keyset, *args, **kwargs | ||
) | ||
|
||
lib.impl( | ||
self._name, | ||
adinplaceorview_impl, | ||
"ADInplaceOrView", | ||
with_keyset=True, | ||
) | ||
if need_python_dispatch: | ||
self._opoverload.py_impl(_C.DispatchKey.ADInplaceOrView)( | ||
adinplaceorview_impl | ||
) | ||
else: | ||
lib.impl( | ||
self._name, | ||
adinplaceorview_impl, | ||
"ADInplaceOrView", | ||
with_keyset=True, | ||
) | ||
Comment on lines
+672
to
+681
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I feel like we should align the API of py_impl and lib.impl. Otherwise it's kind of annoying There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Doesn't need to happen in this PR, just a suggestion |
||
|
||
def _register_backend_select_dispatcher(self, device_arg_index: int): | ||
""" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -124,7 +124,9 @@ | |
annotation_type, _ = unstringify_type(param.annotation) | ||
|
||
if annotation_type not in SUPPORTED_PARAM_TYPES: | ||
if annotation_type.__origin__ is tuple: | ||
if annotation_type in torch.utils._pytree.SUPPORTED_NODES: | ||
schema_type = "Any" | ||
yanboliang marked this conversation as resolved.
Show resolved
Hide resolved
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. when are we going to change this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. tbh I kind of like having "Any" in the type, we could keep it for now |
||
elif annotation_type.__origin__ is tuple: | ||
list_type = tuple_to_list(annotation_type) | ||
example_type_str = "\n\n" | ||
# Only suggest the list type if this type is supported. | ||
|
@@ -141,25 +143,25 @@ | |
f"Parameter {name} has unsupported type {param.annotation}. " | ||
f"The valid types are: {SUPPORTED_PARAM_TYPES.keys()}." | ||
) | ||
|
||
schema_type = SUPPORTED_PARAM_TYPES[annotation_type] | ||
else: | ||
schema_type = SUPPORTED_PARAM_TYPES[annotation_type] | ||
if type(mutates_args) == str: | ||
if mutates_args != UNKNOWN_MUTATES: | ||
raise ValueError( | ||
"mutates_args must either be a sequence of the names of " | ||
"the arguments that are mutated or the string 'unknown'. " | ||
) | ||
if schema_type.startswith("Tensor"): | ||
schema_type = f"Tensor(a{idx}!){schema_type[len('Tensor'):]}" | ||
elif name in mutates_args: | ||
if not schema_type.startswith("Tensor"): | ||
error_fn( | ||
f"Parameter {name} is in mutable_args but only Tensors or collections of Tensors can be mutated" | ||
) | ||
schema_type = f"Tensor(a{idx}!){schema_type[len('Tensor'):]}" | ||
seen_args.add(name) | ||
if param.default is inspect.Parameter.empty: | ||
params.append(f"{schema_type} {name}") | ||
else: | ||
default_repr = None | ||
if param.default is None or isinstance(param.default, (int, float, bool)): | ||
|
@@ -176,7 +178,7 @@ | |
f"Parameter {name} has an unsupported default value type {type(param.default)}. " | ||
f"Please file an issue on GitHub so we can prioritize this." | ||
) | ||
params.append(f"{schema_type} {name}={default_repr}") | ||
if mutates_args != UNKNOWN_MUTATES: | ||
mutates_args_not_seen = set(mutates_args) - seen_args | ||
if len(mutates_args_not_seen) > 0: | ||
|
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -4,6 +4,7 @@ | |||||||||||||||||||||||||||||
import ctypes | ||||||||||||||||||||||||||||||
import importlib | ||||||||||||||||||||||||||||||
import inspect | ||||||||||||||||||||||||||||||
import itertools | ||||||||||||||||||||||||||||||
import sys | ||||||||||||||||||||||||||||||
import types | ||||||||||||||||||||||||||||||
from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar, Union | ||||||||||||||||||||||||||||||
|
@@ -96,6 +97,8 @@ | |||||||||||||||||||||||||||||
# HigherOrderOperator | ||||||||||||||||||||||||||||||
self.functorch_table = {} | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
self.needs_keyset = set() | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
def __call__(self, *args, **kwargs): | ||||||||||||||||||||||||||||||
raise NotImplementedError | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
|
@@ -116,8 +119,11 @@ | |||||||||||||||||||||||||||||
TransformType, | ||||||||||||||||||||||||||||||
DispatchKey, | ||||||||||||||||||||||||||||||
], | ||||||||||||||||||||||||||||||
with_keyset: bool = False, | ||||||||||||||||||||||||||||||
) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: | ||||||||||||||||||||||||||||||
def inner(fn: Callable[_P, _T]) -> Callable[_P, _T]: | ||||||||||||||||||||||||||||||
if with_keyset: | ||||||||||||||||||||||||||||||
self.needs_keyset.add(k) | ||||||||||||||||||||||||||||||
if inspect.isclass(k) and ( | ||||||||||||||||||||||||||||||
issubclass(k, TorchDispatchMode) or issubclass(k, torch.Tensor) | ||||||||||||||||||||||||||||||
): | ||||||||||||||||||||||||||||||
|
@@ -293,7 +299,7 @@ | |||||||||||||||||||||||||||||
# it to next key. This is only safe to do when PreDispatch key stack has no | ||||||||||||||||||||||||||||||
# active modes. | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
def py_impl( | ||||||||||||||||||||||||||||||
self, | ||||||||||||||||||||||||||||||
k: Union[ | ||||||||||||||||||||||||||||||
type[TorchDispatchMode], | ||||||||||||||||||||||||||||||
|
@@ -449,10 +455,14 @@ | |||||||||||||||||||||||||||||
if dispatch_key != DispatchKey.PreDispatch: | ||||||||||||||||||||||||||||||
self._dispatch_cache[dispatch_key] = self.py_kernels[final_key] | ||||||||||||||||||||||||||||||
kernel = self.py_kernels[final_key] | ||||||||||||||||||||||||||||||
# It's illegal to register DispatchKey to py_kernels, since there's no | ||||||||||||||||||||||||||||||
# C++ kernel to call into | ||||||||||||||||||||||||||||||
assert not isinstance(kernel, DispatchKey) | ||||||||||||||||||||||||||||||
return kernel(*args, **kwargs) | ||||||||||||||||||||||||||||||
if final_key in self.needs_keyset: | ||||||||||||||||||||||||||||||
key_set = _compute_keyset(args, kwargs, self.non_fallthrough_keys) | ||||||||||||||||||||||||||||||
return kernel(key_set, *args, **kwargs) | ||||||||||||||||||||||||||||||
yanboliang marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||
# It's illegal to register DispatchKey to py_kernels, since there's no | ||||||||||||||||||||||||||||||
# C++ kernel to call into | ||||||||||||||||||||||||||||||
assert not isinstance(kernel, DispatchKey) | ||||||||||||||||||||||||||||||
return kernel(*args, **kwargs) | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
@abc.abstractmethod | ||||||||||||||||||||||||||||||
def __call__(self, /, *args, **kwargs): | ||||||||||||||||||||||||||||||
|
@@ -925,6 +935,91 @@ | |||||||||||||||||||||||||||||
# TODO: add more methods to expose information about input and output arguments | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
class CustomOpOverload(OpOverload): | ||||||||||||||||||||||||||||||
yanboliang marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||
def __repr__(self): | ||||||||||||||||||||||||||||||
return "<CustomOpOverload(op='{}.{}', overload='{}')>".format( | ||||||||||||||||||||||||||||||
*self._schema.name.split("::"), self._overloadname | ||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
def _fallthrough_keys(self) -> list[DispatchKey]: | ||||||||||||||||||||||||||||||
# TODO: we should be calling the fallback for these, but a fallthrough is almost close | ||||||||||||||||||||||||||||||
# enough to the fallback in most cases that we care about. | ||||||||||||||||||||||||||||||
_DEFAULT_FALLTHROUGH_KEYS = [ | ||||||||||||||||||||||||||||||
DispatchKey.Autograd, | ||||||||||||||||||||||||||||||
DispatchKey.AutogradCPU, | ||||||||||||||||||||||||||||||
DispatchKey.AutogradCUDA, | ||||||||||||||||||||||||||||||
DispatchKey.ADInplaceOrView, | ||||||||||||||||||||||||||||||
DispatchKey.BackendSelect, | ||||||||||||||||||||||||||||||
DispatchKey.PythonTLSSnapshot, | ||||||||||||||||||||||||||||||
DispatchKey.PythonDispatcher, | ||||||||||||||||||||||||||||||
] | ||||||||||||||||||||||||||||||
yanboliang marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
def _may_use_fallthrough_instead_of_fallback(key: DispatchKey): | ||||||||||||||||||||||||||||||
if torch._C._dispatch_has_kernel_for_dispatch_key(self.name(), key): | ||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If there's a kernel in py_impl as well, then we should not use the fallthrough either |
||||||||||||||||||||||||||||||
return torch._C._dispatch_kernel_for_dispatch_key_is_fallthrough( | ||||||||||||||||||||||||||||||
self.name(), key | ||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
return ( | ||||||||||||||||||||||||||||||
key not in self.py_kernels | ||||||||||||||||||||||||||||||
or self.py_kernels[key] is torch.library.fallthrough_kernel | ||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
return [ | ||||||||||||||||||||||||||||||
key | ||||||||||||||||||||||||||||||
for key in _DEFAULT_FALLTHROUGH_KEYS | ||||||||||||||||||||||||||||||
if _may_use_fallthrough_instead_of_fallback(key) | ||||||||||||||||||||||||||||||
] | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
def __call__(self, /, *args, **kwargs): | ||||||||||||||||||||||||||||||
return self._dispatch_in_python(args, kwargs, self._fallthrough_keys()) | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
def _dispatch_in_python(self, args, kwargs, fallthrough_keys): | ||||||||||||||||||||||||||||||
non_fallthrough_keys = torch._C._dispatch_keyset_full() | ||||||||||||||||||||||||||||||
for key in fallthrough_keys: | ||||||||||||||||||||||||||||||
non_fallthrough_keys = non_fallthrough_keys.remove(key) | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
dispatch_key_set = _compute_keyset(args, kwargs, non_fallthrough_keys) | ||||||||||||||||||||||||||||||
dispatch_key = dispatch_key_set.highestPriorityTypeId() | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
handler = ( | ||||||||||||||||||||||||||||||
self._get_dispatch(dispatch_key) | ||||||||||||||||||||||||||||||
if dispatch_key not in self._dispatch_cache | ||||||||||||||||||||||||||||||
else self._dispatch_cache[dispatch_key] | ||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
if isinstance(handler, DispatchKey): | ||||||||||||||||||||||||||||||
# fallthrough keys can be registered at runtime via torch.library.impl | ||||||||||||||||||||||||||||||
# so need to add it to fallthrough_keys and re-dispatch. | ||||||||||||||||||||||||||||||
if torch._C._dispatch_kernel_for_dispatch_key_is_fallthrough( | ||||||||||||||||||||||||||||||
self.name(), dispatch_key | ||||||||||||||||||||||||||||||
): | ||||||||||||||||||||||||||||||
zou3519 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||
return self._dispatch_in_python( | ||||||||||||||||||||||||||||||
args, kwargs, fallthrough_keys + [dispatch_key] | ||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
raise RuntimeError( | ||||||||||||||||||||||||||||||
f"Custom op {self} received a Pytree input when dispatching {handler}." | ||||||||||||||||||||||||||||||
f" but no python implementation is found." | ||||||||||||||||||||||||||||||
f" Please file an issue on this when you encounter this error." | ||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
assert isinstance(handler, Callable) # type: ignore[arg-type] | ||||||||||||||||||||||||||||||
return handler(*args, **kwargs) | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
def _has_pytree_object_arg(schema: torch.FunctionSchema) -> bool: | ||||||||||||||||||||||||||||||
return any(isinstance(arg.type, torch.AnyType) for arg in schema.arguments) | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
def _has_pytree_type_in_args_or_kwargs(args, kwargs) -> bool: | ||||||||||||||||||||||||||||||
return any( | ||||||||||||||||||||||||||||||
not isinstance(x, (list, tuple)) | ||||||||||||||||||||||||||||||
and type(x) in torch.utils._pytree.SUPPORTED_NODES | ||||||||||||||||||||||||||||||
for x in itertools.chain(args, kwargs.values()) | ||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||
Comment on lines
+1028
to
+1033
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
You can use |
||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
# TorchBindOpOverload are those custom ops which have at least one overload's | ||||||||||||||||||||||||||||||
# schema consists of torch.ScriptObject (i.e. custom class) input. | ||||||||||||||||||||||||||||||
# TorchBindOpOverload will skip C++ dispatcher and purely dispatched in python | ||||||||||||||||||||||||||||||
|
@@ -1059,6 +1154,9 @@ | |||||||||||||||||||||||||||||
self._has_torchbind_op_overload = any( | ||||||||||||||||||||||||||||||
_has_script_object_arg(schema) for schema in self._schemas.values() | ||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||
self._has_pytree_arg_overload = any( | ||||||||||||||||||||||||||||||
_has_pytree_object_arg(schema) for schema in self._schemas.values() | ||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
# it's a no-op since OpOverloadPacket object is immutable and must be unique for a given op. | ||||||||||||||||||||||||||||||
def __deepcopy__(self, memo=None): | ||||||||||||||||||||||||||||||
|
@@ -1125,11 +1223,12 @@ | |||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
op_, op_dk_, tags = op_dk_tags | ||||||||||||||||||||||||||||||
schema = torch._C._get_schema(self._qualified_op_name, use_key) | ||||||||||||||||||||||||||||||
overload = ( | ||||||||||||||||||||||||||||||
OpOverload(self, op_, op_dk_, schema, tags) | ||||||||||||||||||||||||||||||
if not _has_script_object_arg(schema) | ||||||||||||||||||||||||||||||
else TorchBindOpOverload(self, op_, op_dk_, schema, tags) | ||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||
if _has_pytree_object_arg(schema): | ||||||||||||||||||||||||||||||
overload = CustomOpOverload(self, op_, op_dk_, schema, tags) | ||||||||||||||||||||||||||||||
elif _has_script_object_arg(schema): | ||||||||||||||||||||||||||||||
overload = TorchBindOpOverload(self, op_, op_dk_, schema, tags) | ||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||
overload = OpOverload(self, op_, op_dk_, schema, tags) | ||||||||||||||||||||||||||||||
# cache the overload object | ||||||||||||||||||||||||||||||
setattr(self, key, overload) | ||||||||||||||||||||||||||||||
self._dir.append(key) | ||||||||||||||||||||||||||||||
|
@@ -1153,7 +1252,11 @@ | |||||||||||||||||||||||||||||
# Directly calling OverloadPacket goes into C++, which will check | ||||||||||||||||||||||||||||||
# the schema and cause an error for torchbind op when inputs consist of FakeScriptObject so we | ||||||||||||||||||||||||||||||
# intercept it here and call TorchBindOpverload instead. | ||||||||||||||||||||||||||||||
if self._has_torchbind_op_overload and _must_dispatch_in_python(args, kwargs): | ||||||||||||||||||||||||||||||
if ( | ||||||||||||||||||||||||||||||
self._has_torchbind_op_overload | ||||||||||||||||||||||||||||||
and _must_dispatch_in_python(args, kwargs) | ||||||||||||||||||||||||||||||
or self._has_pytree_arg_overload | ||||||||||||||||||||||||||||||
): | ||||||||||||||||||||||||||||||
return _call_overload_packet_from_python(self, args, kwargs) | ||||||||||||||||||||||||||||||
return self._op(*args, **(kwargs or {})) | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
|
@@ -1173,6 +1276,10 @@ | |||||||||||||||||||||||||||||
if torch_function_called: | ||||||||||||||||||||||||||||||
return ret | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
if _has_pytree_type_in_args_or_kwargs(args, kwargs): | ||||||||||||||||||||||||||||||
op_overload = getattr(op, op.overloads()[0]) | ||||||||||||||||||||||||||||||
return op_overload(*args, **kwargs) | ||||||||||||||||||||||||||||||
Comment on lines
+1292
to
+1294
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
# The following mirrors getOpWithStack. | ||||||||||||||||||||||||||||||
# In cpp, we do a schema matching for the arguments, and call ToIValue to | ||||||||||||||||||||||||||||||
# to check whether the arguments are valid. But need to do similar things here | ||||||||||||||||||||||||||||||
|
Uh oh!
There was an error while loading. Please reload this page.