From 7af8fa6a5bbdf2bafe9091dc775805b9ccd43206 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Tue, 25 Mar 2025 08:12:50 -0700 Subject: [PATCH 1/8] support saving reduce scatter for backward in async TP --- .../tensor/parallel/test_micro_pipeline_tp.py | 38 ------ .../_inductor/fx_passes/micro_pipeline_tp.py | 125 +++++++++++++----- 2 files changed, 92 insertions(+), 71 deletions(-) diff --git a/test/distributed/tensor/parallel/test_micro_pipeline_tp.py b/test/distributed/tensor/parallel/test_micro_pipeline_tp.py index c8fd02cbc3203d..2cc02d5916fb6b 100644 --- a/test/distributed/tensor/parallel/test_micro_pipeline_tp.py +++ b/test/distributed/tensor/parallel/test_micro_pipeline_tp.py @@ -492,44 +492,6 @@ def no_matching_pattern( gm.graph, ) - @skipIfRocm - @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") - @fresh_inductor_cache() - def test_unsuccessful_fusion(self): - group = dist.group.WORLD - scatter_dim = 0 - - def no_matching_pattern( - A: torch.Tensor, - B: torch.Tensor, - ) -> torch.Tensor: - """ - Performs 'reshape -> reciprocal -> mm -> reshape -> reduce scatter' pattern, - so the extra 'reciprocal' op in the middle should cause pattern matching to fail. - """ - out_shape = [*A.shape[:-1], B.shape[-1]] - A = A.reshape(-1, A.shape[-1]) - - # insert extra op after reshape that will cause pattern matching to fail - A = torch.reciprocal(A) - - C = A @ B - C = C.view(out_shape) - return reduce_scatter_tensor(C, "sum", scatter_dim, group) - - A = torch.rand(2, 16, 32, device="cuda").to(torch.bfloat16) - B = torch.rand(16, 32, device="cuda").to(torch.bfloat16).T - - gm = _make_post_grad_fx(no_matching_pattern, A, B) - - with _test_mode(): - self.assertRaisesRegex( - AssertionError, - "no successful fusions of matul-reduce-scatters", - micro_pipeline_tp_pass, - gm.graph, - ) - @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @parametrize("shard_dim", [0, 1]) @fresh_inductor_cache() diff --git a/torch/_inductor/fx_passes/micro_pipeline_tp.py b/torch/_inductor/fx_passes/micro_pipeline_tp.py index b150b1c0abda87..053d89c62e27c5 100644 --- a/torch/_inductor/fx_passes/micro_pipeline_tp.py +++ b/torch/_inductor/fx_passes/micro_pipeline_tp.py @@ -211,8 +211,25 @@ class _ReduceScatterMatch: group_name: str def replace_with(self, new_node: torch.fx.Node) -> None: + # Replace all uses of the result node (wait_tensor) with the fused node. self.res_node.replace_all_uses_with(new_node) + # If the reduce-scatter result is saved for backward, save the fused node for backward instead. + self._update_save_for_backward(new_node) + + def _update_save_for_backward(self, new_node: torch.fx.Node) -> None: + """ + If the output node is a user of the reduce_scatter node (indicating the reduce_scatter + result is saved for backward), this method will update the output node to use the fused node instead. + """ + output_node = None + for user in self.rs_node.users: + if user.target == "output": + output_node = user + break + if output_node is not None: + output_node.replace_input_with(self.rs_node, new_node) + def erase(self) -> None: for node in reversed(self.match.nodes): if len(node.users) == 0: @@ -222,7 +239,7 @@ def erase(self) -> None: def find_reduce_scatter_patterns(graph: torch.fx.Graph): c10d = torch.ops._c10d_functional - def reduce_scatter_template(inp: PatternExpr): + def reduce_scatter_template(inp: PatternExpr, users: int): return CallFunction( c10d.wait_tensor.default, CallFunction( @@ -231,14 +248,43 @@ def reduce_scatter_template(inp: PatternExpr): KeywordArg("reduce_op"), Ignored(), KeywordArg("group_name"), + _users=users, ), ) # Matches funcol.reduce_scatter_tensor with scatter_dim == 0 - zero_dim_reduce_scatter_pattern = reduce_scatter_template(KeywordArg("input")) + zero_dim_reduce_scatter_pattern_single_user = reduce_scatter_template( + KeywordArg("input"), users=1 + ) + + # Two users will occur when the reduce-scatter result is saved for backward + zero_dim_reduce_scatter_pattern_multi_user = reduce_scatter_template( + KeywordArg("input"), users=2 + ) # Matches funcol.reduce_scatter_tensor with scatter_dim > 0 - non_zero_dim_reduce_scatter_pattern = reduce_scatter_template( + non_zero_dim_reduce_scatter_pattern_single_user = reduce_scatter_template( + CallFunction( + aten.cat.default, + ListOf( + CallFunction( + operator.getitem, + CallFunction( + aten.split.Tensor, + KeywordArg("input"), + Ignored(), + KeywordArg("scatter_dim"), + _users=MULTIPLE, + ), + Ignored(), + ) + ), + ), + users=1, + ) + + # Two users will occur when the reduce-scatter result is saved for backward + non_zero_dim_reduce_scatter_pattern_multi_user = reduce_scatter_template( CallFunction( aten.cat.default, ListOf( @@ -255,12 +301,39 @@ def reduce_scatter_template(inp: PatternExpr): ) ), ), + users=2, ) reduce_scatters = [] for node in reversed(graph.nodes): if node.target == c10d.wait_tensor.default: - if match := non_zero_dim_reduce_scatter_pattern.match(node): + if match := non_zero_dim_reduce_scatter_pattern_single_user.match(node): + assert isinstance(match, Match) + reduce_scatters.append( + _ReduceScatterMatch( + match=match, + input_node=match.kwargs["input"], + rs_node=match.nodes[-2], + res_node=node, + reduce_op=match.kwargs["reduce_op"], + scatter_dim=match.kwargs["scatter_dim"], + group_name=match.kwargs["group_name"], + ) + ) + elif match := zero_dim_reduce_scatter_pattern_single_user.match(node): + assert isinstance(match, Match) + reduce_scatters.append( + _ReduceScatterMatch( + match=match, + input_node=match.kwargs["input"], + rs_node=match.nodes[0], + res_node=node, + reduce_op=match.kwargs["reduce_op"], + scatter_dim=0, + group_name=match.kwargs["group_name"], + ) + ) + elif match := non_zero_dim_reduce_scatter_pattern_multi_user.match(node): assert isinstance(match, Match) reduce_scatters.append( _ReduceScatterMatch( @@ -273,7 +346,7 @@ def reduce_scatter_template(inp: PatternExpr): group_name=match.kwargs["group_name"], ) ) - elif match := zero_dim_reduce_scatter_pattern.match(node): + elif match := zero_dim_reduce_scatter_pattern_multi_user.match(node): assert isinstance(match, Match) reduce_scatters.append( _ReduceScatterMatch( @@ -757,7 +830,7 @@ def _insert_fused_matmul_reduce_scatter( raise AssertionError(f"Unexpected matmul match type: {type(matmul)}") -def fuse_matmul_reduce_scatter(reduce_scatter: _ReduceScatterMatch) -> bool: +def fuse_matmul_reduce_scatter(reduce_scatter: _ReduceScatterMatch) -> None: """ Fused the pattern @@ -771,14 +844,9 @@ def fuse_matmul_reduce_scatter(reduce_scatter: _ReduceScatterMatch) -> bool: Returns boolean indicating if fusion was successful or not. """ - if ( - not torch.distributed.is_available() - or not torch.distributed.is_nccl_available() - ): - log.debug( - "torch.distributed is not available, skipping fuse_matmul_reduce_scatter fusion" - ) - return False + assert torch.distributed.is_available() or torch.distributed.is_nccl_available(), ( + "torch.distributed must be available to use async tensor parallelism" + ) from torch.distributed._symmetric_memory import ( is_symm_mem_enabled_for_group, @@ -794,35 +862,32 @@ def fuse_matmul_reduce_scatter(reduce_scatter: _ReduceScatterMatch) -> bool: reduce_scatter.group_name, ) - if not is_symm_mem_enabled_for_group(group_name): - log.debug( - "symmetric memory is not enabled for process group %s, skipping fuse_matmul_reduce_scatter fusion", - group_name, - ) - return False + assert is_symm_mem_enabled_for_group(group_name), ( + f"symmetric memory is not enabled for process group {group_name}, skipping fuse_matmul_reduce_scatter fusion" + ) # Currently fused_matmul_reduce_scatter doesn't return the matmul result, # so we can't apply the fusion if the matmul result is used by multiple # users. This is not a fundamental limitation of the fused op and can be # addressed if needed. if len(input_node.users) != 1: - log.debug( + log.warning( "matmul result has more than one user, skipping fused_matmul_reduce_scatter fusion." ) - return False + return matmul = _find_producer_matmul(input_node) if matmul is None: - log.debug( + log.warning( "no producer matmul found for reduce scatter, skipping fuse_matmul_reduce_scatter fusion" ) - return False + return if rs_res_node in matmul.arg_ancestor_nodes: - log.debug( + log.warning( "reduce-scatter result node is an ancestor of matmul, skipping fuse_matmul_reduce_scatter fusion" ) - return False + return # We need to track 3 values for the fused scaled mm reduce scatter implementation: # 1. The scatter dim before the reshape, which was assigned using the original (a,b,c) @ (c,d) = (a,b,d) dims. @@ -885,7 +950,6 @@ def fuse_matmul_reduce_scatter(reduce_scatter: _ReduceScatterMatch) -> bool: fused_node.prepend(node) log.debug("successfully fused matmul reduce scatter") - return True def _get_node_to_ancestors( @@ -996,10 +1060,5 @@ def micro_pipeline_tp_pass(graph: torch.fx.Graph): for all_gather in all_gathers: fuse_all_gather_matmul(all_gather) - fused_reduce_scatters = False for reduce_scatter in reduce_scatters: - if fuse_matmul_reduce_scatter(reduce_scatter): - fused_reduce_scatters = True - - if reduce_scatters and not fused_reduce_scatters: - raise AssertionError("no successful fusions of matul-reduce-scatters") + fuse_matmul_reduce_scatter(reduce_scatter) From d74bccbd663f93b7d25ef12d369dac9de47b46e5 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Tue, 25 Mar 2025 15:17:39 -0700 Subject: [PATCH 2/8] rename rs_node to reduce_scatter_node to avoid confusion with res_node --- .../_inductor/fx_passes/micro_pipeline_tp.py | 29 ++++++++++++------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/torch/_inductor/fx_passes/micro_pipeline_tp.py b/torch/_inductor/fx_passes/micro_pipeline_tp.py index 053d89c62e27c5..22c1ae12763858 100644 --- a/torch/_inductor/fx_passes/micro_pipeline_tp.py +++ b/torch/_inductor/fx_passes/micro_pipeline_tp.py @@ -204,7 +204,7 @@ def make_cat_pattern(splits): class _ReduceScatterMatch: match: Match input_node: torch.fx.Node - rs_node: torch.fx.Node + reduce_scatter_node: torch.fx.Node res_node: torch.fx.Node reduce_op: str scatter_dim: int @@ -223,12 +223,12 @@ def _update_save_for_backward(self, new_node: torch.fx.Node) -> None: result is saved for backward), this method will update the output node to use the fused node instead. """ output_node = None - for user in self.rs_node.users: + for user in self.reduce_scatter_node.users: if user.target == "output": output_node = user break if output_node is not None: - output_node.replace_input_with(self.rs_node, new_node) + output_node.replace_input_with(self.reduce_scatter_node, new_node) def erase(self) -> None: for node in reversed(self.match.nodes): @@ -313,7 +313,7 @@ def reduce_scatter_template(inp: PatternExpr, users: int): _ReduceScatterMatch( match=match, input_node=match.kwargs["input"], - rs_node=match.nodes[-2], + reduce_scatter_node=match.nodes[-2], res_node=node, reduce_op=match.kwargs["reduce_op"], scatter_dim=match.kwargs["scatter_dim"], @@ -326,7 +326,7 @@ def reduce_scatter_template(inp: PatternExpr, users: int): _ReduceScatterMatch( match=match, input_node=match.kwargs["input"], - rs_node=match.nodes[0], + reduce_scatter_node=match.nodes[0], res_node=node, reduce_op=match.kwargs["reduce_op"], scatter_dim=0, @@ -339,7 +339,7 @@ def reduce_scatter_template(inp: PatternExpr, users: int): _ReduceScatterMatch( match=match, input_node=match.kwargs["input"], - rs_node=match.nodes[-2], + reduce_scatter_node=match.nodes[-2], res_node=node, reduce_op=match.kwargs["reduce_op"], scatter_dim=match.kwargs["scatter_dim"], @@ -352,7 +352,7 @@ def reduce_scatter_template(inp: PatternExpr, users: int): _ReduceScatterMatch( match=match, input_node=match.kwargs["input"], - rs_node=match.nodes[0], + reduce_scatter_node=match.nodes[0], res_node=node, reduce_op=match.kwargs["reduce_op"], scatter_dim=0, @@ -853,9 +853,16 @@ def fuse_matmul_reduce_scatter(reduce_scatter: _ReduceScatterMatch) -> None: restride_A_for_fused_matmul_reduce_scatter, ) - input_node, _rs_node, rs_res_node, reduce_op, orig_scatter_dim, group_name = ( + ( + input_node, + _reduce_scatter_node, + rs_res_node, + reduce_op, + orig_scatter_dim, + group_name, + ) = ( reduce_scatter.input_node, - reduce_scatter.rs_node, + reduce_scatter.reduce_scatter_node, reduce_scatter.res_node, reduce_scatter.reduce_op, reduce_scatter.scatter_dim, @@ -1048,7 +1055,9 @@ def micro_pipeline_tp_pass(graph: torch.fx.Graph): unexposed_collectives = _get_unexposed_collectives(graph) all_gathers = [x for x in all_gathers if x.ag_node not in unexposed_collectives] reduce_scatters = [ - x for x in reduce_scatters if x.rs_node not in unexposed_collectives + x + for x in reduce_scatters + if x.reduce_scatter_node not in unexposed_collectives ] if not all_gathers and not reduce_scatters: From 15a891b6101de041dd1332c4b49118300c19ba07 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Tue, 25 Mar 2025 15:30:30 -0700 Subject: [PATCH 3/8] update tests with new var names --- .../tensor/parallel/test_micro_pipeline_tp.py | 4 ++-- .../_inductor/fx_passes/micro_pipeline_tp.py | 22 +++++++++---------- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/test/distributed/tensor/parallel/test_micro_pipeline_tp.py b/test/distributed/tensor/parallel/test_micro_pipeline_tp.py index 2cc02d5916fb6b..79843ebae9f95a 100644 --- a/test/distributed/tensor/parallel/test_micro_pipeline_tp.py +++ b/test/distributed/tensor/parallel/test_micro_pipeline_tp.py @@ -157,11 +157,11 @@ def func(inp: torch.Tensor) -> torch.Tensor: "placeholder", ) self.assertEqual( - reduce_scatter.rs_node.target, + reduce_scatter.reduce_scatter_node.target, torch.ops._c10d_functional.reduce_scatter_tensor.default, ) self.assertEqual( - reduce_scatter.res_node.target, + reduce_scatter.wait_tensor_node.target, torch.ops._c10d_functional.wait_tensor.default, ) self.assertEqual(reduce_scatter.group_name, group.group_name) diff --git a/torch/_inductor/fx_passes/micro_pipeline_tp.py b/torch/_inductor/fx_passes/micro_pipeline_tp.py index 22c1ae12763858..905a2119aa6c03 100644 --- a/torch/_inductor/fx_passes/micro_pipeline_tp.py +++ b/torch/_inductor/fx_passes/micro_pipeline_tp.py @@ -205,14 +205,14 @@ class _ReduceScatterMatch: match: Match input_node: torch.fx.Node reduce_scatter_node: torch.fx.Node - res_node: torch.fx.Node + wait_tensor_node: torch.fx.Node reduce_op: str scatter_dim: int group_name: str def replace_with(self, new_node: torch.fx.Node) -> None: # Replace all uses of the result node (wait_tensor) with the fused node. - self.res_node.replace_all_uses_with(new_node) + self.wait_tensor_node.replace_all_uses_with(new_node) # If the reduce-scatter result is saved for backward, save the fused node for backward instead. self._update_save_for_backward(new_node) @@ -314,7 +314,7 @@ def reduce_scatter_template(inp: PatternExpr, users: int): match=match, input_node=match.kwargs["input"], reduce_scatter_node=match.nodes[-2], - res_node=node, + wait_tensor_node=node, reduce_op=match.kwargs["reduce_op"], scatter_dim=match.kwargs["scatter_dim"], group_name=match.kwargs["group_name"], @@ -327,7 +327,7 @@ def reduce_scatter_template(inp: PatternExpr, users: int): match=match, input_node=match.kwargs["input"], reduce_scatter_node=match.nodes[0], - res_node=node, + wait_tensor_node=node, reduce_op=match.kwargs["reduce_op"], scatter_dim=0, group_name=match.kwargs["group_name"], @@ -340,7 +340,7 @@ def reduce_scatter_template(inp: PatternExpr, users: int): match=match, input_node=match.kwargs["input"], reduce_scatter_node=match.nodes[-2], - res_node=node, + wait_tensor_node=node, reduce_op=match.kwargs["reduce_op"], scatter_dim=match.kwargs["scatter_dim"], group_name=match.kwargs["group_name"], @@ -353,7 +353,7 @@ def reduce_scatter_template(inp: PatternExpr, users: int): match=match, input_node=match.kwargs["input"], reduce_scatter_node=match.nodes[0], - res_node=node, + wait_tensor_node=node, reduce_op=match.kwargs["reduce_op"], scatter_dim=0, group_name=match.kwargs["group_name"], @@ -856,14 +856,14 @@ def fuse_matmul_reduce_scatter(reduce_scatter: _ReduceScatterMatch) -> None: ( input_node, _reduce_scatter_node, - rs_res_node, + rs_wait_tensor_node, reduce_op, orig_scatter_dim, group_name, ) = ( reduce_scatter.input_node, reduce_scatter.reduce_scatter_node, - reduce_scatter.res_node, + reduce_scatter.wait_tensor_node, reduce_scatter.reduce_op, reduce_scatter.scatter_dim, reduce_scatter.group_name, @@ -890,7 +890,7 @@ def fuse_matmul_reduce_scatter(reduce_scatter: _ReduceScatterMatch) -> None: ) return - if rs_res_node in matmul.arg_ancestor_nodes: + if rs_wait_tensor_node in matmul.arg_ancestor_nodes: log.warning( "reduce-scatter result node is an ancestor of matmul, skipping fuse_matmul_reduce_scatter fusion" ) @@ -920,8 +920,8 @@ def fuse_matmul_reduce_scatter(reduce_scatter: _ReduceScatterMatch) -> None: B_shape = list(_get_tensor(matmul.B_node).shape) output_shape = [*A_orig_shape[:-1], B_shape[-1]] - graph = rs_res_node.graph - with graph.inserting_before(rs_res_node): + graph = rs_wait_tensor_node.graph + with graph.inserting_before(rs_wait_tensor_node): # Restride A tensor before fused op, for optimal perf in fused matmul reduce scatter if "val" in matmul.A_node.meta: restrided = restride_A_for_fused_matmul_reduce_scatter( From 5b6264ef332d82e6706eccc73f8744f7ef8b626a Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Fri, 28 Mar 2025 08:37:03 -0700 Subject: [PATCH 4/8] update err msg --- torch/_inductor/fx_passes/micro_pipeline_tp.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch/_inductor/fx_passes/micro_pipeline_tp.py b/torch/_inductor/fx_passes/micro_pipeline_tp.py index 905a2119aa6c03..d0eee124564f7f 100644 --- a/torch/_inductor/fx_passes/micro_pipeline_tp.py +++ b/torch/_inductor/fx_passes/micro_pipeline_tp.py @@ -844,8 +844,8 @@ def fuse_matmul_reduce_scatter(reduce_scatter: _ReduceScatterMatch) -> None: Returns boolean indicating if fusion was successful or not. """ - assert torch.distributed.is_available() or torch.distributed.is_nccl_available(), ( - "torch.distributed must be available to use async tensor parallelism" + assert torch.distributed.is_available() and torch.distributed.is_nccl_available(), ( + "torch.distributed and NCCL must be available to use async tensor parallelism" ) from torch.distributed._symmetric_memory import ( From 2c4f70b73172fd1d8c24de23aeb28be8a032b66e Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Fri, 28 Mar 2025 08:46:07 -0700 Subject: [PATCH 5/8] update assert msg --- torch/_inductor/fx_passes/micro_pipeline_tp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/_inductor/fx_passes/micro_pipeline_tp.py b/torch/_inductor/fx_passes/micro_pipeline_tp.py index d0eee124564f7f..e60cf8d02a537c 100644 --- a/torch/_inductor/fx_passes/micro_pipeline_tp.py +++ b/torch/_inductor/fx_passes/micro_pipeline_tp.py @@ -870,7 +870,7 @@ def fuse_matmul_reduce_scatter(reduce_scatter: _ReduceScatterMatch) -> None: ) assert is_symm_mem_enabled_for_group(group_name), ( - f"symmetric memory is not enabled for process group {group_name}, skipping fuse_matmul_reduce_scatter fusion" + f"symmetric memory is not enabled for process group {group_name}, this is required for async TP" ) # Currently fused_matmul_reduce_scatter doesn't return the matmul result, From 3152d1ba26a6f5e891dc6dc6c149557090de5795 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Sat, 29 Mar 2025 09:32:37 -0700 Subject: [PATCH 6/8] assert rs node only has 1 user after graph replacement --- torch/_inductor/fx_passes/micro_pipeline_tp.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/torch/_inductor/fx_passes/micro_pipeline_tp.py b/torch/_inductor/fx_passes/micro_pipeline_tp.py index e60cf8d02a537c..46a85953776f67 100644 --- a/torch/_inductor/fx_passes/micro_pipeline_tp.py +++ b/torch/_inductor/fx_passes/micro_pipeline_tp.py @@ -230,6 +230,10 @@ def _update_save_for_backward(self, new_node: torch.fx.Node) -> None: if output_node is not None: output_node.replace_input_with(self.reduce_scatter_node, new_node) + # Assert that now the reduce scatter node has only one user (the wait_tensor) and it's not + # saved for backward anymore. + assert len(self.reduce_scatter_node.users) == 1, "Reduce scatter node has multiple users, this is not expected" + def erase(self) -> None: for node in reversed(self.match.nodes): if len(node.users) == 0: From 8d7f14ce58365c9c1a329578a2ea0e99f2b4a52b Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Sat, 29 Mar 2025 09:57:37 -0700 Subject: [PATCH 7/8] lint --- torch/_inductor/fx_passes/micro_pipeline_tp.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/torch/_inductor/fx_passes/micro_pipeline_tp.py b/torch/_inductor/fx_passes/micro_pipeline_tp.py index 46a85953776f67..3574022b154cbd 100644 --- a/torch/_inductor/fx_passes/micro_pipeline_tp.py +++ b/torch/_inductor/fx_passes/micro_pipeline_tp.py @@ -232,7 +232,9 @@ def _update_save_for_backward(self, new_node: torch.fx.Node) -> None: # Assert that now the reduce scatter node has only one user (the wait_tensor) and it's not # saved for backward anymore. - assert len(self.reduce_scatter_node.users) == 1, "Reduce scatter node has multiple users, this is not expected" + assert len(self.reduce_scatter_node.users) == 1, ( + "Reduce scatter node has multiple users, this is not expected" + ) def erase(self) -> None: for node in reversed(self.match.nodes): @@ -1069,7 +1071,7 @@ def micro_pipeline_tp_pass(graph: torch.fx.Graph): "async TP found no matching all-gather/reduce-scatter patterns for fusion" ) - # TODO: raise an exception if we're using async TP but failed to fuse any all-gather-matmuls + torch.distributed.breakpoint() for all_gather in all_gathers: fuse_all_gather_matmul(all_gather) From 921cff412f84ad8a8bc87011c8bf1c2d51295f99 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Sat, 29 Mar 2025 15:16:52 -0700 Subject: [PATCH 8/8] remove breakpoint --- torch/_inductor/fx_passes/micro_pipeline_tp.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torch/_inductor/fx_passes/micro_pipeline_tp.py b/torch/_inductor/fx_passes/micro_pipeline_tp.py index 3574022b154cbd..13d18b6d9b9c3a 100644 --- a/torch/_inductor/fx_passes/micro_pipeline_tp.py +++ b/torch/_inductor/fx_passes/micro_pipeline_tp.py @@ -1071,7 +1071,6 @@ def micro_pipeline_tp_pass(graph: torch.fx.Graph): "async TP found no matching all-gather/reduce-scatter patterns for fusion" ) - torch.distributed.breakpoint() for all_gather in all_gathers: fuse_all_gather_matmul(all_gather)