8000 Add torch.library.register_kernel (#124299) · pytorch/pytorch@bad8d25 · GitHub
[go: up one dir, main page]

Skip to content

Commit bad8d25

Browse files
zou3519pytorchmergebot
authored andcommitted
Add torch.library.register_kernel (#124299)
This mirrors the .register_kernel method on the object produced by the custom_op decorator. Test Plan: - new tests Pull Request resolved: #124299 Approved by: https://github.com/albanD ghstack dependencies: #124180, #124200
1 parent 3918dfe commit bad8d25

File tree

4 files changed

+170
-5
lines changed

4 files changed

+170
-5
lines changed

docs/source/library.rst

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,12 @@ Use :func:`torch.library.custom_op` to create new custom ops.
2121
Extending custom ops (created from Python or C++)
2222
-------------------------------------------------
2323

24-
Use the impl methods, such as :func:`torch.library.impl` and
25-
func:`torch.library.impl_abstract`, to add implementations
24+
Use the register.* methods, such as :func:`torch.library.register_kernel` and
< 8000 /td>
25+
func:`torch.library.register_fake`, to add implementations
2626
for any operators (they may have been created using :func:`torch.library.custom_op` or
2727
via PyTorch's C++ operator registration APIs).
2828

29-
.. autofunction:: impl
29+
.. autofunction:: register_kernel
3030
.. autofunction:: register_autograd
3131
.. autofunction:: register_fake
3232
.. autofunction:: impl_abstract
@@ -53,3 +53,5 @@ A tutorial that walks you through some examples on how to use this API is availa
5353
.. autofunction:: fallthrough_kernel
5454

5555
.. autofunction:: define
56+
57+
.. autofunction:: impl

test/test_custom_ops.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2359,6 +2359,110 @@ def _(x, y):
23592359
self.assertEqual(z.shape, x.shape)
23602360
self.assertTrue(called)
23612361

2362+
@skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
2363+
def test_library_register_kernel(self):
2364+
modes = ["function", "qualname", "opoverload"]
2365+
calls = ["decorator", "function"]
2366+
device_types_options = ["cpu", None]
2367+
2368+
for mode, call, device_types in itertools.product(
2369+
modes, calls, device_types_options
2370+
):
2371+
2372+
@torch.library.custom_op(
2373+
"_torch_testing::add", mutates_args=(), device_types="cuda"
2374+
)
2375+
def add(x: Tensor, y: float) -> Tensor:
2376+
x_np = x.cpu().numpy()
2377+
out_np = x_np + y
2378+
return torch.from_numpy(out_np).to(x.device)
2379+
2380+
if mode == "function":
2381+
op = add
2382+
elif mode == "qualname":
2383+
op = "_torch_testing::add"
2384+
else:
2385+
assert mode == "opoverload"
2386+
op = torch.ops._torch_testing.add.default
2387+
2388+
called = False
2389+
2390+
if call == "decorator":
2391+
2392+
@torch.library.register_kernel(op, device_types)
2393+
def _(x, y):
2394+
nonlocal called
2395+
called = True
2396+
x_np = x.numpy()
2397+
out_np = x_np + y
2398+
return torch.from_numpy(out_np)
2399+
2400+
else:
2401+
assert call == "function"
2402+
2403+
def add_cpu(x, y):
2404+
nonlocal called
2405+
called = True
2406+
x_np = x.numpy()
2407+
out_np = x_np + y
2408+
return torch.from_numpy(out_np)
2409+
2410+
torch.library.register_kernel(op, device_types, add_cpu)
2411+
2412+
x = torch.randn(3)
2413+
y = 3.14
2414+
z = add(x, y)
2415+
self.assertEqual(z, x + y)
2416+
self.assertTrue(called)
2417+
2418+
@skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
2419+
def test_library_register_kernel_low_level(self):
2420+
modes = ["qualname", "opoverload"]
2421+
calls = ["decorator", "function"]
2422+
device_types_options = [("cpu", "cuda"), "cpu", None]
2423+
2424+
for mode, call, device_types in itertools.product(
2425+
modes, calls, device_types_options
2426+
):
2427+
with torch.library._scoped_library("_torch_testing", "FRAGMENT") as lib:
2428+
lib.define("add9(Tensor x, float y) -> Tensor")
2429+
2430+
if mode == "qualname":
2431+
op = "_torch_testing::add9"
2432+
else:
2433+
assert mode == "opoverload"
2434+
op = torch.ops._torch_testing.add9.default
2435+
2436+
called = False
2437+
2438+
if call == "decorator":
2439+
2440+
@torch.library.register_kernel(op, device_types, lib=lib)
2441+
def _(x, y):
2442+
nonlocal called
2443+
called = True
2444+
x_np = x.numpy()
2445+
out_np = x_np + y
2446+
return torch.from_numpy(out_np)
2447+
2448+
else:
2449+
assert call == "function"
2450+
2451+
def add_cpu(x, y):
2452+
nonlocal called
2453+
called = True
2454+
x_np = x.numpy()
2455+
out_np = x_np + y
2456+
return torch.from_numpy(out_np)
2457+
2458+
torch.library.register_kernel(op, device_types, add_cpu, lib=lib)
2459+
2460+
x = torch.randn(3)
2461+
y = 3.14
2462+
z = torch.ops._torch_testing.add9.default(x, y)
2463+
self.assertEqual(z, x + y)
2464+
self.assertTrue(called)
2465+
23622466
@skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
23632467
def test_library_register_autograd(self):
23642468
for mode in ["function", "qualname", "opoverload"]:

torch/_library/custom_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ def register_kernel(
191191
>>> from torch.library import custom_op
192192
>>> import numpy as np
193193
>>>
194-
>>> # Example of split cpu and cuda definitions
194+
>>> # Create a custom op that works on cpu
195195
>>> @custom_op("mylib::numpy_sin", mutates_args=(), device_types="cpu")
196196
>>> def numpy_sin(x: Tensor) -> Tensor:
197197
>>> x_np = x.numpy()

torch/library.py

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import contextlib
1010
import sys
1111
import warnings
12-
from torch._library.custom_ops import custom_op, _maybe_get_opdef
12+
from torch._library.custom_ops import custom_op, _maybe_get_opdef, device_types_t
1313
import torch._library as _library
1414

1515

@@ -424,6 +424,65 @@ def impl_abstract(qualname, func=None, *, lib=None, _stacklevel=1):
424424
_op_identifier = Union[str, "torch._ops.OpOverload", "torch._library.custom_ops.CustomOpDef"]
425425

426426

427+
def register_kernel(
428+
op: _op_identifier,
429+
device_types: device_types_t,
430+
func: Optional[Callable] = None,
431+
/,
432+
*,
433+
lib: Optional[Library] = None):
434+
"""Register an implementation for a device type for this operator.
435+
436+
Some valid device_types are: "cpu", "cuda", "xla", "mps", "ipu", "xpu".
437+
This API may be used as a decorator.
438+
439+
Args:
440+
fn (Callable): The function to register as the implementation for
441+
the given device types.
442+
device_types (None | str | Sequence[str]): The device_types to register an impl to.
443+
If None, we will register to all device types -- please only use
444+
this option if your implementation is truly device-type-agnostic.
445+
446+
Examples::
447+
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
448+
>>> import torch
449+
>>> from torch import Tensor
450+
>>> from torch.library import custom_op
451+
>>> import numpy as np
452+
>>>
453+
>>> # Create a custom op that works on cpu
454+
>>> @custom_op("mylib::numpy_sin", mutates_args=(), device_types="cpu")
455+
>>> def numpy_sin(x: Tensor) -> Tensor:
456+
>>> x_np = x.numpy()
457+
>>> y_np = np.sin(x_np)
458+
>>> return torch.from_numpy(y_np)
459+
>>>
460+
>>> # Add implementations for the cuda device
461+
>>> @torch.library.register_kernel("mylib::numpy_sin", "cuda")
462+
>>> def _(x):
463+
>>> x_np = x.cpu().numpy()
464+
>>> y_np = np.sin(x_np)
465+
>>> return torch.from_numpy(y_np).to(device=x.device)
466+
>>>
467+
>>> x_cpu = torch.randn(3)
468+
>>> x_cuda = x_cpu.cuda()
469+
>>> assert torch.allclose(numpy_sin(x_cpu), x_cpu.sin())
470+
>>> assert torch.allclose(numpy_sin(x_cuda), x_cuda.sin())
471+
472+
"""
473+
474+
if not isinstance(op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef)):
475+
raise ValueError("register_kernel(op): got unexpected type for op: {type(op)}")
476+
if isinstance(op, torch._ops.OpOverload):
477+
op = op._name
478+
opdef = _maybe_get_opdef(op)
479+
if opdef is not None:
480+
return opdef.register_kernel(device_types, func)
481+
assert isinstance(op, str)
482+
if device_types is None:
483+
device_types = "CompositeExplicitAutograd"
484+
return impl(op, device_types, func, lib=lib)
485+
427486

428487
def register_fake(
429488
op: _op_identifier,

0 commit comments

Comments
 (0)
0