8000 Update on "[invoke_subgraph] Run missing graph passes recursively" · pytorch/pytorch@239dd4e · GitHub
[go: up one dir, main page]

Skip to content

Commit 239dd4e

Browse files
committed
Update on "[invoke_subgraph] Run missing graph passes recursively"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov [ghstack-poisoned]
2 parents fcdf745 + c253f16 commit 239dd4e

File tree

3 files changed

+31
-14
lines changed

3 files changed

+31
-14
lines changed

torch/_inductor/compile_fx.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -396,20 +396,25 @@ def _unlift_graph(
396396
def _get_subgraph_names(
397397
gm: GraphModule, skip_invoke_subgraph: bool = False
398398
) -> Generator[str, None, None]:
399-
subgraph_names: OrderedSet[str] = OrderedSet()
400-
for child_name, child_mod in gm.named_children():
401-
if isinstance(child_mod, torch.fx.GraphModule):
402-
subgraph_names.add(child_name)
399+
all_subgraph_names: OrderedSet[str] = OrderedSet(
400+
x.target for x in gm.graph.find_nodes(op="get_attr")
401+
)
402+
fx_subgraph_names: OrderedSet[str] = OrderedSet()
403+
for child_name, child_module in gm.named_children():
404+
# Sometimes an owning_module can have unused children. Skip them
405+
# by checking them from get_attr node targets.
406+
if child_name in all_subgraph_names and isinstance(
407+
child_module, torch.fx.GraphModule
408+
):
409+
fx_subgraph_names.add(child_name)
403410

404411
if skip_invoke_subgraph:
405< 10000 /td>-
for node in sorted(
406-
gm.graph.find_nodes(
407-
op="call_function", target=torch.ops.higher_order.invoke_subgraph
408-
)
412+
for node in gm.graph.find_nodes(
413+
op="call_function", target=torch.ops.higher_order.invoke_subgraph
409414
):
410-
subgraph_names.discard(node.args[0].target)
415+
fx_subgraph_names.discard(node.args[0].target)
411416

412-
yield from subgraph_names
417+
yield from fx_subgraph_names
413418

414419

415420
def _recursive_pre_grad_passes(

torch/_inductor/fx_passes/post_grad.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1202,9 +1202,12 @@ def view_to_reshape(gm):
12021202
"""
12031203
Replace view ops in the GraphModule to reshape ops.
12041204
"""
1205+
subgraph_names: OrderedSet[str] = OrderedSet(
1206+
x.target for x in gm.graph.find_nodes(op="get_attr")
1207+
)
12051208

1206-
for child_mod in gm.children():
1207-
if isinstance(child_mod, torch.fx.GraphModule):
1209+
for child_name, child_mod in gm.named_children():
1210+
if child_name in subgraph_names and isinstance(child_mod, torch.fx.GraphModule):
12081211
view_to_reshape(child_mod)
12091212

12101213
for nd in gm.graph.find_nodes(

torch/fx/graph.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1779,6 +1779,8 @@ def forward(self, x):
17791779
of functional operations or you supply your own custom
17801780
function for detecting side-effectful nodes.
17811781
"""
1782+
from torch.utils._ordered_set import OrderedSet
1783+
17821784
# Lint the graph first to make sure its topologically sorted, otherwise
17831785
# DCE below will not behave as expected.
17841786
self.lint()
@@ -1803,8 +1805,15 @@ def has_side_effect(node):
18031805

18041806
# Call DCE on the subgraphs
18051807
if self.owning_module is not None:
1806-
for child_module in self.owning_module.children():
1807-
if isinstance(child_module, torch.fx.GraphModule):
1808+
subgraph_names = OrderedSet(
1809+
x.target for x in self.find_nodes(op="get_attr")
1810+
)
1811+
for child_name, child_module in self.owning_module.named_children():
1812+
# Sometimes an owning_module can have unused children. Skip them
1813+
# by checking them from get_attr node targets.
1814+
if child_name in subgraph_names and isinstance(
1815+
child_module, torch.fx.GraphModule
1816+
):
18081817
changed |= child_module.graph.eliminate_dead_code()
18091818
child_module.recompile()
18101819

0 commit comments

Comments
 (0)
0