8000 [Inductor] Pattern matcher support for mutable ops with non-view inputs · pytorch/pytorch@924bc65 · GitHub
[go: up one dir, main page]

Skip to content

Commit 924bc65

Browse files
committed
[Inductor] Pattern matcher support for mutable ops with non-view inputs
ghstack-source-id: 01c0214 Pull-Request-resolved: #152767 ghstack-source-id: 01c0214 Pull Request resolved: #152775
1 parent 56d6d4d commit 924bc65

File tree

2 files changed

+79
-0
lines changed

2 files changed

+79
-0
lines changed

test/inductor/test_pattern_matcher.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Owner(s): ["module: inductor"]
22
import copy
3+
import functools
34
import itertools
45
import os
56
import unittest
@@ -32,6 +33,7 @@
3233
from torch._inductor.utils import run_and_get_code
3334
from torch._inductor.virtualized import V
3435
from torch.fx.experimental.proxy_tensor import make_fx
36+
from torch.library import register_fake
3537
from torch.testing import FileCheck
3638
from torch.testing._internal.common_cuda import SM80OrLater, xfailIfSM89
3739
from torch.testing._internal.common_device_type import expectedFailureXPU, skipCUDAIf
@@ -1655,6 +1657,78 @@ def my_func_static(x, w, epsilon):
16551657
test, (code,) = run_and_get_code(my_func_static, *inputs)
16561658
self.assertTrue("static_scaled_int8_quant" not in code)
16571659

1660+
def test_mutable_op_nonview_inputs_register_replacement(self):
1661+
@torch.library.custom_op("mylib::foo_inplace", mutates_args={"x"})
1662+
def foo_inplace(x: torch.Tensor) -> None:
1663+
x.add_(1)
1664+
1665+
# NOTE: only returning None is supported; the custom op cannot return `out` because it's part of op input.
1666+
@torch.library.custom_op("mylib::bar", mutates_args={"out"})
1667+
def bar_out(x: torch.Tensor, out: torch.Tensor) -> None:
1668+
out.copy_(x + 2)
1669+
1670+
@register_fake("mylib::bar")
1671+
def bar_out_fake(x: torch.Tensor, out: torch.Tensor) -> None:
1672+
return None
1673+
1674+
@torch.library.custom_op("mylib::foobar_out", mutates_args={"out"})
1675+
def foobar_out(x: torch.Tensor, out: torch.Tensor) -> None:
1676+
x.add_(1)
1677+
out.copy_(x + 7) # intentionally different from bar_out
1678+
1679+
def mutable_ops_pattern(x, out):
1680+
foo_inplace(x)
1681+
bar_out(x, out)
1682+
return out
1683+
1684+
def mutable_ops_replacement(x, out):
1685+
foobar_out(x, out)
1686+
return out
1687+
1688+
inp = torch.randn(3)
1689+
1690+
my_patterns = PatternMatcherPass()
1691+
register_replacement(
1692+
search_fn=mutable_ops_pattern,
1693+
replace_fn=mutable_ops_replacement,
1694+
example_inputs=[inp.clone().detach(), inp.clone().detach()],
1695+
trace_fn=functools.partial(fwd_only, apply_auto_functionalize=True),
1696+
pass_dicts=my_patterns,
1697+
)
1698+
1699+
count = 0
1700+
1701+
def custom_pass(graph: torch.fx.Graph):
1702+
nonlocal count
1703+
count = my_patterns.apply(graph)
1704+
1705+
def custom_backend(graph: torch.fx.GraphModule, example_inputs):
1706+
from torch._inductor import config
1707+
1708+
current_config = config.shallow_copy_dict()
1709+
from torch._inductor.compile_fx import compile_fx
1710+
1711+
current_config["post_grad_custom_post_pass"] = custom_pass
1712+
return compile_fx(graph, example_inputs, config_patches=current_config)
1713+
1714+
# user-function
1715+
@torch.compile(fullgraph=True, backend=custom_backend)
1716+
def f(x):
1717+
x = x.clone()
1718+
out = torch.zeros_like(x)
1719+
foo_inplace(x)
1720+
bar_out(x, out)
1721+
return out
1722+
1723+
def f_replaced(x):
1724+
x = x.clone()
1725+
out = torch.zeros_like(x)
1726+
foobar_out(x, out)
1727+
return out
1728+
1729+
self.assertEqual(f(inp.clone().detach()), f_replaced(inp.clone().detach()))
1730+
self.assertEqual(count, 1)
1731+
16581732

16591733
if __name__ == "__main__":
16601734
if IS_LINUX and HAS_GPU:

torch/_inductor/pattern_matcher.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
from torch._dynamo.utils import counters
6363
from torch._prims_common import is_integer_dtype
6464
from torch._subclasses.fake_tensor import unset_fake_temporarily
65+
from torch._subclasses.functional_tensor import dispatch_functionalize
6566
from torch.fx.experimental.proxy_tensor import make_fx
6667
from torch.fx.experimental.symbolic_shapes import statically_known_true
6768
from torch.fx.graph_module import _get_attr
@@ -2082,6 +2083,7 @@ def fwd_only(
20822083
fn: Callable[..., Any],
20832084
args: Sequence[Any],
20842085
*,
2086+
apply_auto_functionalize: bool = False,
20852087
run_functional_passes: bool = True,
20862088
get_decomp_fn: Optional[Callable[..., Any]] = None,
20872089
) -> torch.fx.GraphModule:
@@ -2091,6 +2093,9 @@ def fwd_only(
20912093
decompositions = (
20922094
get_decomp_fn() if get_decomp_fn is not None else select_decomp_table()
20932095
)
2096+
# When true, apply auto_functionalize to the pattern to functionalize any mutable ops.
2097+
if apply_auto_functionalize:
2098+
fn = dispatch_functionalize(fn)
20942099
gm = make_fx(fn, decompositions, tracing_mode="real")(*args)
20952100

20962101
from .fx_passes.post_grad import remove_noop_ops

0 commit comments

Comments
 (0)
0