-
Notifications
You must be signed in to change notification settings - Fork 24.2k
Custom ops support arbitrary input types by migrating to python dispatcher #147927
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/147927
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 14c0439 with merge base 84e60ee ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
||
x = Point(x=torch.randn(2, 3), y=torch.randn(2, 3)) | ||
y = torch.ops.mylib.foo(x) | ||
self.assertEqual(y, torch.sqrt(torch.sum((x.x - x.y) ** 2))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some high-level comments:
- Let's wait until after the branch cut (Monday) to merge this, assuming its ready before then. We don't want this feature to be partially in PyTorch 2.7.
- eager-mode performance is pretty important. Can you do some benchmarking comparing e.g. custom ops with dict input (uses the new path) to custom ops with list inputs (uses the C++ dispatcher)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree! I'll do the perf benchmark while you conduct code review.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My main comments are:
- eager-mode performance. Let's make sure this is good.
- the fallthrough mechanism isn't faithful to the C++ pytorch dispatcher. We should make it more faithful
- the FakeTensor registration mechanism is on the sketchy side. In particular, the py_impl(FakeTensorMode) completely bypasses FakeTensor caching. Maybe something to handle as a follow-up.
if annotation_type.__origin__ is tuple: | ||
if annotation_type in torch.utils._pytree.SUPPORTED_NODES: | ||
# TODO: Move to a separate schema type for pytrees. | ||
schema_type = "Any" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
when are we going to change this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tbh I kind of like having "Any" in the type, we could keep it for now
@@ -46,6 +47,23 @@ def dl_open_guard(): | |||
sys.setdlopenflags(old_flags) | |||
|
|||
|
|||
class Kernel: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
put this somewhere else. Maybe in torch/_dispatch or torch/_library. Otherwise I think this gets exposed as torch.ops.Kernel (things in torch._ops behave a little weirdly)
It is the thing that is called when you call the operator. | ||
""" | ||
|
||
def __init__(self, func, with_keyset=False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: make with_keyset kwarg-only for readability
if _has_pytree_type_in_args_or_kwargs(args, kwargs): | ||
op_overload = getattr(op, op.overloads()[0]) | ||
return op_overload(*args, **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- this is probably slow and needs caching on the OpOverload
- this is wrong if someone tries to register multiple OpOverloads for one OpOverloadPacket, so we should error out in that situation
need_python_dispatch = isinstance( | ||
self._opoverload, torch._ops.PythonDispatcherOpOverload | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: should probably make this a helper function
if need_python_dispatch: | ||
self._opoverload.py_impl(torch._subclasses.fake_tensor.FakeTensorMode)( | ||
fake_impl | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NB: this is kind of a hack
self._opoverload.py_impl(_C.DispatchKey.ADInplaceOrView)( | ||
adinplaceorview_impl | ||
) | ||
else: | ||
lib.impl( | ||
self._name, | ||
adinplaceorview_impl, | ||
"ADInplaceOrView", | ||
with_keyset=True, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel like we should align the API of py_impl and lib.impl. Otherwise it's kind of annoying
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doesn't need to happen in this PR, just a suggestion
] | ||
|
||
def _may_use_fallthrough_instead_of_fallback(key: DispatchKey): | ||
if torch._C._dispatch_has_kernel_for_dispatch_key(self.name(), key): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If there's a kernel in py_impl as well, then we should not use the fallthrough either
# TODO: we should be calling the fallback for these, but a fallthrough is almost close | ||
# enough to the fallback in most cases that we care about. | ||
_DEFAULT_FALLTHROUGH_KEYS = [ | ||
DispatchKey.ADInplaceOrView, | ||
DispatchKey.BackendSelect, | ||
DispatchKey.PythonTLSSnapshot, | ||
DispatchKey.PythonDispatcher, | ||
] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should try to model fallthroughs closer to how the C++ dispatcher does it. That is, the operator doesn't come with fallthrough keys, instead these are fallbacks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(maybe in a follow-up)
|
||
def redispatch(self, /, keyset, *args, **kwargs): | ||
return self._dispatch_in_python(args, kwargs, self._fallthrough_keys()) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At some point I want to try an exercise of "refactor the python dispatcher to look more like the C++ dispatcher". Concretely:
- each operator has an OperatorEntry
- there is a mapping from DispatchKey->kernel on the OperatorEntry
- OperatorEntry has a DispatchKeyExtractor
- registering a fallback as a py_kernel should modify the DispatchKeyExtractor
- one can use the DispatchKeyExtractor to get the next DispatchKey
- DispatchKeys can have fallbacks in Python
- etc
Now my question for you is, are you interested in doing this? It could be a good learning experience. If not I'm happy to try this refactor
def _has_pytree_type_in_args_or_kwargs(args, kwargs) -> bool: | ||
return any( | ||
not isinstance(x, (list, tuple)) | ||
and type(x) in torch.utils._pytree.SUPPORTED_NODES | ||
for x in itertools.chain(args, kwargs.values()) | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def _has_pytree_type_in_args_or_kwargs(args, kwargs) -> bool: | |
return any( | |
not isinstance(x, (list, tuple)) | |
and type(x) in torch.utils._pytree.SUPPORTED_NODES | |
for x in itertools.chain(args, kwargs.values()) | |
) | |
def _has_pytree_type_in_args_or_kwargs(args, kwargs) -> bool: | |
def is_list_or_tuple(x): | |
return isinstance(x, (list, tuple)) | |
return any( | |
not torch.utils._pytree.tree_is_leaf(x, is_leaf=is_list_or_tuple) | |
for x in itertools.chain(args, kwargs.values()) | |
) |
You can use tree_is_leaf
after #113257 get merged into main.
self.func = func | ||
self.with_keyset = with_keyset | ||
|
||
def __call__(self, *args, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def __call__(self, *args, **kwargs): | |
def __call__(self, /, *args, **kwargs): |
Allow 'self'
as a key in kwargs
.
Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as |
Test case:
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @amjames