10000 [Inductor] Pattern matcher support for mutable ops with non-view inputs · pytorch/pytorch@7a3d412 · GitHub
[go: up one dir, main page]

Skip to content

Commit 7a3d412

Browse files
committed
[Inductor] Pattern matcher support for mutable ops with non-view inputs
ghstack-source-id: 65eb9ad Pull-Request-resolved: #152767 ghstack-source-id: 65eb9ad Pull Request resolved: #152775
1 parent 56d6d4d commit 7a3d412

File tree

2 files changed

+100
-0
lines changed

2 files changed

+100
-0
lines changed

test/inductor/test_pattern_matcher.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Owner(s): ["module: inductor"]
22
import copy
3+
import functools
34
import itertools
45
import os
56
import unittest
@@ -32,6 +33,7 @@
3233
from torch._inductor.utils import run_and_get_code
3334
from torch._inductor.virtualized import V
3435
from torch.fx.experimental.proxy_tensor import make_fx
36+
from torch.library import register_fake
3537
from torch.testing import FileCheck
3638
from torch.testing._internal.common_cuda import SM80OrLater, xfailIfSM89
3739
from torch.testing._internal.common_device_type import expectedFailureXPU, skipCUDAIf
@@ -1655,6 +1657,99 @@ def my_func_static(x, w, epsilon):
16551657
test, (code,) = run_and_get_code(my_func_static, *inputs)
16561658
self.assertTrue("static_scaled_int8_quant" not in code)
16571659

1660+
def test_mutable_op_nonview_inputs_register_replacement(self):
1661+
@torch.library.custom_op("mylib::foo_inplace", mutates_args={"x"})
1662+
def foo_inplace(x: torch.Tensor) -> None:
1663+
x.add_(1)
1664+
1665+
# NOTE: only returning None is supported; the custom op cannot return `out` because it's part of op input.
1666+
@torch.library.custom_op("mylib::bar", mutates_args={"out"})
1667+
def bar_out(x: torch.Tensor, out: torch.Tensor) -> None:
1668+
out.copy_(x + 2)
1669+
1670+
@torch.library.custom_op("mylib::foobar_out", mutates_args={"x", "out"})
1671+
def foobar_out(x: torch.Tensor, out: torch.Tensor) -> None:
1672+
x.add_(1)
1673+
out.copy_(x + 7) # intentionally different from bar_out
1674+
1675+
def mutable_ops_pattern(x, out):
1676+
foo_inplace(x)
1677+
bar_out(x, out)
1678+
return x, out
1679+
1680+
def mutable_ops_replacement(x, out):
1681+
foobar_out(x, out)
1682+
return x, out
1683+
1684+
inp = torch.randn(3)
1685+
1686+
my_patterns = PatternMatcherPass()
1687+
register_replacement(
1688+
search_fn=mutable_ops_pattern,
1689+
replace_fn=mutable_ops_replacement,
1690+
example_inputs=[inp.clone().detach(), inp.clone().detach()],
1691+
trace_fn=functools.partial(fwd_only, apply_auto_functionalize=True),
1692+
pass_dicts=my_patterns,
1693+
)
1694+
1695+
count = 0
1696+
1697+
def custom_pass(graph: torch.fx.Graph):
1698+
nonlocal count
1699+
count = my_patterns.apply(graph)
1700+
1701+
def custom_backend(graph: torch.fx.GraphModule, example_inputs):
1702+
from torch._inductor import config
1703+
1704+
current_config = config.shallow_copy_dict()
1705+
from torch._inductor.compile_fx import compile_fx
1706+
1707+
current_config["post_grad_custom_post_pass"] = custom_pass
1708+
return compile_fx(graph, example_inputs, config_patches=current_config)
1709+
1710+
# Case 1: mutates a clone of graph input
1711+
@torch.compile(fullgraph=True, backend=custom_backend)
1712+
def f1(x):
1713+
x = x.clone()
1714+
out = torch.zeros_like(x)
1715+
foo_inplace(x)
1716+
bar_out(x, out)
1717+
return out
1718+
1719+
def f1_replaced(x):
1720+
x = x.clone()
1721+
out = torch.zeros_like(x)
1722+
foobar_out(x, out)
1723+
return out
1724+
1725+
f1_inp = inp.clone().detach()
1726+
f1_replaced_inp = inp.clone().detach()
1727+
f1_out = f1(f1_inp)
1728+
f1_replaced_out = f1_replaced(f1_replaced_inp)
1729+
self.assertEqual(f1_inp, f1_replaced_inp)
1730+
self.assertEqual(f1_out, f1_replaced_out)
1731+
self.assertEqual(count, 1)
1732+
1733+
# Case 2: mutates graph input (not supported yet)
1734+
@torch.compile(fullgraph=True, backend=custom_backend)
1735+
def f2(x):
1736+
out = torch.zeros_like(x)
1737+
foo_inplace(x)
1738+
bar_out(x, out)
1739+
return out
1740+
1741+
def f2_replaced(x):
1742+
out = torch.zeros_like(x)
1743+
foobar_out(x, out)
1744+
return out
1745+
1746+
f2_inp = inp.clone().detach()
1747+
f2_replaced_inp = inp.clone().detach()
1748+
f2_out = f2(f2_inp)
1749+
f2_replaced_out = f2_replaced(f2_replaced_inp)
1750+
self.assertEqual(f1_inp, f1_replaced_inp)
1751+
self.assertEqual(f2_out, f2_replaced_out)
1752+
self.assertEqual(count, 1)
16581753

16591754
if __name__ == "__main__":
16601755
if IS_LINUX and HAS_GPU:

torch/_inductor/pattern_matcher.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
from torch._dynamo.utils import counters
6363
from torch._prims_common import is_integer_dtype
6464
from torch._subclasses.fake_tensor import unset_fake_temporarily
65+
from torch._subclasses.functional_tensor import dispatch_functionalize
6566
from torch.fx.experimental.proxy_tensor import make_fx
6667
from torch.fx.experimental.symbolic_shapes import statically_known_true
6768
from torch.fx.graph_module import _get_attr
@@ -2082,6 +2083,7 @@ def fwd_only(
20822083
fn: Callable[..., Any],
20832084
args: Sequence[Any],
20842085
*,
2086+
apply_auto_functionalize: bool = False,
20852087
run_functional_passes: bool = True,
20862088
get_decomp_fn: Optional[Callable[..., Any]] = None,
20872089
) -> torch.fx.GraphModule:
@@ -2091,6 +2093,9 @@ def fwd_only(
20912093
decompositions = (
20922094
get_decomp_fn() if get_decomp_fn is not None else select_decomp_table()
20932095
)
2096+
# When true, apply auto_functionalize to the pattern to functionalize any mutable ops.
2097+
if apply_auto_functionalize:
2098+
fn = dispatch_functionalize(fn)
20942099
gm = make_fx(fn, decompositions, tracing_mode="real")(*args)
20952100

20962101
from .fx_passes.post_grad import remove_noop_ops

0 commit comments

Comments
 (0)
0