8000 [Async TP] Fuse matmul-reduce-scatters when reduce scatters have mult… · pytorch/pytorch@157bff2 · GitHub
[go: up one dir, main page]

Skip to content

Commit 157bff2

Browse files
danielvegamyhrepytorchmergebot
authored andcommitted
[Async TP] Fuse matmul-reduce-scatters when reduce scatters have multiple users, and save fused node for backward instead of reduce_scatter node (#149946)
Fixes #149876 ## Stack - [previous PR in stack] #149247 ## TL;DR This PR implements support in async TP for saving the reduce-scatter result for backward, which previously would break the torchtitan AC policies: no AC, per op SAC, and per layer SAC. ## Context In torchtitan's LLama3 per op SAC policy, we want to save the output of `reduce_scatter` ops for backward, which is useful for TP. The reduce_scatter op is also saved for No AC (since all activations are saved) and per layer SAC (since we save the activations for N full layers, which do contain reduce-scatters for TP. However, doing this causes incompatibility with Async TP for the AC policies above, for 2 reasons: 1) The graph pattern matching specifically only matches on reduce scatter nodes with 1 user, but reduce_scatter nodes saved for backwards will have 2 users (the 2nd one being the return/output node, which saves it for backward). 2) The subgraph replacement logic which replaces the users of the `wait_tensor` after the reduce-scatter with the new fused node has no mechanism to save the fused_node for backward instead of the reduce-scatter node. This means we cannot directly replace the subgraph, since we can't delete nodes which still have users (in this case, the output node is still using the reduce-scatter node). To fix this, we do 2 things: 1) Add additional pattern matching logic to also match reduce-scatter nodes with 2 users, so we also perform fusion when reduce-scatter is saved for backward. 2) When replacing the subgraph with the fused node, detect if the reduce-scatter was saved for backward, and if so, save the result of the fused node for backward instead. This enables us to properly erase the subgraph and prevent the memory leak which occurred in #149876 ## Other changes - Continue to throw an error if we don't find any candidate all-gathers or reduce-scatters for fusion (since TP should have both) but DON'T throw an error if we don't fuse any matmul-reduce-scatters. This is because I've found there are actually valid graphs where we do fuse reduce scatters in the forward graph but not the backward graph (in the backward pass there are reduce-scatters but the producer op is an "add" not a mm/scaled_mm). ## Test plan 1. All unit tests are passing 2. Visualized the graphs and verified the fusion is occurring properly. 3. Verified via manual torchtitan runs there is no memory leak / OOM occurring anymore. Pull Request resolved: #149946 Approved by: https://github.com/fegin
1 parent cbc0964 commit 157bff2

File tree

2 files changed

+123
-88
lines changed

2 files changed

+123
-88
lines changed

test/distributed/tensor/parallel/test_micro_pipeline_tp.py

Lines changed: 2 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -157,11 +157,11 @@ def func(inp: torch.Tensor) -> torch.Tensor:
157157
"placeholder",
158158
)
159159
self.assertEqual(
160-
reduce_scatter.rs_node.target,
160+
reduce_scatter.reduce_scatter_node.target,
161161
torch.ops._c10d_functional.reduce_scatter_tensor.default,
162162
)
163163
self.assertEqual(
164-
reduce_scatter.res_node.target,
164+
reduce_scatter.wait_tensor_node.target,
165165
torch.ops._c10d_functional.wait_tensor.default,
166166
)
167167
self.assertEqual(reduce_scatter.group_name, group.group_name)
@@ -492,44 +492,6 @@ def no_matching_pattern(
492492
gm.graph,
493493
)
494494

495-
@skipIfRocm
496-
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
497-
@fresh_inductor_cache()
498-
def test_unsuccessful_fusion(self):
499-
group = dist.group.WORLD
500-
scatter_dim = 0
501-
502-
def no_matching_pattern(
503-
A: torch.Tensor,
504-
B: torch.Tensor,
505-
) -> torch.Tensor:
506-
"""
507-
Performs 'reshape -> reciprocal -> mm -> reshape -> reduce scatter' pattern,
508-
so the extra 'reciprocal' op in the middle should cause pattern matching to fail.
509-
"""
510-
out_shape = [*A.shape[:-1], B.shape[-1]]
511-
A = A.reshape(-1, A.shape[-1])
512-
513-
# insert extra op after reshape that will cause pattern matching to fail
514-
A = torch.reciprocal(A)
515-
516-
C = A @ B
517-
C = C.view(out_shape)
518-
return reduce_scatter_tensor(C, "sum", scatter_dim, group)
519-
520-
A = torch.rand(2, 16, 32, device="cuda").to(torch.bfloat16)
521-
B = torch.rand(16, 32, device="cuda").to(torch.bfloat16).T
522-
523-
gm = _make_post_grad_fx(no_matching_pattern, A, B)
524-
525-
with _test_mode():
526-
self.assertRaisesRegex(
527-
AssertionError,
528-
"no successful fusions of matul-reduce-scatters",
529-
micro_pipeline_tp_pass,
530-
gm.graph,
531-
)
532-
533495
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
534496
@parametrize("shard_dim", [0, 1])
535497
@fresh_inductor_cache()

torch/_inductor/fx_passes/micro_pipeline_tp.py

Lines changed: 121 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -204,14 +204,37 @@ def make_cat_pattern(splits):
204204
class _ReduceScatterMatch:
205205
match: Match
206206
input_node: torch.fx.Node
207-
rs_node: torch.fx.Node
208-
res_node: torch.fx.Node
207+
reduce_scatter_node: torch.fx.Node
208+
wait_tensor_node: torch.fx.Node
209209
reduce_op: str
210210
scatter_dim: int
211211
group_name: str
212212

213213
def replace_with(self, new_node: torch.fx.Node) -> None:
214-
self.res_node.replace_all_uses_with(new_node)
214+
# Replace all uses of the result node (wait_tensor) with the fused node.
215+
self.wait_tensor_node.replace_all_uses_with(new_node)
216+
217+
# If the reduce-scatter result is saved for backward, save the fused node for backward instead.
218+
self._update_save_for_backward(new_node)
219+
220+
def _update_save_for_backward(self, new_node: torch.fx.Node) -> None:
221+
"""
222+
If the output node is a user of the reduce_scatter node (indicating the reduce_scatter
223+
result is saved for backward), this method will update the output node to use the fused node instead.
224+
"""
225+
output_node = None
226+
for user in self.reduce_scatter_node.users:
227+
if user.target == "output":
228+
output_node = user
229+
break
230+
if output_node is not None:
231+
output_node.replace_input_with(self.reduce_scatter_node, new_node)
232+
233+
# Assert that now the reduce scatter node has only one user (the wait_tensor) and it's not
234+
# saved for backward anymore.
235+
assert len(self.reduce_scatter_node.users) == 1, (
236+
"Reduce scatter node has multiple users, this is not expected"
237+
)
215238

216239
def erase(self) -> None:
217240
for node in reversed(self.match.nodes):
@@ -222,7 +245,7 @@ def erase(self) -> None:
222245
def find_reduce_scatter_patterns(graph: torch.fx.Graph):
223246
c10d = torch.ops._c10d_functional
224247

225-
def reduce_scatter_template(inp: PatternExpr):
248+
def reduce_scatter_template(inp: PatternExpr, users: int):
226249
return CallFunction(
227250
c10d.wait_tensor.default,
228251
CallFunction(
@@ -231,14 +254,43 @@ def reduce_scatter_template(inp: PatternExpr):
231254
KeywordArg("reduce_op"),
232255
Ignored(),
233256
KeywordArg("group_name"),
257+
_users=users,
234258
),
235259
)
236260

237261
# Matches funcol.reduce_scatter_tensor with scatter_dim == 0
238-
zero_dim_reduce_scatter_pattern = reduce_scatter_template(KeywordArg("input"))
262+
zero_dim_reduce_scatter_pattern_single_user = reduce_scatter_template(
263+
KeywordArg("input"), users=1
264+
)
265+
266+
# Two users will occur when the reduce-scatter result is saved for backward
267+
zero_dim_reduce_scatter_pattern_multi_user = reduce_scatter_template(
268+
KeywordArg("input"), users=2
269+
)
239270

240271
# Matches funcol.reduce_scatter_tensor with scatter_dim > 0
241-
non_zero_dim_reduce_scatter_pattern = reduce_scatter_template(
272+
non_zero_dim_reduce_scatter_pattern_single_user = reduce_scatter_template(
273+
CallFunction(
274+
aten.cat.default,
275+
ListOf(
276+
CallFunction(
277+
operator.getitem,
278+
CallFunction(
279+
aten.split.Tensor,
280+
KeywordArg("input"),
281+
Ignored(),
282+
KeywordArg("scatter_dim"),
283+
_users=MULTIPLE,
284+
),
285+
Ignored(),
286+
)
287+
),
288+
),
289+
users=1,
290+
)
291+
292+
# Two users will occur when the reduce-scatter result is saved for backward
293+
non_zero_dim_reduce_scatter_pattern_multi_user = reduce_scatter_template(
242294
CallFunction(
243295
aten.cat.default,
244296
ListOf(
@@ -255,32 +307,59 @@ def reduce_scatter_template(inp: PatternExpr):
255307
)
256308
),
257309
),
310+
users=2,
258311
)
259312

260313
reduce_scatters = []
261314
for node in reversed(graph.nodes):
262315
if node.target == c10d.wait_tensor.default:
263-
if match := non_zero_dim_reduce_scatter_pattern.match(node):
316+
if match := non_zero_dim_reduce_scatter_pattern_single_user.match(node):
317+
assert isinstance(match, Match)
318+
reduce_scatters.append(
319+
_ReduceScatterMatch(
320+
match=match,
321+
input_node=match.kwargs["input"],
322+
reduce_scatter_node=match.nodes[-2],
323+
wait_tensor_node=node,
324+
reduce_op=match.kwargs["reduce_op"],
325+
scatter_dim=match.kwargs["scatter_dim"],
326+
group_name=match.kwargs["group_name"],
327+
)
328+
)
329+
elif match := zero_dim_reduce_scatter_pattern_single_user.match(node):
330+
assert isinstance(match, Match)
331+
reduce_scatters.append(
332+
_ReduceScatterMatch(
333+
match=match,
334+
input_node=match.kwargs["input"],
335+
reduce_scatter_node=match.nodes[0],
336+
wait_tensor_node=node,
337+
reduce_op=match.kwargs["reduce_op"],
338+
scatter_dim=0,
339+
group_name=match.kwargs["group_name"],
340+
)
341+
)
342+
elif match := non_zero_dim_reduce_scatter_pattern_multi_user.match(node):
264343
assert isinstance(match, Match)
265344
reduce_scatters.append(
266345
_ReduceScatterMatch(
267346
match=match,
268347
input_node=match.kwargs["input"],
269-
rs_node=match.nodes[-2],
270-
res_node=node,
348+
reduce_scatter_node=match.nodes[-2],
349+
wait_tensor_node=node,
271350
reduce_op=match.kwargs["reduce_op"],
272351
scatter_dim=match.kwargs["scatter_dim"],
273352
group_name=match.kwargs["group_name"],
274353
)
275354
)
276-
elif match := zero_dim_reduce_scatter_pattern.match(node):
355+
elif match := zero_dim_reduce_scatter_pattern_multi_user.match(node):
277356
assert isinstance(match, Match)
278357
reduce_scatters.append(
279358
_ReduceScatterMatch(
280359
match=match,
281360
input_node=match.kwargs["input"],
282-
rs_node=match.nodes[0],
283-
res_node=node,
361+
reduce_scatter_node=match.nodes[0],
362+
wait_tensor_node=node,
284363
reduce_op=match.kwargs["reduce_op"],
285364
scatter_dim=0,
286365
group_name=match.kwargs["group_name"],
@@ -757,7 +836,7 @@ def _insert_fused_matmul_reduce_scatter(
757836
raise AssertionError(f"Unexpected matmul match type: {type(matmul)}")
758837

759838

760-
def fuse_matmul_reduce_scatter(reduce_scatter: _ReduceScatterMatch) -> bool:
839+
def fuse_matmul_reduce_scatter(reduce_scatter: _ReduceScatterMatch) -> None:
761840
"""
762841
Fused the pattern
763842
@@ -771,58 +850,57 @@ def fuse_matmul_reduce_scatter(reduce_scatter: _ReduceScatterMatch) -> bool:
771850
772851
Returns boolean indicating if fusion was successful or not.
773852
"""
774-
if (
775-
not torch.distributed.is_available()
776-
or not torch.distributed.is_nccl_available()
777-
):
778-
log.debug(
779-
"torch.distributed is not available, skipping fuse_matmul_reduce_scatter fusion"
780-
)
781-
return False
853+
assert torch.distributed.is_available() and torch.distributed.is_nccl_available(), (
854+
"torch.distributed and NCCL must be available to use async tensor parallelism"
855+
)
782856

783857
from torch.distributed._symmetric_memory import (
784858
is_symm_mem_enabled_for_group,
785859
restride_A_for_fused_matmul_reduce_scatter,
786860
)
787861

788-
input_node, _rs_node, rs_res_node, reduce_op, orig_scatter_dim, group_name = (
862+
(
863+
input_node,
864+
_reduce_scatter_node,
865+
rs_wait_tensor_node,
866+
reduce_op,
867+
orig_scatter_dim,
868+
group_name,
869+
) = (
789870
reduce_scatter.input_node,
790-
reduce_scatter.rs_node,
791-
reduce_scatter.res_node,
871+
reduce_scatter.reduce_scatter_node,
872+
reduce_scatter.wait_tensor_node,
792873
reduce_scatter.reduce_op,
793874
reduce_scatter.scatter_dim,
794875
reduce_scatter.group_name,
795876
)
796877

797-
if not is_symm_mem_enabled_for_group(group_name):
798-
log.debug(
799-
"symmetric memory is not enabled for process group %s, skipping fuse_matmul_reduce_scatter fusion",
800-
group_name,
801-
)
802-
return False
878+
assert is_symm_mem_enabled_for_group(group_name), (
879+
f"symmetric memory is not enabled for process group {group_name}, this is required for async TP"
880+
)
803881

804882
# Currently fused_matmul_reduce_scatter doesn't return the matmul result,
805883
# so we can't apply the fusion if the matmul result is used by multiple
806884
# users. This is not a fundamental limitation of the fused op and can be
807885
# addressed if needed.
808886
if len(input_node.users) != 1:
809-
log.debug(
887+
log.warning(
810888
"matmul result has more than one user, skipping fused_matmul_reduce_scatter fusion."
811889
)
812-
return False
890+
return
813891

814892
matmul = _find_producer_matmul(input_node)
815893
if matmul is None:
816-
log.debug(
894+
log.warning(
817895
"no producer matmul found for reduce scatter, skipping fuse_matmul_reduce_scatter fusion"
818896
)
819-
return False
897+
return
820898

821-
if rs_res_node in matmul.arg_ancestor_nodes:
822-
log.debug(
899+
if rs_wait_tensor_node in matmul.arg_ancestor_nodes:
900+
log.warning(
823901
"reduce-scatter result node is an ancestor of matmul, skipping fuse_matmul_reduce_scatter fusion"
824902
)
825-
return False
903+
return
826904

827905
# We need to track 3 values for the fused scaled mm reduce scatter implementation:
828906
# 1. The scatter dim before the reshape, which was assigned using the original (a,b,c) @ (c,d) = (a,b,d) dims.
@@ -848,8 +926,8 @@ def fuse_matmul_reduce_scatter(reduce_scatter: _ReduceScatterMatch) -> bool:
848926
B_shape = list(_get_tensor(matmul.B_node).shape)
849927
output_shape = [*A_orig_shape[:-1], B_shape[-1]]
850928

851-
graph = rs_res_node.graph
852-
with graph.inserting_before(rs_res_node):
929+
graph = rs_wait_tensor_node.graph
930+
with graph.inserting_before(rs_wait_tensor_node):
853931
# Restride A tensor before fused op, for optimal perf in fused matmul reduce scatter
854932
if "val" in matmul.A_node.meta:
855933
restrided = restride_A_for_fused_matmul_reduce_scatter(
@@ -885,7 +963,6 @@ def fuse_matmul_reduce_scatter(reduce_scatter: _ReduceScatterMatch) -> bool:
885963
fused_node.prepend(node)
886964

887965
log.debug("successfully fused matmul reduce scatter")
888-
return True
889966

890967

891968
def _get_node_to_ancestors(
@@ -984,22 +1061,18 @@ def micro_pipeline_tp_pass(graph: torch.fx.Graph):
9841061
unexposed_collectives = _get_unexposed_collectives(graph)
9851062
all_gathers = [x for x in all_gathers if x.ag_node not in unexposed_collectives]
9861063
reduce_scatters = [
987-
x for x in reduce_scatters if x.rs_node not in unexposed_collectives
1064+
x
1065+
for x in reduce_scatters
1066+
if x.reduce_scatter_node not in unexposed_collectives
9881067
]
9891068

9901069
if not all_gathers and not reduce_scatters:
9911070
raise AssertionError(
9921071
"async TP found no matching all-gather/reduce-scatter patterns for fusion"
9931072
)
9941073

995-
# TODO: raise an exception if we're using async TP but failed to fuse any all-gather-matmuls
9961074
for all_gather in all_gathers:
9971075
fuse_all_gather_matmul(all_gather)
9981076

999-
fused_reduce_scatters = False
10001077
for reduce_scatter in reduce_scatters:
1001-
if fuse_matmul_reduce_scatter(reduce_scatter):
1002-
fused_reduce_scatters = True
1003-
1004-
if reduce_scatters and not fused_reduce_scatters:
1005-
raise AssertionError("no successful fusions of matul-reduce-scatters")
1078+
fuse_matmul_reduce_scatter(reduce_scatter)

0 commit comments

Comments
 (0)
0