-
Notifications
You must be signed in to change notification settings - Fork 24.8k
Custom ops support arbitrary input types by migrating to python dispatcher #147927
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
cb34942
cd1dc0c
858b0b0
db58378
8156410
2f51c4a
71fcd73
d84db1a
5fc5fa2
d883874
14c0439
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -346,16 +346,30 @@ def get_module(): | |
) | ||
return result | ||
|
||
need_python_dispatch = hasattr(self, "_opoverload") and isinstance( | ||
self._opoverload, torch._ops.PythonDispatcherOpOverload | ||
) | ||
|
||
if device_type is None: | ||
self._lib.impl( | ||
self._name, backend_impl, "CompositeExplicitAutograd" | ||
) | ||
if need_python_dispatch: | ||
self._opoverload.py_impl( | ||
_C.DispatchKey.CompositeExplicitAutograd | ||
)(backend_impl) | ||
else: | ||
self._lib.impl( | ||
self._name, backend_impl, "CompositeExplicitAutograd" | ||
) | ||
else: | ||
self._lib.impl( | ||
self._name, | ||
backend_impl, | ||
_C._dispatch_key_for_device(device_type), | ||
) | ||
if need_python_dispatch: | ||
dispatch_key = _C._dispatch_key_for_device(device_type) | ||
dispatch_key = getattr(_C.DispatchKey, dispatch_key) | ||
self._opoverload.py_impl(dispatch_key)(backend_impl) | ||
else: | ||
self._lib.impl( | ||
self._name, | ||
backend_impl, | ||
_C._dispatch_key_for_device(device_type), | ||
) | ||
|
||
# Wrap function to choose between the default implementation or the device-specific | ||
# implementation depending on if the kernel is disabled. | ||
|
@@ -609,6 +623,10 @@ def _register_to_dispatcher(self) -> None: | |
) | ||
self._opoverload = utils.lookup_op(self._qualname) | ||
|
||
need_python_dispatch = isinstance( | ||
self._opoverload, torch._ops.PythonDispatcherOpOverload | ||
) | ||
Comment on lines
+626
to
+628
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: should probably make this a helper function |
||
|
||
def fake_impl(*args, **kwargs): | ||
if self._abstract_fn is None: | ||
if utils.can_generate_trivial_fake_impl(self._opoverload): | ||
|
@@ -619,12 +637,22 @@ def fake_impl(*args, **kwargs): | |
f"Please use `{self._init_fn.__name__}.register_fake` to add an " | ||
f"fake impl." | ||
) | ||
if need_python_dispatch: | ||
args = args[1:] # Remove the dispatch key under fake tensor mode. | ||
return self._abstract_fn(*args, **kwargs) | ||
|
||
lib._register_fake(self._name, fake_impl, _stacklevel=4) | ||
|
||
autograd_impl = autograd.make_autograd_impl(self._opoverload, self) | ||
lib.impl(self._name, autograd_impl, "Autograd", with_keyset=True) | ||
|
||
if need_python_dispatch: | ||
self._opoverload.py_impl(torch._subclasses.fake_tensor.FakeTensorMode)( | ||
fake_impl | ||
) | ||
Comment on lines
+646
to
+649
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. NB: this is kind of a hack |
||
self._opoverload.py_impl(_C.DispatchKey.Autograd, with_keyset=True)( | ||
autograd_impl | ||
) | ||
else: | ||
lib._register_fake(self._name, fake_impl, _stacklevel=4) | ||
lib.impl(self._name, autograd_impl, "Autograd", with_keyset=True) | ||
|
||
schema = self._opoverload._schema | ||
if schema.is_mutable: | ||
|
@@ -640,12 +668,17 @@ def adinplaceorview_impl(keyset, *args, **kwargs): | |
keyset & _C._after_ADInplaceOrView_keyset, *args, **kwargs | ||
) | ||
|
||
lib.impl( | ||
self._name, | ||
adinplaceorview_impl, | ||
"ADInplaceOrView", | ||
with_keyset=True, | ||
) | ||
if need_python_dispatch: | ||
self._opoverload.py_impl(_C.DispatchKey.ADInplaceOrView)( | ||
adinplaceorview_impl | ||
) | ||
else: | ||
lib.impl( | ||
self._name, | ||
adinplaceorview_impl, | ||
"ADInplaceOrView", | ||
with_keyset=True, | ||
) | ||
Comment on lines
+672
to
+681
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I feel like we should align the API of py_impl and lib.impl. Otherwise it's kind of annoying There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Doesn't need to happen in this PR, just a suggestion |
||
|
||
def _register_backend_select_dispatcher(self, device_arg_index: int): | ||
""" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -124,7 +124,10 @@ def unstringify_type(ty: Union[type[object], str]) -> tuple[typing.Any, bool]: | |
annotation_type, _ = unstringify_type(param.annotation) | ||
|
||
if annotation_type not in SUPPORTED_PARAM_TYPES: | ||
if annotation_type.__origin__ is tuple: | ||
if annotation_type in torch.utils._pytree.SUPPORTED_NODES: | ||
# TODO: Move to a separate schema type for pytrees. | ||
schema_type = "Any" | ||
yanboliang marked this conversation as resolved.
Show resolved
Hide resolved
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. when are we going to change this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. tbh I kind of like having "Any" in the type, we could keep it for now |
||
elif annotation_type.__origin__ is tuple: | ||
list_type = tuple_to_list(annotation_type) | ||
example_type_str = "\n\n" | ||
# Only suggest the list type if this type is supported. | ||
|
@@ -141,25 +144,25 @@ def unstringify_type(ty: Union[type[object], str]) -> tuple[typing.Any, bool]: | |
f"Parameter {name} has unsupported type {param.annotation}. " | ||
f"The valid types are: {SUPPORTED_PARAM_TYPES.keys()}." | ||
) | ||
|
||
schema_type = SUPPORTED_PARAM_TYPES[annotation_type] | ||
else: | ||
schema_type = SUPPORTED_PARAM_TYPES[annotation_type] | ||
if type(mutates_args) == str: | ||
if mutates_args != UNKNOWN_MUTATES: | ||
raise ValueError( | ||
"mutates_args must either be a sequence of the names of " | ||
"the arguments that are mutated or the string 'unknown'. " | ||
) | ||
if schema_type.startswith("Tensor"): | ||
schema_type = f"Tensor(a{idx}!){schema_type[len('Tensor'):]}" | ||
if schema_type.startswith("Tensor"): # type: ignore[possibly-undefined] | ||
schema_type = f"Tensor(a{idx}!){schema_type[len('Tensor'):]}" # type: ignore[possibly-undefined] | ||
elif name in mutates_args: | ||
if not schema_type.startswith("Tensor"): | ||
if not schema_type.startswith("Tensor"): # type: ignore[possibly-undefined] | ||
error_fn( | ||
f"Parameter {name} is in mutable_args but only Tensors or collections of Tensors can be mutated" | ||
) | ||
schema_type = f"Tensor(a{idx}!){schema_type[len('Tensor'):]}" | ||
schema_type = f"Tensor(a{idx}!){schema_type[len('Tensor'):]}" # type: ignore[possibly-undefined] | ||
seen_args.add(name) | ||
if param.default is inspect.Parameter.empty: | ||
params.append(f"{schema_type} {name}") | ||
params.append(f"{schema_type} {name}") # type: ignore[possibly-undefined] | ||
else: | ||
default_repr = None | ||
if param.default is None or isinstance(param.default, (int, float, bool)): | ||
|
@@ -176,7 +179,7 @@ def unstringify_type(ty: Union[type[object], str]) -> tuple[typing.Any, bool]: | |
f"Parameter {name} has an unsupported default value type {type(param.default)}. " | ||
f"Please file an issue on GitHub so we can prioritize this." | ||
) | ||
params.append(f"{schema_type} {name}={default_repr}") | ||
params.append(f"{schema_type} {name}={default_repr}") # type: ignore[possibly-undefined] | ||
if mutates_args != UNKNOWN_MUTATES: | ||
mutates_args_not_seen = set(mutates_args) - seen_args | ||
if len(mutates_args_not_seen) > 0: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some high-level comments:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree! I'll do the perf benchmark while you conduct code review.