8000 [invoke_subgraph] Run missing graph passes recursively by anijain2305 · Pull Request #152675 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[invoke_subgraph] Run missing graph passes recursively #152675

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 20 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
57f8556
[invoke_subgraph] Run missing graph passes recursively
anijain2305 May 2, 2025
d786c21
Update on "[invoke_subgraph] Run missing graph passes recursively"
anijain2305 May 2, 2025
e7fa1e3
Update on "[invoke_subgraph] Run missing graph passes recursively"
anijain2305 May 3, 2025
3639e98
Update on "[invoke_subgraph] Run missing graph passes recursively"
anijain2305 May 4, 2025
1ef794a
Update on "[invoke_subgraph] Run missing graph passes recursively"
anijain2305 May 4, 2025
3c8fe85
Update on "[invoke_subgraph] Run missing graph passes recursively"
anijain2305 May 4, 2025
f046bd8
Update on "[invoke_subgraph] Run missing graph passes recursively"
anijain2305 May 4, 2025
c495166
Update on "[invoke_subgraph] Run missing graph passes recursively"
anijain2305 May 4, 2025
19bbfc4
Update on "[invoke_subgraph] Run missing graph passes recursively"
anijain2305 May 4, 2025
5f878b2
Update on "[invoke_subgraph] Run missing graph passes recursively"
anijain2305 May 4, 2025
b4ae64d
Update on "[invoke_subgraph] Run missing graph passes recursively"
anijain2305 May 4, 2025
df2583f
Update on "[invoke_subgraph] Run missing graph passes recursively"
anijain2305 May 5, 2025
29555ce
Update on "[invoke_subgraph] Run missing graph passes recursively"
anijain2305 May 5, 2025
0aa67bc
Update on "[invoke_subgraph] Run missing graph passes recursively"
anijain2305 May 5, 2025
36d6535
Update on "[invoke_subgraph] Run missing graph passes recursively"
anijain2305 May 5, 2025
bde719c
Update on "[invoke_subgraph] Run missing graph passes recursively"
anijain2305 May 5, 2025
814e68f
Update on "[invoke_subgraph] Run missing graph passes recursively"
anijain2305 May 5, 2025
cbf8ffc
Update on "[invoke_subgraph] Run missing graph passes recursively"
anijain2305 May 5, 2025
fcdf745
Update on "[invoke_subgraph] Run missing graph passes recursively"
anijain2305 May 5, 2025
239dd4e
Update on "[invoke_subgraph] Run missing graph passes recursively"
anijain2305 May 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions test/higher_order_ops/test_invoke_subgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,46 @@ def forward(self, a: "f32[8]", l_y_: "f32[8]"):
""",
)

def test_view_to_reshape(self):
@mark_compile_region
def gn(x):
x = torch.sin(x)
x = x.view(1, 8)
return torch.sin(x)

def fn(x):
return gn(x)

x = torch.randn(8, requires_grad=False)

torch._dynamo.reset()
backend = InductorAndRecordGraphs()
torch.compile(fn, backend=backend, fullgraph=True)(x)

if not TEST_WITH_CROSSREF:
self.assertExpectedInline(
normalize_gm(
backend.inductor_graphs[0].print_readable(print_output=False)
),
"""\
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: "f32[8]"):
repeated_subgraph0 = self.repeated_subgraph0
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, 'subgraph_0', arg0_1); repeated_subgraph0 = arg0_1 = None
getitem: "f32[1, 8]" = invoke_subgraph[0]; invoke_subgraph = None
return (getitem,)

class repeated_subgraph0(torch.nn.Module):
def forward(self, arg0_1: "f32[8]"):
sin: "f32[8]" = torch.ops.aten.sin.default(arg0_1); arg0_1 = None

view: "f32[1, 8]" = torch.ops.aten.reshape.default(sin, [1, 8]); sin = None

sin_1: "f32[1, 8]" = torch.ops.aten.sin.default(view); view = None
return (sin_1,)
""",
)

def test_normalize_gm(self):
@mark_compile_region
def gn(x, y):
Expand Down
8 changes: 8 additions & 0 deletions torch/_inductor/fx_passes/post_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -1202,6 +1202,14 @@ def view_to_reshape(gm):
"""
Replace view ops in the GraphModule to reshape ops.
"""
subgraph_names: OrderedSet[str] = OrderedSet(
x.target for x in gm.graph.find_nodes(op="get_attr")
)

for child_name, child_mod in gm.named_children():
if child_name in subgraph_names and isinstance(child_mod, torch.fx.GraphModule):
view_to_reshape(child_mod)
Comment on lines +1205 to +1211
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer some automated way of doing this, like GraphTransformObserver accepts view_to_reshape and then automatically applies the pass to subgraphs.

How many more times do we need to do this manually?


for nd in gm.graph.find_nodes(
op="call_function", target=torch.ops.aten.view.default
):
Expand Down
29 changes: 29 additions & 0 deletions torch/fx/passes/fake_tensor_prop.py
8000
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from torch.fx._compatibility import compatibility
from torch.fx.experimental.proxy_tensor import py_sym_types, snapshot_fake
from torch.fx.node import map_aggregate
from torch.utils._ordered_set import OrderedSet


__all__ = ["FakeTensorProp"]
Expand Down Expand Up @@ -36,13 +37,41 @@ def __init__(
self._mode = mode
mode.epoch += 1
mode.reset_nt_tensor_id_counter()
self.seen_subgraphs: OrderedSet[str] = OrderedSet()

def run_node(self, n: Node):
from torch.fx.experimental.symbolic_shapes import (
compute_unbacked_bindings,
rebind_unbacked,
)

if (
n.op == "call_function"
and n.target is torch.ops.higher_order.invoke_subgraph
and n.args[1] not in self.seen_subgraphs
):
# Prevent redundant fake tensor prop for invoke_subgraphs. Note that
# there is also fake tensor caching for the entire subgraph. This
# happens the next time we call `run_node` for the same subgraph,
# which goes through super.run_node and caches the fake tensor prop.
# Therefore, we are propagating fake tensor through the subgraphs
# twice.
assert isinstance(n.args[1], str)
assert (
isinstance(n.args[0], torch.fx.Node)
and n.args[0].op == "get_attr"
and isinstance(n.args[0].target, str)
)
self.seen_subgraphs.add(n.args[1])
operands = n.args[2:]
example_inputs = []
for operand in operands:
assert isinstance(operand, torch.fx.Node) and "val" in operand.meta
example_inputs.append(operand.meta["val"])
return FakeTensorProp(
getattr(self.module, n.args[0].target), mode=self._mode
).propagate(*example_inputs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This reminds me - inductor also has a way of doing incremental faketensor updates through FakeTensorUpdater.incremental_update. The incremental updater might not know to properly perform incremental updates on subgraphs as well today, which could lead to silent correctness? https://github.com/pytorch/pytorch/blob/main/torch/_inductor/fx_utils.py#L153

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked that one. So anything which is run during post_grad pass, we will recurse.

view_to_reshape is outside of post_grad. Therefore, we need this PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FakeTensorUpdater doesn't work on HOPs, we should fix that at some point


result = super().run_node(n)
rebind_unbacked(self._mode.shape_env, n, result)

Expand Down
Loading
0