8000 Custom ops support arbitrary input types by migrating to python dispatcher by yanboliang · Pull Request #147927 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Custom ops support arbitrary input types by migrating to python dispatcher #147927

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: main
Choose a base branch
from

Conversation

yanboliang
Copy link
Contributor
@yanboliang yanboliang commented Feb 26, 2025

Test case:

@torch.library.custom_op("mylib::foo", mutates_args
10000
=())
def foo(d: dict, t: torch.Tensor) -> torch.Tensor:
    return torch.sin(d["x"] - d["y"] + t)


@foo.register_fake
def _(d: dict, t: torch.Tensor) -> torch.Tensor:
    return torch.empty_like(d["x"])

d = {"x": torch.randn(2, 3, requires_grad=True), "y": torch.randn(2, 3, requires_grad=True)}
t = torch.randn(2, 3, requires_grad=True)

@torch.compile(backend="eager", fullgraph=True)
def fn(d, t):
    return torch.sin(torch.ops.mylib.foo.default(d, t) + 1.5)

y = fn(d, t)
print(y)
y.sum().backward()
print(d["x"].grad)
print(d["y"].grad)
print(t.grad)

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @amjames

Copy link
pytorch-bot bot commented Feb 26, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/147927

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 14c0439 with merge base 84e60ee (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@yanboliang yanboliang changed the title [Prototype] Custom ops support arbitrary input types by migrating to python dispatcher Custom ops support arbitrary input types by migrating to python dispatcher Mar 5, 2025

x = Point(x=torch.randn(2, 3), y=torch.randn(2, 3))
y = torch.ops.mylib.foo(x)
self.assertEqual(y, torch.sqrt(torch.sum((x.x - x.y) ** 2)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some high-level comments:

  1. Let's wait until after the branch cut (Monday) to merge this, assuming its ready before then. We don't want this feature to be partially in PyTorch 2.7.
  2. eager-mode performance is pretty important. Can you do some benchmarking comparing e.g. custom ops with dict input (uses the new path) to custom ops with list inputs (uses the C++ dispatcher)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree! I'll do the perf benchmark while you conduct code review.

Copy link
Contributor
@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My main comments are:

  1. eager-mode performance. Let's make sure this is good.
  2. the fallthrough mechanism isn't faithful to the C++ pytorch dispatcher. We should make it more faithful
  3. the FakeTensor registration mechanism is on the sketchy side. In particular, the py_impl(FakeTensorMode) completely bypasses FakeTensor caching. Maybe something to handle as a follow-up.

if annotation_type.__origin__ is tuple:
if annotation_type in torch.utils._pytree.SUPPORTED_NODES:
# TODO: Move to a separate schema type for pytrees.
schema_type = "Any"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when are we going to change this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tbh I kind of like having "Any" in the type, we could keep it for now

@@ -46,6 +47,23 @@ def dl_open_guard():
sys.setdlopenflags(old_flags)


class Kernel:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

put this somewhere else. Maybe in torch/_dispatch or torch/_library. Otherwise I think this gets exposed as torch.ops.Kernel (things in torch._ops behave a little weirdly)

It is the thing that is called when you call the operator.
"""

def __init__(self, func, with_keyset=False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: make with_keyset kwarg-only for readability

Comment on lines +1292 to +1294
if _has_pytree_type_in_args_or_kwargs(args, kwargs):
op_overload = getattr(op, op.overloads()[0])
return op_overload(*args, **kwargs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. this is probably slow and needs caching on the OpOverload
  2. this is wrong if someone tries to register multiple OpOverloads for one OpOverloadPacket, so we should error out in that situation

Comment on lines +626 to +628
need_python_dispatch = isinstance(
self._opoverload, torch._ops.PythonDispatcherOpOverload
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: should probably make this a helper function

Comment on lines +646 to +649
if need_python_dispatch:
self._opoverload.py_impl(torch._subclasses.fake_tensor.FakeTensorMode)(
fake_impl
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NB: this is kind of a hack

Comment on lines +672 to +681
self._opoverload.py_impl(_C.DispatchKey.ADInplaceOrView)(
adinplaceorview_impl
)
else:
lib.impl(
self._name,
adinplaceorview_impl,
"ADInplaceOrView",
with_keyset=True,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like we should align the API of py_impl and lib.impl. Otherwise it's kind of annoying

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't need to happen in this PR, just a suggestion

]

def _may_use_fallthrough_instead_of_fallback(key: DispatchKey):
if torch._C._dispatch_has_kernel_for_dispatch_key(self.name(), key):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there's a kernel in py_impl as well, then we should not use the fallthrough either

Comment on lines +955 to +962
# TODO: we should be calling the fallback for these, but a fallthrough is almost close
# enough to the fallback in most cases that we care about.
_DEFAULT_FALLTHROUGH_KEYS = [
DispatchKey.ADInplaceOrView,
DispatchKey.BackendSelect,
DispatchKey.PythonTLSSnapshot,
DispatchKey.PythonDispatcher,
]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should try to model fallthroughs closer to how the C++ dispatcher does it. That is, the operator doesn't come with fallthrough keys, instead these are fallbacks.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(maybe in a follow-up)


def redispatch(self, /, keyset, *args, **kwargs):
return self._dispatch_in_python(args, kwargs, self._fallthrough_keys())

Copy link
Contributor
@zou3519 zou3519 Mar 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At some point I want to try an exercise of "refactor the python dispatcher to look more like the C++ dispatcher". Concretely:

  • each operator has an OperatorEntry
  • there is a mapping from DispatchKey->kernel on the OperatorEntry
  • OperatorEntry has a DispatchKeyExtractor
  • registering a fallback as a py_kernel should modify the DispatchKeyExtractor
  • one can use the DispatchKeyExtractor to get the next DispatchKey
  • DispatchKeys can have fallbacks in Python
  • etc

Now my question for you is, are you interested in doing this? It could be a good learning experience. If not I'm happy to try this refactor

Comment on lines +1028 to +1033
def _has_pytree_type_in_args_or_kwargs(args, kwargs) -> bool:
return any(
not isinstance(x, (list, tuple))
and type(x) in torch.utils._pytree.SUPPORTED_NODES
for x in itertools.chain(args, kwargs.values())
)
Copy link
Collaborator
@XuehaiPan XuehaiPan Mar 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def _has_pytree_type_in_args_or_kwargs(args, kwargs) -> bool:
return any(
not isinstance(x, (list, tuple))
and type(x) in torch.utils._pytree.SUPPORTED_NODES
for x in itertools.chain(args, kwargs.values())
)
def _has_pytree_type_in_args_or_kwargs(args, kwargs) -> bool:
def is_list_or_tuple(x):
return isinstance(x, (list, tuple))
return any(
not torch.utils._pytree.tree_is_leaf(x, is_leaf=is_list_or_tuple)
for x in itertools.chain(args, kwargs.values())
)

You can use tree_is_leaf after #113257 get merged into main.

self.func = func
self.with_keyset = with_keyset

def __call__(self, *args, **kwargs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def __call__(self, *args, **kwargs):
def __call__(self, /, *args, **kwargs):

Allow 'self' as a key in kwargs.

@zou3519 zou3519 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Mar 17, 2025
Copy link
Contributor

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

@github-actions github-actions bot added the Stale label May 16, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/inductor module: dynamo open source Stale topic: not user facing topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants
1115
0