-
Notifications
You must be signed in to change notification settings - Fork 24.3k
[Async TP] Fuse matmul-reduce-scatters when reduce scatters have multiple users, and save fused node for backward instead of reduce_scatter node #149946
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
7af8fa6
d74bccb
15a891b
5b6264e
2c4f70b
3152d1b
8d7f14c
921cff4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -204,14 +204,37 @@ def make_cat_pattern(splits): | |
class _ReduceScatterMatch: | ||
match: Match | ||
input_node: torch.fx.Node | ||
rs_node: torch.fx.Node | ||
res_node: torch.fx.Node | ||
reduce_scatter_node: torch.fx.Node | ||
wait_tensor_node: torch.fx.Node | ||
reduce_op: str | ||
scatter_dim: int | ||
group_name: str | ||
|
||
def replace_with(self, new_node: torch.fx.Node) -> None: | ||
self.res_node.replace_all_uses_with(new_node) | ||
# Replace all uses of the result node (wait_tensor) with the fused node. | ||
self.wait_tensor_node.replace_all_uses_with(new_node) | ||
|
||
# If the reduce-scatter result is saved for backward, save the fused node for backward instead. | ||
self._update_save_for_backward(new_node) | ||
|
||
def _update_save_for_backward(self, new_node: torch.fx.Node) -> None: | ||
""" | ||
If the output node is a user of the reduce_scatter node (indicating the reduce_scatter | ||
result is saved for backward), this method will update the output node to use the fused node instead. | ||
""" | ||
output_node = None | ||
for user in self.reduce_scatter_node.users: | ||
if user.target == "output": | ||
output_node = user | ||
break | ||
if output_node is not None: | ||
output_node.replace_input_with(self.reduce_scatter_node, new_node) | ||
|
||
# Assert that now the reduce scatter node has only one user (the wait_tensor) and it's not | ||
# saved for backward anymore. | ||
assert len(self.reduce_scatter_node.users) == 1, ( | ||
"Reduce scatter node has multiple users, this is not expected" | ||
) | ||
|
||
def erase(self) -> None: | ||
for node in reversed(self.match.nodes): | ||
|
@@ -222,7 +245,7 @@ def erase(self) -> None: | |
def find_reduce_scatter_patterns(graph: torch.fx.Graph): | ||
c10d = torch.ops._c10d_functional | ||
|
||
def reduce_scatter_template(inp: PatternExpr): | ||
def reduce_scatter_template(inp: PatternExpr, users: int): | ||
return CallFunction( | ||
c10d.wait_tensor.default, | ||
CallFunction( | ||
|
@@ -231,14 +254,43 @@ def reduce_scatter_template(inp: PatternExpr): | |
KeywordArg("reduce_op"), | ||
Ignored(), | ||
KeywordArg("group_name"), | ||
_users=users, | ||
), | ||
) | ||
|
||
# Matches funcol.reduce_scatter_tensor with scatter_dim == 0 | ||
zero_dim_reduce_scatter_pattern = reduce_scatter_template(KeywordArg("input")) | ||
zero_dim_reduce_scatter_pattern_single_user = reduce_scatter_template( | ||
KeywordArg("input"), users=1 | ||
) | ||
|
||
# Two users will occur when the reduce-scatter result is saved for backward | ||
zero_dim_reduce_scatter_pattern_multi_user = reduce_scatter_template( | ||
KeywordArg("input"), users=2 | ||
) | ||
|
||
# Matches funcol.reduce_scatter_tensor with scatter_dim > 0 | ||
non_zero_dim_reduce_scatter_pattern = reduce_scatter_template( | ||
non_zero_dim_reduce_scatter_pattern_single_user = reduce_scatter_template( | ||
CallFunction( | ||
aten.cat.default, | ||
ListOf( | ||
CallFunction( | ||
operator.getitem, | ||
CallFunction( | ||
aten.split.Tensor, | ||
KeywordArg("input"), | ||
Ignored(), | ||
KeywordArg("scatter_dim"), | ||
_users=MULTIPLE, | ||
), | ||
Ignored(), | ||
) | ||
), | ||
), | ||
users=1, | ||
) | ||
|
||
# Two users will occur when the reduce-scatter result is saved for backward | ||
non_zero_dim_reduce_scatter_pattern_multi_user = reduce_scatter_template( | ||
CallFunction( | ||
aten.cat.default, | ||
ListOf( | ||
|
@@ -255,32 +307,59 @@ def reduce_scatter_template(inp: PatternExpr): | |
) | ||
), | ||
), | ||
users=2, | ||
) | ||
|
||
reduce_scatters = [] | ||
for node in reversed(graph.nodes): | ||
if node.target == c10d.wait_tensor.default: | ||
if match := non_zero_dim_reduce_scatter_pattern.match(node): | ||
if match := non_zero_dim_reduce_scatter_pattern_single_user.match(node): | ||
assert isinstance(match, Match) | ||
reduce_scatters.append( | ||
_ReduceScatterMatch( | ||
match=match, | ||
input_node=match.kwargs["input"], | ||
reduce_scatter_node=match.nodes[-2], | ||
wait_tensor_node=node, | ||
reduce_op=match.kwargs["reduce_op"], | ||
scatter_dim=match.kwargs["scatter_dim"], | ||
group_name=match.kwargs["group_name"], | ||
) | ||
) | ||
elif match := zero_dim_reduce_scatter_pattern_single_user.match(node): | ||
assert isinstance(match, Match) | ||
reduce_scatters.append( | ||
_ReduceScatterMatch( | ||
match=match, | ||
input_node=match.kwargs["input"], | ||
reduce_scatter_node=match.nodes[0], | ||
wait_tensor_node=node, | ||
reduce_op=match.kwargs["reduce_op"], | ||
scatter_dim=0, | ||
group_name=match.kwargs["group_name"], | ||
) | ||
) | ||
elif match := non_zero_dim_reduce_scatter_pattern_multi_user.match(node): | ||
assert isinstance(match, Match) | ||
reduce_scatters.append( | ||
_ReduceScatterMatch( | ||
match=match, | ||
input_node=match.kwargs["input"], | ||
rs_node=match.nodes[-2], | ||
res_node=node, | ||
reduce_scatter_node=match.nodes[-2], | ||
wait_tensor_node=node, | ||
reduce_op=match.kwargs["reduce_op"], | ||
scatter_dim=match.kwargs["scatter_dim"], | ||
group_name=match.kwargs["group_name"], | ||
) | ||
) | ||
elif match := zero_dim_reduce_scatter_pattern.match(node): | ||
elif match := zero_dim_reduce_scatter_pattern_multi_user.match(node): | ||
assert isinstance(match, Match) | ||
reduce_scatters.append( | ||
_ReduceScatterMatch( | ||
match=match, | ||
input_node=match.kwargs["input"], | ||
rs_node=match.nodes[0], | ||
res_node=node, | ||
reduce_scatter_node=match.nodes[0], | ||
wait_tensor_node=node, | ||
reduce_op=match.kwargs["reduce_op"], | ||
scatter_dim=0, | ||
group_name=match.kwargs["group_name"], | ||
|
@@ -757,7 +836,7 @@ def _insert_fused_matmul_reduce_scatter( | |
raise AssertionError(f"Unexpected matmul match type: {type(matmul)}") | ||
|
||
|
||
def fuse_matmul_reduce_scatter(reduce_scatter: _ReduceScatterMatch) -> bool: | ||
def fuse_matmul_reduce_scatter(reduce_scatter: _ReduceScatterMatch) -> None: | ||
""" | ||
Fused the pattern | ||
|
||
|
@@ -771,58 +850,57 @@ def fuse_matmul_reduce_scatter(reduce_scatter: _ReduceScatterMatch) -> bool: | |
|
||
Returns boolean indicating if fusion was successful or not. | ||
""" | ||
if ( | ||
not torch.distributed.is_available() | ||
or not torch.distributed.is_nccl_available() | ||
): | ||
log.debug( | ||
"torch.distributed is not available, skipping fuse_matmul_reduce_scatter fusion" | ||
) | ||
return False | ||
assert torch.distributed.is_available() and torch.distributed.is_nccl_available(), ( | ||
"torch.distributed and NCCL must be available to use async tensor parallelism" | ||
) | ||
|
||
from torch.distributed._symmetric_memory import ( | ||
is_symm_mem_enabled_for_group, | ||
restride_A_for_fused_matmul_reduce_scatter, | ||
) | ||
|
||
input_node, _rs_node, rs_res_node, reduce_op, orig_scatter_dim, group_name = ( | ||
( | ||
input_node, | ||
_reduce_scatter_node, | ||
rs_wait_tensor_node, | ||
reduce_op, | ||
orig_scatter_dim, | ||
group_name, | ||
) = ( | ||
reduce_scatter.input_node, | ||
reduce_scatter.rs_node, | ||
reduce_scatter.res_node, | ||
reduce_scatter.reduce_scatter_node, | ||
reduce_scatter.wait_tensor_node, | ||
reduce_scatter.reduce_op, | ||
reduce_scatter.scatter_dim, | ||
reduce_scatter.group_name, | ||
) | ||
|
||
if not is_symm_mem_enabled_for_group(group_name): | ||
log.debug( | ||
"symmetric memory is not enabled for process group %s, skipping fuse_matmul_reduce_scatter fusion", | ||
group_name, | ||
) | ||
return False | ||
assert is_symm_mem_enabled_for_group(group_name), ( | ||
f"symmetric memory is not enabled for process group {group_name}, this is required for async TP" | ||
) | ||
|
||
# Currently fused_matmul_reduce_scatter doesn't return the matmul result, | ||
# so we can't apply the fusion if the matmul result is used by multiple | ||
# users. This is not a fundamental limitation of the fused op and can be | ||
# addressed if needed. | ||
if len(input_node.users) != 1: | ||
log.debug( | ||
log.warning( | ||
"matmul result has more than one user, skipping fused_matmul_reduce_scatter fusion." | ||
) | ||
return False | ||
return | ||
|
||
matmul = _find_producer_matmul(input_node) | ||
if matmul is None: | ||
log.debug( | ||
log.warning( | ||
"no producer matmul found for reduce scatter, skipping fuse_matmul_reduce_scatter fusion" | ||
) | ||
return False | ||
return | ||
|
||
if rs_res_node in matmul.arg_ancestor_nodes: | ||
log.debug( | ||
if rs_wait_tensor_node in matmul.arg_ancestor_nodes: | ||
log.warning( | ||
"reduce-scatter result node is an ancestor of matmul, skipping fuse_matmul_reduce_scatter fusion" | ||
) | ||
return False | ||
return | ||
|
||
# We need to track 3 values for the fused scaled mm reduce scatter implementation: | ||
# 1. The scatter dim before the reshape, which was assigned using the original (a,b,c) @ (c,d) = (a,b,d) dims. | ||
|
@@ -848,8 +926,8 @@ def fuse_matmul_reduce_scatter(reduce_scatter: _ReduceScatterMatch) -> bool: | |
B_shape = list(_get_tensor(matmul.B_node).shape) | ||
output_shape = [*A_orig_shape[:-1], B_shape[-1]] | ||
|
||
graph = rs_res_node.graph | ||
with graph.inserting_before(rs_res_node): | ||
graph = rs_wait_tensor_node.graph | ||
with graph.inserting_before(rs_wait_tensor_node): | ||
# Restride A tensor before fused op, for optimal perf in fused matmul reduce scatter | ||
if "val" in matmul.A_node.meta: | ||
restrided = restride_A_for_fused_matmul_reduce_scatter( | ||
|
@@ -885,7 +963,6 @@ def fuse_matmul_reduce_scatter(reduce_scatter: _ReduceScatterMatch) -> bool: | |
fused_node.prepend(node) | ||
|
||
log.debug("successfully fused matmul reduce scatter") | ||
return True | ||
|
||
|
||
def _get_node_to_ancestors( | ||
|
@@ -984,22 +1061,18 @@ def micro_pipeline_tp_pass(graph: torch.fx.Graph): | |
unexposed_collectives = _get_unexposed_collectives(graph) | ||
all_gathers = [x for x in all_gathers if x.ag_node not in unexposed_collectives] | ||
reduce_scatters = [ | ||
x for x in reduce_scatters if x.rs_node not in unexposed_collectives | ||
x | ||
for x in reduce_scatters | ||
if x.reduce_scatter_node not in unexposed_collectives | ||
] | ||
|
||
if not all_gathers and not reduce_scatters: | ||
raise AssertionError( | ||
"async TP found no matching all-gather/reduce-scatter patterns for fusion" | ||
) | ||
|
||
# TODO: raise an exception if we're using async TP but failed to fuse any all-gather-matmuls | ||
for all_gather in all_gathers: | ||
fuse_all_gather_matmul(all_gather) | ||
|
||
fused_reduce_scatters = False | ||
for reduce_scatter in reduce_scatters: | ||
if fuse_matmul_reduce_scatter(reduce_scatter): | ||
fused_reduce_scatters = True | ||
|
||
if reduce_scatters and not fused_reduce_scatters: | ||
raise AssertionError("no successful fusions of matul-reduce-scatters") | ||
fuse_matmul_reduce_scatter(reduce_scatter) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Curious why do we decide to change the check of the return value? What if there are no successful fusion occurred, do we still want to raise? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah I found that there were only valid scaled_mm-reduce-scatter patterns in the forward graph, but not in the backward graph, so we can't assert this here. In the backward graph, the reduce-scatters are receiving as input the addition of various scaled_mms, so it's not what we are looking to fuse. See diagram below: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. uh, thanks, his makes sense. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it true that the original
reduce_scatter_node
should have only one user at this moment (the original wait_tensor)? iirc, onlywait_tensor_node
will be used by other nodes. So can we also assert that there is only one user forreduce_scatter_node
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, the reduce-scatter may have 2 users: the wait_tensor, and the final output node (if it is saved for backward)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@fegin I think one alternative here is that we do something to force the partitioner to never save a collective for bw directly, and to always save its corresponding
wait_tensor
. We could try to do this. Two things though:(1) You could imagine cases where.... a collective is run in the forward, but its result is not actually needed until the backward. In that case, it would actually be more profitable to delay the sync until the backward when the collective is actually used. I can pretty easily construct a case like this but I'm not sure how likely it is to show up in practice
(1) I'm not sure how difficult of an invariant this would be to maintain in the partitioner. So the tradeoff here is probably more around "increased complexity in the partitioner" vs "increased complexity in the pattern matcher"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@danielvegamyhre I didn't express my question clearly. I meant after the
output_node .replace_input_with()
, there should be only one node for the ORIGINALreduce_scatter_node
.@bdhirsh I do believe there are some cases where collective waits are intentionally delayed until the backward. So making partitioner have the assumption is not great. I simply wants to add some check AFTER the replacement.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@fegin Sure, that would be a helpful check - I added an assertion here and reran torchittan training runs for bf16, float8 tensorwise, float8 rowwise and validated everything still works as expected (with the exception of all-gather-matmuls not fusing properly for rowwise scales, which is a known issue I'm already tracking in #149990).