8000 [WIP] Pattern matcher support for custom op by yf225 · Pull Request #152767 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[WIP] Pattern matcher support for custom op #152767

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions test/test_pattern_matcher_custom_op_DEBUG_NOTES.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
DEBUG NOTES:
1. even the nonmutable-op version doesn't work right now, because custom op is automatically wrapped in auto_functionalized / auto_functionalized_v2,
while the pattern is looking for vanilla ops.
TODO: we should convert the pattern to auto_functionalized_v2 and then do matching.
- Richard said we can maybe use torch.func.functionalize + make_fx
2. after the nonmutable-op version is fixed, we will move to mutable-op-nonview version.
51 changes: 51 additions & 0 deletions test/test_pattern_matcher_mutable_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import torch
from torch.library import register_fake
from torch._inductor.pattern_matcher import register_replacement, fwd_only, PatternMatcherPass

@torch.library.custom_op("mylib::foo_inplace", mutates_args={"x"})
def foo_inplace(x: torch.Tensor) -> None:
x.add_(1)

# NOTE: only returning None is supported; the custom op cannot return `out`.
@torch.library.custom_op("mylib::bar", mutates_args={"out"})
def bar_out(x: torch.Tensor, out: torch.Tensor) -> None:
out.copy_(x + 2)

@register_fake("mylib::bar")
def bar_out_fake(x: torch.Tensor, out: torch.Tensor) -> None:
return None

@torch.library.custom_op("mylib::foobar_out", mutates_args={"out"})
def foobar_out(x: torch.Tensor, out: torch.Tensor) -> torch.Tensor:
x.add_(1)
out.copy_(x + 2)
return out

def pattern(x, out):
foo_inplace(x)
bar_out(x, out)
return out

def replacement(x, out):
return foobar_out(x, out)

patterns = PatternMatcherPass()
register_replacement(
search_fn=pattern,
replace_fn=replacement,
example_inputs=(torch.randn(3), torch.randn(3)),
trace_fn=fwd_only,
pass_dicts=patterns,
)

# user-function
@torch.compile(fullgraph=True)
def f(x):
x = x.clone()
out = torch.empty_like(x)
foo_inplace(x)
bar_out(x, out)
return out

x = torch.randn(3, device="cpu")
f(x)
49 changes: 49 additions & 0 deletions test/test_pattern_matcher_nonmutable_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import torch
from torch.library import register_fake
from torch._inductor.pattern_matcher import register_replacement, fwd_only, PatternMatcherPass

@torch.library.custom_op("mylib::foo", mutates_args={})
def foo(x: torch.Tensor) -> torch.Tensor:
return x + 1

@register_fake("mylib::foo")
def foo_fake(x: torch.Tensor) -> torch.Tensor:
return x

@torch.library.custom_op("mylib::bar", mutates_args={})
def bar(x: torch.Tensor) -> torch.Tensor:
return x + 2

@register_fake("mylib::bar")
def bar_fake(x: torch.Tensor) -> torch.Tensor:
return x

@torch.library.custom_op("mylib::foobar", mutates_args={})
def foobar(x: torch.Tensor) -> torch.Tensor:
return x + 3

def pattern(x):
o1 = foo(x)
o2 = bar(o1)
return o2

def replacement(x):
return foobar(x)

patterns = PatternMatcherPass()
register_replacement(
search_fn=pattern,
replace_fn=replacement,
example_inputs=(torch.randn(3), torch.randn(3)),
trace_fn=fwd_only,
pass_dicts=patterns,
)

# user-function
@torch.compile(fullgraph=True)
def f(x):
x = x.clone()
return bar(foo(x))

x = torch.randn(3, device="cpu")
f(x)
1 change: 1 addition & 0 deletions torch/_inductor/fx_passes/post_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool):
),
)
gm.graph.lint()
print(f"after post_grad_passes: gm: {gm}")


def prepare_softmax_pattern(x, dim):
Expand Down
22 changes: 20 additions & 2 deletions torch/_inductor/pattern_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -1017,8 +1017,12 @@ def run(obj: PatternExpr, output_name: str = "output") -> str:
"""

pp = PatternPrettyPrinter()
assert hasattr(obj, "pretty_print")
out_str = obj.pretty_print(pp=pp)
print(f"obj: {obj}, type(obj): {type(obj)}")
if isinstance(obj, KeywordArg):
out_str = obj.name
else:
assert hasattr(obj, "pretty_print")
out_str = obj.pretty_print(pp=pp)

output = [
f"{pp.memoized_objs_names[key]} = {pp.memoized_objs_pp[key]}"
Expand Down Expand Up @@ -1072,6 +1076,10 @@ def register(
target: Union[torch.fx.node.Target, None] = None,
prepend: bool = False,
) -> None:
print(f"target: {target}, ")
if "auto_functionalized" in str(target):
import traceback
traceback.print_stack()
if target is None:
assert hasattr(self.pattern, "fns")
for fn in self.pattern.fns:
Expand Down Expand Up @@ -1902,6 +1910,11 @@ def __getitem__(self, item: tuple[str, torch.fx.node.Target]) -> list[PatternEnt
return self.patterns[item]

def apply(self, gm: Union[torch.fx.GraphModule, torch.fx.Graph]) -> int:
import traceback
if "pass_pattern_" in str(traceback.format_stack()):
print(f"PatternMatcherPass: apply: entering, pass_pattern_, gm: {gm}")
for op, target in self.patterns:
print(f"self.patterns: op: {op}, target: {target}")
if not self.patterns:
return 0
if isinstance(gm, torch.fx.GraphModule):
Expand Down Expand Up @@ -2085,13 +2098,18 @@ def fwd_only(
run_functional_passes: bool = True,
get_decomp_fn: Optional[Callable[..., Any]] = None,
) -> torch.fx.GraphModule:
# import traceback
# traceback.print_stack()
print(f"here1: Entering fwd_only")
"""Build a normalized inference graph, for use with fx_to_pattern"""
# TODO - look into using aot autograd, asserting no mutating ops here
with enable_python_dispatcher():
decompositions = (
get_decomp_fn() if get_decomp_fn is not None else select_decomp_table()
)
print(f"here2: fwd_only: will call make_fx")
gm = make_fx(fn, decompositions, tracing_mode="real")(*args)
print(f"here3: fwd_only: called make_fx, gm: {gm}")

from .fx_passes.post_grad import remove_noop_ops

Expand Down
Loading
0