-
Notifications
You must be signed in to change notification settings - Fork 24.2k
[Async TP] Fuse matmul-reduce-scatters when reduce scatters have multiple users, and save fused node for backward instead of reduce_scatter node #149946
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasi 8000 onally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/149946
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (4 Unrelated Failures)As of commit 921cff4 with merge base 2bd5bfa ( FLAKY - The following job failed but was likely due to flakiness present on trunk:
BROKEN TRUNK - The following job failed but was present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
UNSTABLE - The following jobs are marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
8c1a679
to
fe00fe3
Compare
@@ -211,8 +211,25 @@ class _ReduceScatterMatch: | |||
group_name: str | |||
|
|||
def replace_with(self, new_node: torch.fx.Node) -> None: | |||
# Replace all uses of the result node (wait_tensor) with the fused node. | |||
self.res_node.replace_all_uses_with(new_node) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so this node is presumably replacing all usages of the original reduce_scatter
output node with the new fused node. Does that sound right? If that sounds right, maybe we should look into why it is not successfully replacing the usage by the target == 'output'
node in the graph?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This line is actually replacing all usages of the result node (wait_tensor) with the new fused node. IMO these variable names are a bit ambiguous (rs_node
vs res_node
) so I just pushed a change changing:
rs_node
=>reduce_scatter_node
res_node
=>wait_tensor_node
(for ReduceScatterMatch, where the result node is explicitly always a wait_tensor).
Hopefully that helps clarify the approach here. Basically we are replacing all uses of the result/wait_tensor with the fused node, and then if the reduce_scatter node is being used by output
(i.e., saved for backward), we change output to use the fused node instead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
makes sense, thanks for explaining!
Let me know if you need me to review the PR. Happy to review it - also @kwen2501 might be interested since it touches the asyncTP code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok awesome! Yes if you and @kwen2501 want to review that would be greatly appreciated, thanks
fe00fe3
to
8943e6d
Compare
output_node = user | ||
break | ||
if output_node is not None: | ||
output_node.replace_input_with(self.reduce_scatter_node, new_node) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it true that the original reduce_scatter_node
should have only one user at this moment (the original wait_tensor)? iirc, only wait_tensor_node
will be used by other nodes. So can we also assert that there is only one user for reduce_scatter_node
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, the reduce-scatter may have 2 users: the wait_tensor, and the final output node (if it is saved for backward)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@fegin I think one alternative here is that we do something to force the partitioner to never save a collective for bw directly, and to always save its corresponding wait_tensor
. We could try to do this. Two things though:
(1) You could imagine cases where.... a collective is run in the forward, but its result is not actually needed until the backward. In that case, it would actually be more profitable to delay the sync until the backward when the collective is actually used. I can pretty easily construct a case like this but I'm not sure how likely it is to show up in practice
(1) I'm not sure how difficult of an invariant this would be to maintain in the partitioner. So the tradeoff here is probably more around "increased complexity in the partitioner" vs "increased complexity in the pattern matcher"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@danielvegamyhre I didn't express my question clearly. I meant after the output_node .replace_input_with()
, there should be only one node for the ORIGINAL reduce_scatter_node
.
@bdhirsh I do believe there are some cases where collective waits are intentionally delayed until the backward. So making partitioner have the assumption is not great. I simply wants to add some check AFTER the replacement.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@danielvegamyhre I didn't express my question clearly. I meant after the output_node .replace_input_with(), there should be only one node for the ORIGINAL reduce_scatter_node.
@fegin Sure, that would be a helpful check - I added an assertion here and reran torchittan training runs for bf16, float8 tensorwise, float8 rowwise and validated everything still works as expected (with the exception of all-gather-matmuls not fusing properly for rowwise scales, which is a known issue I'm already tracking in #149990).
|
||
if reduce_scatters and not fused_reduce_scatters: | ||
raise AssertionError("no successful fusions of matul-reduce-scatters") | ||
fuse_matmul_reduce_scatter(reduce_scatter) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Curious why do we decide to change the check of the return value? What if there are no successful fusion occurred, do we still want to raise?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I found that there were only valid scaled_mm-reduce-scatter patterns in the forward graph, but not in the backward graph, so we can't assert this here. In the backward graph, the reduce-scatters are receiving as input the addition of various scaled_mms, so it's not what we are looking to fuse. See diagram below:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
uh, thanks, his makes sense.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, let me know if adding the check of reduce_scatter_node
make sense. There are some test errors, at least linter is real. Not sure if others are broken in the trunk.
@pytorchbot merge -f "dr ci confirmed test failures are unrelated" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…iple users, and save fused node for backward instead of reduce_scatter node (pytorch#149946) Fixes pytorch#149876 ## Stack - [previous PR in stack] pytorch#149247 ## TL;DR This PR implements support in async TP for saving the reduce-scatter result for backward, which previously would break the torchtitan AC policies: no AC, per op SAC, and per layer SAC. ## Context In torchtitan's LLama3 per op SAC policy, we want to save the output of `reduce_scatter` ops for backward, which is useful for TP. The reduce_scatter op is also saved for No AC (since all activations are saved) and per layer SAC (since we save the activations for N full layers, which do contain reduce-scatters for TP. However, doing this causes incompatibility with Async TP for the AC policies above, for 2 reasons: 1) The graph pattern matching specifically only matches on reduce scatter nodes with 1 user, but reduce_scatter nodes saved for backwards will have 2 users (the 2nd one being the return/output node, which saves it for backward). 2) The subgraph replacement logic which replaces the users of the `wait_tensor` after the reduce-scatter with the new fused node has no mechanism to save the fused_node for backward instead of the reduce-scatter node. This means we cannot directly replace the subgraph, since we can't delete nodes which still have users (in this case, the output node is still using the reduce-scatter node). To fix this, we do 2 things: 1) Add additional pattern matching logic to also match reduce-scatter nodes with 2 users, so we also perform fusion when reduce-scatter is saved for backward. 2) When replacing the subgraph with the fused node, detect if the reduce-scatter was saved for backward, and if so, save the result of the fused node for backward instead. This enables us to properly erase the subgraph and prevent the memory leak which occurred in pytorch#149876 ## Other changes - Continue to throw an error if we don't find any candidate all-gathers or reduce-scatters for fusion (since TP should have both) but DON'T throw an error if we don't fuse any matmul-reduce-scatters. This is because I've found there are actually valid graphs where we do fuse reduce scatters in the forward graph but not the backward graph (in the backward pass there are reduce-scatters but the producer op is an "add" not a mm/scaled_mm). ## Test plan 1. All unit tests are passing 2. Visualized the graphs and verified the fusion is occurring properly. 3. Verified via manual torchtitan runs there is no memory leak / OOM occurring anymore. Pull Request resolved: pytorch#149946 Approved by: https://github.com/fegin
… now that it's supported in core (#1031) ## Summary In pytorch/pytorch#149876 I found there was a problem with per op SAC, per layer SAC, and no AC because all these settings saved reduce_scatter_tensor for backward but this a problem: it broke the async TP pattern matching which expects reduce scatter node to only have 1 user (wait_tensor), not 2 (wait_tensor and output_node). In pytorch/pytorch#149946 I addressed this by: 1) Adding new graph patterns to match on which allow reduce_scatter to have 2 users. 2) Updating the subgraph replacement logic to save the "fused matmul reduce scatter" node for backward instead of the reduce scatter node, if it detects the graph is saving reduce_scatter for backward. This allows the original matmul reduce scatter graph to be replaced and erased correctly. Once pytorch/pytorch#149946 is landed, we can add back reduce_scatter_tensor to the op save list for SAC in torchtitan, and it won't break SAC and no AC anymore 👍
Fixes #149876
Stack
TL;DR
This PR implements support in async TP for saving the reduce-scatter result for backward, which previously would break the torchtitan AC policies: no AC, per op SAC, and per layer SAC.
Context
In torchtitan's LLama3 per op SAC policy, we want to save the output of
reduce_scatter
ops for backward, which is useful for TP. The reduce_scatter op is also saved for No AC (since all activations are saved) and per layer SAC (since we save the activations for N full layers, which do contain reduce-scatters for TP.However, doing this causes incompatibility with Async TP for the AC policies above, for 2 reasons:
The graph pattern matching specifically only matches on reduce scatter nodes with 1 user, but reduce_scatter nodes saved for backwards will have 2 users (the 2nd one being the return/output node, which saves it for backward).
The subgraph replacement logic which replaces the users of the
wait_tensor
after the reduce-scatter with the new fused node has no mechanism to save the fused_node for backward instead of the reduce-scatter node. This means we cannot directly replace the subgraph, since we can't delete nodes which still have users (in this case, the output node is still using the reduce-scatter node).To fix this, we do 2 things:
Add additional pattern matching logic to also match reduce-scatter nodes with 2 users, so we also perform fusion when reduce-scatter is saved for backward.
When replacing the subgraph with the fused node, detect if the reduce-scatter was saved for backward, and if so, save the result of the fused node for backward instead. This enables us to properly erase the subgraph and prevent the memory leak which occurred in [Async TP] Activations not cleared after backward when reduce_scatter_tensor saved for backward by per op SAC #149876
Other changes
Test plan
cc @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @xmfan