8000 [invoke_subgraph] Run missing graph passes recursively · pytorch/pytorch@9b0f583 · GitHub
[go: up one dir, main page]

Skip to content

Commit 9b0f583

Browse files
committed
[invoke_subgraph] Run missing graph passes recursively
ghstack-source-id: 48ffd03 Pull Request resolved: #152675 [invoke_subgraph] Force the output to have same strides as meta
1 parent 99e6c92 commit 9b0f583

File tree

4 files changed

+93
-0
lines changed

4 files changed

+93
-0
lines changed

test/higher_order_ops/test_invoke_subgraph.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -617,6 +617,46 @@ def forward(self, a: "f32[8]", l_y_: "f32[8]"):
617617
""",
618618
)
619619

620+
def test_view_to_reshape(self):
621+
@mark_compile_region
622+
def gn(x):
623+
x = torch.sin(x)
624+
x = x.view(1, 8)
625+
return torch.sin(x)
626+
627+
def fn(x):
628+
return gn(x)
629+
630+
x = torch.randn(8, requires_grad=False)
631+
632+
torch._dynamo.reset()
633+
backend = InductorAndRecordGraphs()
634+
torch.compile(fn, backend=backend, fullgraph=True)(x)
635+
636+
if not TEST_WITH_CROSSREF:
637+
self.assertExpectedInline(
638+
normalize_gm(
639+
backend.inductor_graphs[0].print_readable(print_output=False)
640+
),
641+
"""\
642+
class <lambda>(torch.nn.Module):
643+
def forward(self, arg0_1: "f32[8]"):
644+
repeated_subgraph0 = self.repeated_subgraph0
645+
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, 'subgraph_0', arg0_1); repeated_subgraph0 = arg0_1 = None
646+
getitem: "f32[1, 8]" = invoke_subgraph[0]; invoke_subgraph = None
647+
return (getitem,)
648+
649+
class repeated_subgraph0(torch.nn.Module):
650+
def forward(self, arg0_1: "f32[8]"):
651+
sin: "f32[8]" = torch.ops.aten.sin.default(arg0_1); arg0_1 = None
652+
653+
view: "f32[1, 8]" = torch.ops.aten.reshape.default(sin, [1, 8]); sin = None
654+
655+
sin_1: "f32[1, 8]" = torch.ops.aten.sin.default(view); view = None
656+
return (sin_1,)
657+
""",
658+
)
659+
620660
def test_normalize_gm(self):
621661
@mark_compile_region
622662
def gn(x, y):

torch/_inductor/fx_passes/post_grad.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1207,6 +1207,15 @@ def view_to_reshape(gm):
12071207
):
12081208
nd.target = torch.ops.aten.reshape.default
12091209

1210+
subgraph_names: OrderedSet[str] = OrderedSet()
1211+
for node in sorted(gm.graph.find_nodes(op="get_attr")):
1212+
attr_name = node.target
1213+
if "." not in attr_name and attr_name not in subgraph_names:
1214+
sub_mod = getattr(gm, attr_name)
1215+
if isinstance(sub_mod, torch.fx.GraphModule):
1216+
subgraph_names.add(attr_name)
1217+
view_to_reshape(sub_mod)
1218+
12101219

12111220
def should_prefer_unfused_addmm(match):
12121221
inp = match.kwargs["inp"]

torch/_inductor/ir.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7431,6 +7431,10 @@ def _has_aliased_buffers(buffers: Sequence[IRNode]) -> bool:
74317431

74327432
@ir_dataclass(frozen=False)
74337433
class InvokeSubgraph(ExternKernel):
7434+
"""
7435+
Implementation of InvokeSubgraph HOP
7436+
"""
7437+
74347438
subgraph: Optional[Subgraph] = None
74357439
operands: Optional[list[TensorBox]] = None
74367440
outputs: Optional[list[MultiOutput]] = None
@@ -7515,6 +7519,17 @@ def create_output(output: IRNode, ind: int):
75157519
skip_size_stride_alignment_checks=True,
75167520
)
75177521

7522+
# Force the output strides to be same as the original strides
7523+
new_outputs = []
7524+
fake_outputs = V.graph.current_node.meta["val"]
7525+
for idx, output in enumerate(outputs):
7526+
if isinstance(output, (ShapeAsConstantBuffer, NoneAsConstantBuffer)):
7527+
new_outputs.append(output)
7528+
else:
7529+
example_stride = handle_sym_expr(fake_outputs[idx].stride())
7530+
new_outputs.append(cls.require_exact_strides(output, example_stride))
7531+
outputs = new_outputs
7532+
75187533
outputs = [create_output(output, i) for i, output in enumerate(outputs)]
75197534
invoke_subgraph.outputs = outputs
75207535
return outputs

torch/fx/passes/fake_tensor_prop.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from torch.fx._compatibility import compatibility
88
from torch.fx.experimental.proxy_tensor import py_sym_types, snapshot_fake
99
from torch.fx.node import map_aggregate
10+
from torch.utils._ordered_set import OrderedSet
1011

1112

1213
__all__ = ["FakeTensorProp"]
@@ -36,13 +37,41 @@ def __init__(
3637
self._mode = mode
3738
mode.epoch += 1
3839
mode.reset_nt_tensor_id_counter()
40+
self.seen_subgraphs: OrderedSet[str] = OrderedSet()
3941

4042
def run_node(self, n: Node):
4143
from torch.fx.experimental.symbolic_shapes import (
4244
compute_unbacked_bindings,
4345
rebind_unbacked,
4446
)
4547

48+
if (
49+
n.op == "call_function"
50+
and n.target is torch.ops.higher_order.invoke_subgraph
51+
and n.args[1] not in self.seen_subgraphs
52+
):
53+
# Prevent redundant fake tensor prop for invoke_subgraphs. Note that
54+
# there is also fake tensor caching for the entire subgraph. This
55+
# happens the next time we call `run_node` for the same subgraph,
56+
# which goes through super.run_node and caches the fake tensor prop.
57+
# Therefore, we are propagating fake tensor through the subgraphs
58+
# twice.
59+
assert isinstance(n.args[1], str)
60+
assert (
61+
isinstance(n.args[0], torch.fx.Node)
62+
and n.args[0].op == "get_attr"
63+
and isinstance(n.args[0].target, str)
64+
)
65+
self.seen_subgraphs.add(n.args[1])
66+
operands = n.args[2:]
67+
example_inputs = []
68+
for operand in operands:
69+
assert isinstance(operand, torch.fx.Node) and "val" in operand.meta
70+
example_inputs.append(operand.meta["val"])
71+
return FakeTensorProp(
72+
getattr(self.module, n.args[0].target), mode=self._mode
73+
).propagate(*example_inputs)
74+
4675
result = super().run_node(n)
4776
rebind_unbacked(self._mode.shape_env, n, result)
4877

0 commit comments

Comments
 (0)
0