10000 [invoke_subgraph] Run missing graph passes recursively · pytorch/pytorch@b428bb0 · GitHub
[go: up one dir, main page]

Skip to content

Commit b428bb0

Browse files
committed
[invoke_subgraph] Run missing graph passes recursively
ghstack-source-id: b0a382c Pull Request resolved: #152675 [invoke_subgraph] Force the output to have same strides as meta
1 parent 6a5d145 commit b428bb0

File tree

4 files changed

+48
-0
lines changed

4 files changed

+48
-0
lines changed

torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from torch import Tensor
1414
from torch._dispatch.python import enable_python_dispatcher
1515
from torch._dynamo.utils import detect_fake_mode, lazy_format_graph_code
16+
from torch._inductor.utils import OrderedSet
1617
from torch._logging import getArtifactLogger, trace_structured
1718
from torch._subclasses.functional_tensor import FunctionalTensorMode
1819
from torch.fx.experimental.proxy_tensor import make_fx
@@ -183,6 +184,19 @@ def aot_dispatch_base_graph(
183184
# there should be *NO* mutating ops in the graph at this point.
184185
copy_count = assert_functional_graph(fw_module.graph)
185186
fw_module.graph.eliminate_dead_code()
187+
188+
# Call DCE on the subgraphs
189+
# TODO - Consider updating the eliminate_dead_code to work recursively.
190+
seen_subgraphs: OrderedSet[str] = OrderedSet()
191+
for nd in fw_module.graph.find_nodes(
192+
op="call_function", target=torch.ops.higher_order.invoke_subgraph
193+
):
194+
subgraph_name = nd.args[0].target
195+
if subgraph_name not in seen_subgraphs:
196+
seen_subgraphs.add(subgraph_name)
197+
subgraph = getattr(fw_module, nd.args[0].target)
198+
subgraph.graph.eliminate_dead_code()
199+
subgraph.recompile()
186200
fw_module.recompile()
187201

188202
copy_count2 = assert_functional_graph(fw_module.graph)

torch/_inductor/fx_passes/post_grad.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1207,6 +1207,16 @@ def view_to_reshape(gm):
12071207
):
12081208
nd.target = torch.ops.aten.reshape.default
12091209

1210+
seen_subgraphs: OrderedSet[str] = OrderedSet()
1211+
for nd in gm.graph.find_nodes(
1212+
op="call_function", target=torch.ops.higher_order.invoke_subgraph
1213+
):
1214+
subgraph_name = nd.args[0].target
1215+
if subgraph_name not in seen_subgraphs:
1216+
seen_subgraphs.add(subgraph_name)
1217+
subgraph = getattr(gm, nd.args[0].target)
1218+
view_to_reshape(subgraph)
1219+
12101220

12111221
def should_prefer_unfused_addmm(match):
12121222
inp = match.kwargs["inp"]

torch/_inductor/ir.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7431,6 +7431,10 @@ def _has_aliased_buffers(buffers: Sequence[IRNode]) -> bool:
74317431

74327432
@ir_dataclass(frozen=False)
74337433
class InvokeSubgraph(ExternKernel):
7434+
"""
7435+
Implementation of InvokeSubgraph HOP
7436+
"""
7437+
74347438
subgraph: Optional[Subgraph] = None
74357439
operands: Optional[list[TensorBox]] = None
74367440
outputs: Optional[list[MultiOutput]] = None
@@ -7515,6 +7519,17 @@ def create_output(output: IRNode, ind: int):
75157519
skip_size_stride_alignment_checks=True,
75167520
)
75177521

7522+
# Force the output strides to be same as the original strides
7523+
new_outputs = []
7524+
fake_outputs = V.graph.current_node.meta["val"]
7525+
for idx, output in enumerate(outputs):
7526+
if isinstance(output, (ShapeAsConstantBuffer, NoneAsConstantBuffer)):
7527+
new_outputs.append(output)
7528+
else:
7529+
example_stride = handle_sym_expr(fake_outputs[idx].stride())
7530+
new_outputs.append(cls.require_exact_strides(output, example_stride))
7531+
outputs = new_outputs
7532+
75187533
outputs = [create_output(output, i) for i, output in enumerate(outputs)]
75197534
invoke_subgraph.outputs = outputs
75207535
return outputs

torch/fx/passes/fake_tensor_prop.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,15 @@ def run_node(self, n: Node):
4343
rebind_unbacked,
4444
)
4545

46+
if (
47+
n.op == "call_function"
48+
and n.target is torch.ops.higher_order.invoke_subgraph
49+
):
50+
subgraph_example_inputs = [a.meta["val"] for a in n.args[2:]] # type: ignore[union-attr,arg-type]
51+
FakeTensorProp(
52+
getattr(self.module, n.args[0].target), mode=self._mode # type: ignore[union-attr,arg-type]
53+
).propagate(*subgraph_example_inputs)
54+
4655
result = super().run_node(n)
4756
rebind_unbacked(self._mode.shape_env, n, result)
4857

0 commit comments

Comments
 (0)
0