8000 [invoke_subgraph] Run missing graph passes recursively (#152675) · pytorch/pytorch@97dfd8d · GitHub
[go: up one dir, main page]

Skip to content

Commit 97dfd8d

Browse files
anijain2305pytorchmergebot
authored andcommitted
[invoke_subgraph] Run missing graph passes recursively (#152675)
Pull Request resolved: #152675 Approved by: https://github.com/bdhirsh, https://github.com/zou3519 ghstack dependencies: #152772, #152770
1 parent cc254ea commit 97dfd8d

File tree

3 files changed

+77
-0
lines changed

3 files changed

+77
-0
lines changed

test/higher_order_ops/test_invoke_subgraph.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -617,6 +617,46 @@ def forward(self, a: "f32[8]", l_y_: "f32[8]"):
617617
""",
618618
)
619619

620+
def test_view_to_reshape(self):
621+
@mark_compile_region
622+
def gn(x):
623+
x = torch.sin(x)
624+
x = x.view(1, 8)
625+
return torch.sin(x)
626+
627+
def fn(x):
628+
return gn(x)
629+
630+
x = torch.randn(8, requires_grad=False)
631+
632+
torch._dynamo.reset()
633+
backend = InductorAndRecordGraphs()
634+
torch.compile(fn, backend=backend, fullgraph=True)(x)
635+
636+
if not TEST_WITH_CROSSREF:
637+
self.assertExpectedInline(
638+
normalize_gm(
639+
backend.inductor_graphs[0].print_readable(print_output=False)
640+
),
641+
"""\
642+
class <lambda>(torch.nn.Module):
643+
def forward(self, arg0_1: "f32[8]"):
644+
repeated_subgraph0 = self.repeated_subgraph0
645+
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, 'subgraph_0', arg0_1); repeated_subgraph0 = arg0_1 = None
646+
getitem: "f32[1, 8]" = invoke_subgraph[0]; invoke_subgraph = None
647+
return (getitem,)
648+
649+
class repeated_subgraph0(torch.nn.Module):
650+
def forward(self, arg0_1: "f32[8]"):
651+
sin: "f32[8]" = torch.ops.aten.sin.default(arg0_1); arg0_1 = None
652+
653+
view: "f32[1, 8]" = torch.ops.aten.reshape.default(sin, [1, 8]); sin = None
654+
655+
sin_1: "f32[1, 8]" = torch.ops.aten.sin.default(view); view = None
656+
return (sin_1,)
657+
""",
658+
)
659+
620660
def test_normalize_gm(self):
621661
@mark_compile_region
622662
def gn(x, y):

torch/_inductor/fx_passes/post_grad.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1202,6 +1202,14 @@ def view_to_reshape(gm):
12021202
"""
12031203
Replace view ops in the GraphModule to reshape ops.
12041204
"""
1205+
subgraph_names: OrderedSet[str] = OrderedSet(
1206+
x.target for x in gm.graph.find_nodes(op="get_attr")
1207+
)
1208+
1209+
for child_name, child_mod in gm.named_children():
1210+
if child_name in subgraph_names and isinstance(child_mod, torch.fx.GraphModule):
1211+
view_to_reshape(child_mod)
1212+
12051213
for nd in gm.graph.find_nodes(
12061214
op="call_function", target=torch.ops.aten.view.default
12071215
):

torch/fx/passes/fake_tensor_prop.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from torch.fx._compatibility import compatibility
88
from torch.fx.experimental.proxy_tensor import py_sym_types, snapshot_fake
99
from torch.fx.node import map_aggregate
10+
from torch.utils._ordered_set import OrderedSet
1011

1112

1213
__all__ = ["FakeTensorProp"]
@@ -36,13 +37,41 @@ def __init__(
3637
self._mode = mode
3738
mode.epoch += 1
3839
mode.reset_nt_tensor_id_counter()
40+
self.seen_subgraphs: OrderedSet[str] = OrderedSet()
3941

4042
def run_node(self, n: Node):
4143
from torch.fx.experimental.symbolic_shapes import (
4244
compute_unbacked_bindings,
4345
rebind_unbacked,
4446
)
4547

48+
if (
49+
n.op == "call_function"
50+
and n.target is torch.ops.higher_order.invoke_subgraph
51+
and n.args[1] not in self.seen_subgraphs
52+
):
53+
# Prevent redundant fake tensor prop for invoke_subgraphs. Note that
54+
# there is also fake tensor caching for the entire subgraph. This
55+
# happens the next time we call `run_node` for the same subgraph,
56+
# which goes through super.run_node and caches the fake tensor prop.
57+
# Therefore, we are propagating fake tensor through the subgraphs
58+
# twice.
59+
assert isinstance(n.args[1], str)
60+
assert (
61+
isinstance(n.args[0], torch.fx.Node)
62+
and n.args[0].op == "get_attr"
63+
and isinstance(n.args[0].target, str)
64+
)
65+
self.seen_subgraphs.add(n.args[1])
66+
operands = n.args[2:]
67+
example_inputs = []
68+
for operand in operands:
69+
assert isinstance(operand, torch.fx.Node) and "val" in operand.meta
70+
example_inputs.append(operand.meta["val"])
71+
return FakeTensorProp(
72+
getattr(self.module, n.args[0].target), mode=self._mode
73+
).propagate(*example_inputs)
74+
4861
4675
result = super().run_node(n)
4776
rebind_unbacked(self._mode.shape_env, n, result)
4877

0 commit comments

Comments
 (0)
0