8000 [WIP] Pattern matcher support for custom op · pytorch/pytorch@cf77ed9 · GitHub
[go: up one dir, main page]

Skip to content

Commit cf77ed9

Browse files
committed
[WIP] Pattern matcher support for custom op
ghstack-source-id: daace00 Pull-Request-resolved: #152767
1 parent 56d6d4d commit cf77ed9

File tree

5 files changed

+127
-2
lines changed

5 files changed

+127
-2
lines changed
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
DEBUG NOTES:
2+
1. even the nonmutable-op version doesn't work right now, because custom op is automatically wrapped in auto_functionalized / auto_functionalized_v2,
3+
while the pattern is looking for vanilla ops.
4+
TODO: we should convert the pattern to auto_functionalized_v2 and then do matching.
5+
- Richard said we can maybe use torch.func.functionalize + make_fx
6+
2. after the nonmutable-op version is fixed, we will move to mutable-op-nonview version.
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import torch
2+
from torch.library import register_fake
3+
from torch._inductor.pattern_matcher import register_replacement, fwd_only, PatternMatcherPass
4+
5+
@torch.library.custom_op("mylib::foo_inplace", mutates_args={"x"})
6+
def foo_inplace(x: torch.Tensor) -> None:
7+
x.add_(1)
8+
9+
# NOTE: only returning None is supported; the custom op cannot return `out`.
10+
@torch.library.custom_op("mylib::bar", mutates_args={"out"})
11+
def bar_out(x: torch.Tensor, out: torch.Tensor) -> None:
12+
out.copy_(x + 2)
13+
14+
@register_fake("mylib::bar")
15+
def bar_out_fake(x: torch.Tensor, out: torch.Tensor) -> None:
16+
return None
17+
18+
@torch.library.custom_op("mylib::foobar_out", mutates_args={"out"})
19+
def foobar_out(x: torch.Tensor, out: torch.Tensor) -> torch.Tensor:
20+
x.add_(1)
21+
out.copy_(x + 2)
22+
return out
23+
24+
def pattern(x, out):
25+
foo_inplace(x)
26+
bar_out(x, out)
27+
return out
28+
29+
def replacement(x, out):
30+
return foobar_out(x, out)
31+
32+
patterns = PatternMatcherPass()
33+
register_replacement(
34+
search_fn=pattern,
35+
replace_fn=replacement,
36+
example_inputs=(torch.randn(3), torch.randn(3)),
37+
trace_fn=fwd_only,
38+
pass_dicts=patterns,
39+
)
40+
41+
# user-function
42+
@torch.compile(fullgraph=True)
43+
def f(x):
44+
x = x.clone()
45+
out = torch.empty_like(x)
46+
foo_inplace(x)
47+
bar_out(x, out)
48+
return out
49+
50+
x = torch.randn(3, device="cpu")
51+
f(x)
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import torch
2+
from torch.library import register_fake
3+
from torch._inductor.pattern_matcher import register_replacement, fwd_only, PatternMatcherPass
4+
5+
@torch.library.custom_op("mylib::foo", mutates_args={})
6+
def foo(x: torch.Tensor) -> torch.Tensor:
7+
return x + 1
8+
9+
@register_fake("mylib::foo")
10+
def foo_fake(x: torch.Tensor) -> torch.Tensor:
11+
return x
12+
13+
@torch.library.custom_op("mylib::bar", mutates_args={})
14+
def bar(x: torch.Tensor) -> torch.Tensor:
15+
return x + 2
16+
17+
@register_fake("mylib::bar")
18+
def bar_fake(x: torch.Tensor) -> torch.Tensor:
19+
return x
20+
21+
@torch.library.custom_op("mylib::foobar", mutates_args={})
22+
def foobar(x: torch.Tensor) -> torch.Tensor:
23+
return x + 3
24+
25+
def pattern(x):
26+
o1 = foo(x)
27+
o2 = bar(o1)
28+
return o2
29+
30+
def replacement(x):
31+
return foobar(x)
32+
33+
patterns = PatternMatcherPass()
34+
register_replacement(
35+
search_fn=pattern,
36+
replace_fn=replacement,
37+
example_inputs=(torch.randn(3), torch.randn(3)),
38+
trace_fn=fwd_only,
39+
pass_dicts=patterns,
40+
)
41+
42+
# user-function
43+
@torch.compile(fullgraph=True)
44+
def f(x):
45+
x = x.clone()
46+
return bar(foo(x))
47+
48+
x = torch.randn(3, device="cpu")
49+
f(x)

torch/_inductor/fx_passes/post_grad.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,7 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool):
217217
),
218218
)
219219
gm.graph.lint()
220+
print(f"after post_grad_passes: gm: {gm}")
220221

221222

222223
def prepare_softmax_pattern(x, dim):

torch/_inductor/pattern_matcher.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1017,8 +1017,12 @@ def run(obj: PatternExpr, output_name: str = "output") -> str:
10171017
"""
10181018

10191019
pp = PatternPrettyPrinter()
1020-
assert hasattr(obj, "pretty_print")
1021-
out_str = obj.pretty_print(pp=pp)
1020+
print(f"obj: {obj}, type(obj): {type(obj)}")
1021+
if isinstance(obj, KeywordArg):
1022+
out_str = obj.name
1023+
else:
1024+
assert hasattr(obj, "pretty_print")
1025+
out_str = obj.pretty_print(pp=pp)
10221026

10231027
output = [
10241028
f"{pp.memoized_objs_names[key]} = {pp.memoized_objs_pp[key]}"
@@ -1072,6 +1076,10 @@ def register(
10721076
target: Union[torch.fx.node.Target, None] = None,
10731077
prepend: bool = False,
10741078
) -> None:
1079+
print(f"target: {target}, ")
1080+
if "auto_functionalized" in str(target):
1081+
import traceback
1082+
traceback.print_stack()
10751083
if target is None:
10761084
assert hasattr(self.pattern, "fns")
10771085
for fn in self.pattern.fns:
@@ -1902,6 +1910,11 @@ def __getitem__(self, item: tuple[str, torch.fx.node.Target]) -> list[PatternEnt
19021910
return self.patterns[item]
19031911

19041912
def apply(self, gm: Union[torch.fx.GraphModule, torch.fx.Graph]) -> int:
1913+
import traceback
1914+
if "pass_pattern_" in str(traceback.format_stack()):
1915+
print(f"PatternMatcherPass: apply: entering, pass_pattern_, gm: {gm}")
1916+
for op, target in self.patterns:
1917+
print(f"self.patterns: op: {op}, target: {target}")
19051918
if not self.patterns:
19061919
return 0
19071920
if isinstance(gm, torch.fx.GraphModule):
@@ -2085,13 +2098,18 @@ def fwd_only(
20852098
run_functional_passes: bool = True,
20862099
get_decomp_fn: Optional[Callable[..., Any]] = None,
20872100
) -> torch.fx.GraphModule:
2101+
# import traceback
2102+
# traceback.print_stack()
2103+
print(f"here1: Entering fwd_only")
20882104
"""Build a normalized inference graph, for use with fx_to_pattern"""
20892105
# TODO - look into using aot autograd, asserting no mutating ops here
20902106
with enable_python_dispatcher():
20912107
decompositions = (
20922108
get_decomp_fn() if get_decomp_fn is not None else select_decomp_table()
20932109
)
2110+
print(f"here2: fwd_only: will call make_fx")
20942111
gm = make_fx(fn, decompositions, tracing_mode="real")(*args)
2112+
print(f"here3: fwd_only: called make_fx, gm: {gm}")
20952113

20962114
from .fx_passes.post_grad import remove_noop_ops
20972115

0 commit comments

Comments
 (0)
0