8000 Custom ops support arbitrary input types by migrating to python dispatcher by yanboliang · Pull Request #147927 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Custom ops support arbitrary input types by migrating to python dispatcher #147927

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions test/dynamo/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,45 @@ def g(x):
g(x) # dynamo falls back on the outermost frame
self.assertEqual(len(counters["graph_break"]), 0)

def test_custom_op_pytree_input(self):
@torch.library.custom_op("mylib::foo", mutates_args=())
def foo(d: dict, t: torch.Tensor) -> torch.Tensor:
return torch.sin(d["x"] - d["y"] + t)

@foo.register_fake
def _(a: dict, t: torch.Tensor) -> torch.Tensor:
return torch.empty_like(t)

d = {"x": torch.randn(2, 3), "y": torch.randn(2, 3)}
t = torch.randn(2, 3)

cnt = CompileCounterWithBackend("eager")

@torch.compile(backend=cnt, fullgraph=True)
def fn(d, t):
return torch.ops.mylib.foo(d, t)

self.assertEqual(fn(d, t), torch.sin(d["x"] - d["y"] + t))

if torch._dynamo.config.assume_static_by_default:
actual_graph = torch._dynamo.testing.normalize_gm(
cnt.graphs[0].print_readable(print_output=False)
)
self.assertExpectedInline(
actual_graph,
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_t_: "f32[2, 3]", L_d_y_: "f32[2, 3]", L_d_x_: "f32[2, 3]"):
l_t_ = L_t_
l_d_y_ = L_d_y_
l_d_x_ = L_d_x_

mylib_foo_input_spec : torch.utils._pytree.TreeSpec = self.mylib_foo_input_spec
flat_apply: "f32[2, 3]" = torch.ops.higher_order.flat_apply(torch.ops.mylib.foo.default, mylib_foo_input_spec, l_d_x_, l_d_y_, l_t_); mylib_foo_input_spec = l_d_x_ = l_d_y_ = l_t_ = None
return (flat_apply,)
""",
)

def test_invalid_args_builtin(self):
@torch.compile(backend="eager")
def fn(x):
Expand Down
30 changes: 30 additions & 0 deletions test/test_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# ruff: noqa: F841

import collections
import dataclasses
import itertools
import os
import re
Expand Down Expand Up @@ -50,6 +51,12 @@
MyTensor = torch.Tensor


@dataclasses.dataclass
class Point:
x: torch.Tensor
y: torch.Tensor


def requires_compile(fun):
fun = unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")(fun)
return fun
Expand Down Expand Up @@ -2460,6 +2467,29 @@ def foo3(x: Tensor, *, y: int, z: List[Tensor]) -> Tensor:
lambda info, in_dims, x, *, y: (x, 0),
)

@skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
def test_dict_input(self):
@torch.library.custom_op("mylib::foo", mutates_args=())
def foo(d: dict, t: torch.Tensor) -> torch.Tensor:
return torch.sin(d["x"] - d["y"] + t)

d = {"x": torch.randn(2, 3), "y": torch.randn(2, 3)}
t = torch.randn(2, 3)
y = torch.ops.mylib.foo(d, t)
self.assertEqual(y, torch.sin(d["x"] - d["y"] + t))

@skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
def test_dataclass_input(self):
torch.utils._pytree.register_dataclass(Point)

@torch.library.custom_op("mylib::foo", mutates_args=())
def foo(a: Point) -> torch.Tensor:
return torch.sqrt(torch.sum((a.x - a.y) ** 2))

x = Point(x=torch.randn(2, 3), y=torch.randn(2, 3))
y = torch.ops.mylib.foo(x)
self.assertEqual(y, torch.sqrt(torch.sum((x.x - x.y) ** 2)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some high-level comments:

  1. Let's wait until after the branch cut (Monday) to merge this, assuming its ready before then. We don't want this feature to be partially in PyTorch 2.7.
  2. eager-mode performance is pretty important. Can you do some benchmarking comparing e.g. custom ops with dict input (uses the new path) to custom ops with list inputs (uses the C++ dispatcher)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree! I'll do the perf benchmark while you conduct code review.


@skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
def test_register_autograd_kwargonly_low_level(self):
with torch.library._scoped_library("_torch_testing", "FRAGMENT") as lib:
Expand Down
1 change: 1 addition & 0 deletions torch/_dynamo/trace_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -3225,6 +3225,7 @@ def _module_dir(m: types.ModuleType):
LEGACY_MOD_INLINELIST = {
"torch._dynamo.external_utils",
"torch._export.db.examples",
"torch._export.utils",
"torch._export.wrappers",
"torch._functorch.apis",
"torch._functorch.deprecated",
Expand Down
56 changes: 55 additions & 1 deletion torch/_dynamo/variables/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import torch._refs
import torch.fx
import torch.nn
import torch.utils._pytree as pytree
from torch._guards import TracingContext
from torch._logging import warning_once
from torch.utils._python_dispatch import is_traceable_wrapper_subclass_type
Expand Down Expand Up @@ -209,6 +210,14 @@ def get_overridable_functions():
return funcs


def is_python_dispatcher_op_overload(op):
return (
isinstance(op, torch._ops.PythonDispatcherOpOverload)
or isinstance(op, torch._ops.OpOverloadPacket)
and op._has_pytree_arg_overload
)


class BaseTorchVariable(VariableTracker):
"""common base for all torch.* functions, classes, modules and other things"""

Expand Down Expand Up @@ -982,11 +991,12 @@ def call_function(
args: Sequence[VariableTracker],
kwargs: "dict[str, VariableTracker]",
) -> "VariableTracker":
import torch._higher_order_ops.flat_apply as flat_apply

from . import ConstantVariable, SymNodeVariable, TensorVariable
from .builder import wrap_fx_proxy

if self.nonstrict_traceable:
import torch._higher_order_ops.flat_apply as flat_apply
from torch._higher_order_ops.flat_apply import (
func_to_graphable,
is_graphable_type,
Expand Down Expand Up @@ -1112,6 +1122,50 @@ def patched_fn(*args, **kwargs):
if result:
return result

# If the function is custom op, we need to wrap it as flat_apply call in the fx graph.
if is_python_dispatcher_op_overload(self.value):
if isinstance(self.value, torch._ops.OpOverload):
opoverload = self.value
else:
assert isinstance(self.value, torch._ops.OpOverloadPacket)
opoverload = self.value.default
packed_input_vt = TupleVariable.build(
tx,
(
variables.TupleVariable.build(tx, args),
variables.ConstDictVariable.build(tx, kwargs),
),
)
flat_args_and_spec = variables.UserFunctionVariable(
pytree.tree_flatten
).call_function(tx, [packed_input_vt], {})

assert isinstance(flat_args_and_spec, variables.TupleVariable)
flat_args_vt, in_spec_vt = flat_args_and_spec.items
assert isinstance(flat_args_vt, variables.ListVariable)
in_spec = in_spec_vt.as_python_constant()
in_spec_proxy = tx.output.register_static_attr_and_return_proxy(
f"{opoverload._namespace}_{opoverload._opname}_input_spec", in_spec
)

proxified_flat_args = [
flat_arg_vt.as_proxy() for flat_arg_vt in flat_args_vt.items
]

in_spec_proxy.node.type = type(in_spec)
all_args = (opoverload, in_spec_proxy, *proxified_flat_args)

out_vt = wrap_fx_proxy(
tx=tx,
proxy=tx.output.create_proxy(
"call_function",
flat_apply,
all_args,
{},
),
)
return out_vt

any_symints_or_symfloats = any(isinstance(x, SymNodeVariable) for x in args)

all_ints_or_floats = all(
Expand Down
67 changes: 50 additions & 17 deletions torch/_library/custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,16 +346,30 @@ def get_module():
)
return result

need_python_dispatch = hasattr(self, "_opoverload") and isinstance(
self._opoverload, torch._ops.PythonDispatcherOpOverload
)

if device_type is None:
self._lib.impl(
self._name, backend_impl, "CompositeExplicitAutograd"
)
if need_python_dispatch:
self._opoverload.py_impl(
_C.DispatchKey.CompositeExplicitAutograd
)(backend_impl)
else:
self._lib.impl(
self._name, backend_impl, "CompositeExplicitAutograd"
)
else:
self._lib.impl(
self._name,
backend_impl,
_C._dispatch_key_for_device(device_type),
)
if need_python_dispatch:
dispatch_key = _C._dispatch_key_for_device(device_type)
dispatch_key = getattr(_C.DispatchKey, dispatch_key)
self._opoverload.py_impl(dispatch_key)(backend_impl)
else:
self._lib.impl(
self._name,
backend_impl,
_C._dispatch_key_for_device(device_type),
)

# Wrap function to choose between the default implementation or the device-specific
# implementation depending on if the kernel is disabled.
Expand Down Expand Up @@ -609,6 +623,10 @@ def _register_to_dispatcher(self) -> None:
)
self._opoverload = utils.lookup_op(self._qualname)

need_python_dispatch = isinstance(
self._opoverload, torch._ops.PythonDispatcherOpOverload
)
Comment on lines +626 to +628
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: should probably make this a helper function


def fake_impl(*args, **kwargs):
if self._abstract_fn is None:
if utils.can_generate_trivial_fake_impl(self._opoverload):
Expand All @@ -619,12 +637,22 @@ def fake_impl(*args, **kwargs):
f"Please use `{self._init_fn.__name__}.register_fake` to add an "
f"fake impl."
)
if need_python_dispatch:
args = args[1:] # Remove the dispatch key under fake tensor mode.
return self._abstract_fn(*args, **kwargs)

lib._register_fake(self._name, fake_impl, _stacklevel=4)

autograd_impl = autograd.make_autograd_impl(self._opoverload, self)
lib.impl(self._name, autograd_impl, "Autograd", with_keyset=True)

if need_python_dispatch:
self._opoverload.py_impl(torch._subclasses.fake_tensor.FakeTensorMode)(
fake_impl
)
Comment on lines +646 to +649
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NB: this is kind of a hack

self._opoverload.py_impl(_C.DispatchKey.Autograd, with_keyset=True)(
autograd_impl
)
else:
lib._register_fake(self._name, fake_impl, _stacklevel=4)
lib.impl(self._name, autograd_impl, "Autograd", with_keyset=True)

schema = self._opoverload._schema
if schema.is_mutable:
Expand All @@ -640,12 +668,17 @@ def adinplaceorview_impl(keyset, *args, **kwargs):
keyset & _C._after_ADInplaceOrView_keyset, *args, **kwargs
)

lib.impl(
self._name,
adinplaceorview_impl,
"ADInplaceOrView",
with_keyset=True,
)
if need_python_dispatch:
self._opoverload.py_impl(_C.DispatchKey.ADInplaceOrView)(
adinplaceorview_impl
)
else:
lib.impl(
self._name,
adinplaceorview_impl,
"ADInplaceOrView",
with_keyset=True,
)
Comment on lines +672 to +681
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like we should align the API of py_impl and lib.impl. Otherwise it's kind of annoying

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't need to happen in this PR, just a suggestion


def _register_backend_select_dispatcher(self, device_arg_index: int):
"""
Expand Down
21 changes: 12 additions & 9 deletions torch/_library/infer_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,10 @@ def unstringify_type(ty: Union[type[object], str]) -> tuple[typing.Any, bool]:
annotation_type, _ = unstringify_type(param.annotation)

if annotation_type not in SUPPORTED_PARAM_TYPES:
if annotation_type.__origin__ is tuple:
if annotation_type in torch.utils._pytree.SUPPORTED_NODES:
# TODO: Move to a separate schema type for pytrees.
schema_type = "Any"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when are we going to change this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tbh I kind of like having "Any" in the type, we could keep it for now

elif annotation_type.__origin__ is tuple:
list_type = tuple_to_list(annotation_type)
example_type_str = "\n\n"
# Only suggest the list type if this type is supported.
Expand All @@ -141,25 +144,25 @@ def unstringify_type(ty: Union[type[object], str]) -> tuple[typing.Any, bool]:
f"Parameter {name} has unsupported type {param.annotation}. "
f"The valid types are: {SUPPORTED_PARAM_TYPES.keys()}."
)

schema_type = SUPPORTED_PARAM_TYPES[annotation_type]
else:
schema_type = SUPPORTED_PARAM_TYPES[annotation_type]
if type(mutates_args) == str:
if mutates_args != UNKNOWN_MUTATES:
raise ValueError(
"mutates_args must either be a sequence of the names of "
"the arguments that are mutated or the string 'unknown'. "
)
if schema_type.startswith("Tensor"):
schema_type = f"Tensor(a{idx}!){schema_type[len('Tensor'):]}"
if schema_type.startswith("Tensor"): # type: ignore[possibly-undefined]
schema_type = f"Tensor(a{idx}!){schema_type[len('Tensor'):]}" # type: ignore[possibly-undefined]
elif name in mutates_args:
if not schema_type.startswith("Tensor"):
if not schema_type.startswith("Tensor"): # type: ignore[possibly-undefined]
error_fn(
f"Parameter {name} is in mutable_args but only Tensors or collections of Tensors can be mutated"
)
schema_type = f"Tensor(a{idx}!){schema_type[len('Tensor'):]}"
schema_type = f"Tensor(a{idx}!){schema_type[len('Tensor'):]}" # type: ignore[possibly-undefined]
seen_args.add(name)
if param.default is inspect.Parameter.empty:
params.append(f"{schema_type} {name}")
params.append(f"{schema_type} {name}") # type: ignore[possibly-undefined]
else:
default_repr = None
if param.default is None or isinstance(param.default, (int, float, bool)):
Expand All @@ -176,7 +179,7 @@ def unstringify_type(ty: Union[type[object], str]) -> tuple[typing.Any, bool]:
f"Parameter {name} has an unsupported default value type {type(param.default)}. "
f"Please file an issue on GitHub so we can prioritize this."
)
params.append(f"{schema_type} {name}={default_repr}")
params.append(f"{schema_type} {name}={default_repr}") # type: ignore[possibly-undefined]
if mutates_args != UNKNOWN_MUTATES:
mutates_args_not_seen = set(mutates_args) - seen_args
if len(mutates_args_not_seen) > 0:
Expand Down
Loading
Loading
0