8000 [Inductor] Pattern matcher support for mutable ops with non-view inputs · pytorch/pytorch@eaf12c8 · GitHub
[go: up one dir, main page]

Skip to content

Commit eaf12c8

Browse files
committed
[Inductor] Pattern matcher support for mutable ops with non-view inputs
ghstack-source-id: daace00 Pull-Request-resolved: #152767 ghstack-source-id: dbf44a7 Pull Request resolved: #152775
1 parent 56d6d4d commit eaf12c8

File tree

4 files changed

+100
-3
lines changed

4 files changed

+100
-3
lines changed

test/inductor/test_pattern_matcher.py

+75
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import os
55
import unittest
66
from typing import Callable, Optional
7+
import functools
78

89
import torch
910
import torch._dynamo.config as dynamo_config
@@ -44,6 +45,7 @@
4445
)
4546
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU, IS_BIG_GPU
4647
from torch.utils import _pytree as pytree
48+
from torch.library import register_fake
4749

4850

4951
aten = torch.ops.aten
@@ -1655,6 +1657,79 @@ def my_func_static(x, w, epsilon):
16551657
test, (code,) = run_and_get_code(my_func_static, *inputs)
16561658
self.assertTrue("static_scaled_int8_quant" not in code)
16571659

1660+
def test_mutable_op_nonview_inputs_register_replacement(self):
1661+
@torch.library.custom_op("mylib::foo_inplace", mutates_args={"x"})
1662+
def foo_inplace(x: torch.Tensor) -> None:
1663+
x.add_(1)
1664+
1665+
# NOTE: only returning None is supported; the custom op cannot return `out`.
1666+
@torch.library.custom_op("mylib::bar", mutates_args={"out"})
1667+
def bar_out(x: torch.Tensor, out: torch.Tensor) -> None:
1668+
out.copy_(x + 2)
1669+
1670+
@register_fake("mylib::bar")
1671+
def bar_out_fake(x: torch.Tensor, out: torch.Tensor) -> None:
1672+
return None
1673+
1674+
@torch.library.custom_op("mylib::foobar_out", mutates_args={"out"})
1675+
def foobar_out(x: torch.Tensor, out: torch.Tensor) -> None:
1676+
x.add_(1)
1677+
out.copy_(x + 7)
1678+
1679+
def mutable_ops_pattern(x, out):
1680+
foo_inplace(x)
1681+
bar_out(x, out)
1682+
return out
1683+
1684+
def mutable_ops_replacement(x, out):
1685+
foobar_out(x, out)
1686+
return out
1687+
1688+
inp = torch.randn(3)
1689+
1690+
my_patterns = PatternMatcherPass()
1691+
register_replacement(
1692+
search_fn=mutable_ops_pattern,
1693+
replace_fn=mutable_ops_replacement,
1694+
example_inputs=[inp.clone().detach(), inp.clone().detach()],
1695+
trace_fn=functools.partial(fwd_only, apply_auto_functionalize=True),
1696+
pass_dicts=my_patterns,
1697+
)
1698+
1699+
count = 0
1700+
1701+
def custom_pass(graph: torch.fx.Graph):
1702+
global count
1703+
count = my_patterns.apply(graph)
1704+
1705+
def custom_backend(
1706+
graph: torch.fx.GraphModule, example_inputs
1707+
):
1708+
from torch._inductor import config
1709+
1710+
current_config = config.shallow_copy_dict()
1711+
from torch._inductor.compile_fx import compile_fx
1712+
1713+
current_config["post_grad_custom_post_pass"] = custom_pass
1714+
return compile_fx(graph, example_inputs, config_patches=current_config)
1715+
1716+
# user-function
1717+
@torch.compile(fullgraph=True, backend=custom_backend)
1718+
def f(x):
1719+
x = x.clone()
1720+
out = torch.zeros_like(x)
1721+
foo_inplace(x)
1722+
bar_out(x, out)
1723+
r 1E0A eturn out
1724+
1725+
def f_replaced(x):
1726+
x = x.clone()
1727+
out = torch.zeros_like(x)
1728+
foobar_out(x, out)
1729+
return out
1730+
1731+
self.assertEqual(f(inp.clone().detach()), f_replaced(inp.clone().detach()))
1732+
self.assertEqual(count, 1)
16581733

16591734
if __name__ == "__main__":
16601735
if IS_LINUX and HAS_GPU:
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
DEBUG NOTES:
2+
1. even the nonmutable-op version doesn't work right now, because custom op is automatically wrapped in auto_functionalized / auto_functionalized_v2,
3+
while the pattern is looking for vanilla ops.
4+
TODO: we should convert the pattern to auto_functionalized_v2 and then do matching.
5+
- Richard said we can maybe use torch.func.functionalize + make_fx
6+
2. after the nonmutable-op version is fixed, we will move to mutable-op-nonview version.

torch/_inductor/fx_passes/post_grad.py

+1
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,7 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool):
217217
),
218218
)
219219
gm.graph.lint()
220+
print(f"after post_grad_passes: gm: {gm}")
220221

221222

222223
def prepare_softmax_pattern(x, dim):

torch/_inductor/pattern_matcher.py

+18-3
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@
7777
from . import config
7878
from .decomposition import select_decomp_table
7979
from .lowering import fallback_node_due_to_unsupported_type
80+
from torch._subclasses.functional_tensor import FunctionalTensorMode, FunctionalTensor, dispatch_functionalize
8081

8182

8283
log = logging.getLogger(__name__)
@@ -1017,8 +1018,12 @@ def run(obj: PatternExpr, output_name: str = "output") -> str:
10171018
"""
10181019

10191020
pp = PatternPrettyPrinter()
1020-
assert hasattr(obj, "pretty_print")
1021-
out_str = obj.pretty_print(pp=pp)
1021+
print(f"obj: {obj}, type(obj): {type(obj)}")
1022+
if isinstance(obj, KeywordArg):
1023+
out_str = obj.name
1024+
else:
1025+
assert hasattr(obj, "pretty_print")
1026+
out_str = obj.pretty_print(pp=pp)
10221027

10231028
output = [
10241029
f"{pp.memoized_objs_names[key]} = {pp.memoized_objs_pp[key]}"
@@ -1072,6 +1077,7 @@ def register(
10721077
target: Union[torch.fx.node.Target, None] = None,
10731078
prepend: bool = False,
10741079
) -> None:
1080+
print(f"target: {target}, self.pattern: {self.pattern}")
10751081
if target is None:
10761082
assert hasattr(self.pattern, "fns")
10771083
for fn in self.pattern.fns:
@@ -1902,6 +1908,12 @@ def __getitem__(self, item: tuple[str, torch.fx.node.Target]) -> list[PatternEnt
19021908
return self.patterns[item]
19031909

19041910
def apply(self, gm: Union[torch.fx.GraphModule, torch.fx.Graph]) -> int:
1911+
import traceback
1912+
traceback.print_stack()
1913+
if "pass_pattern_" in str(traceback.format_stack()):
1914+
print(f"PatternMatcherPass: apply: entering, pass_pattern_, gm: {gm}")
1915+
for op, target in self.patterns:
1916+
print(f"self.patterns: op: {op}, target: {target}")
19051917
if not self.patterns:
19061918
return 0
19071919
if isinstance(gm, torch.fx.GraphModule):
@@ -2082,15 +2094,18 @@ def fwd_only(
20822094
fn: Callable[..., Any],
20832095
args: Sequence[Any],
20842096
*,
2097+
apply_auto_functionalize: bool = False,
20852098
run_functional_passes: bool = True,
20862099
get_decomp_fn: Optional[Callable[..., Any]] = None,
20872100
) -> torch.fx.GraphModule:
20882101
"""Build a normalized inference graph, for use with fx_to_pattern"""
2089-
# TODO - look into using aot autograd, asserting no mutating ops here
2102+
20902103
with enable_python_dispatcher():
20912104
decompositions = (
20922105
get_decomp_fn() if get_decomp_fn is not None else select_decomp_table()
20932106
)
2107+
if apply_auto_functionalize:
2108+
fn = dispatch_functionalize(fn)
20942109
gm = make_fx(fn, decompositions, tracing_mode="real")(*args)
20952110

20962111
from .fx_passes.post_grad import remove_noop_ops

0 commit comments

Comments
 (0)
0