8000 [Async TP] Fuse matmul-reduce-scatters when reduce scatters have multiple users, and save fused node for backward instead of reduce_scatter node by danielvegamyhre · Pull Request #149946 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[Async TP] Fuse matmul-reduce-scatters when reduce scatters have multiple users, and save fused node for backward instead of reduce_scatter node #149946

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasi 8000 onally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 8 commits into from

Conversation

danielvegamyhre
Copy link
Contributor
@danielvegamyhre danielvegamyhre commented Mar 25, 2025

Fixes #149876

Stack

TL;DR

This PR implements support in async TP for saving the reduce-scatter result for backward, which previously would break the torchtitan AC policies: no AC, per op SAC, and per layer SAC.

Context

In torchtitan's LLama3 per op SAC policy, we want to save the output of reduce_scatter ops for backward, which is useful for TP. The reduce_scatter op is also saved for No AC (since all activations are saved) and per layer SAC (since we save the activations for N full layers, which do contain reduce-scatters for TP.

However, doing this causes incompatibility with Async TP for the AC policies above, for 2 reasons:

  1. The graph pattern matching specifically only matches on reduce scatter nodes with 1 user, but reduce_scatter nodes saved for backwards will have 2 users (the 2nd one being the return/output node, which saves it for backward).

  2. The subgraph replacement logic which replaces the users of the wait_tensor after the reduce-scatter with the new fused node has no mechanism to save the fused_node for backward instead of the reduce-scatter node. This means we cannot directly replace the subgraph, since we can't delete nodes which still have users (in this case, the output node is still using the reduce-scatter node).

To fix this, we do 2 things:

  1. Add additional pattern matching logic to also match reduce-scatter nodes with 2 users, so we also perform fusion when reduce-scatter is saved for backward.

  2. When replacing the subgraph with the fused node, detect if the reduce-scatter was saved for backward, and if so, save the result of the fused node for backward instead. This enables us to properly erase the subgraph and prevent the memory leak which occurred in [Async TP] Activations not cleared after backward when reduce_scatter_tensor saved for backward by per op SAC #149876

Other changes

  • Continue to throw an error if we don't find any candidate all-gathers or reduce-scatters for fusion (since TP should have both) but DON'T throw an error if we don't fuse any matmul-reduce-scatters. This is because I've found there are actually valid graphs where we do fuse reduce scatters in the forward graph but not the backward graph (in the backward pass there are reduce-scatters but the producer op is an "add" not a mm/scaled_mm).

Test plan

  1. All unit tests are passing
  2. Visualized the graphs and verified the fusion is occurring properly.
  3. Verified via manual torchtitan runs there is no memory leak / OOM occurring anymore.

cc @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @xmfan

Copy link
pytorch-bot bot commented Mar 25, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/149946

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (4 Unrelated Failures)

As of commit 921cff4 with merge base 2bd5bfa (image):

FLAKY - The following job failed but was likely due to flakiness present on trunk:

BROKEN TRUNK - The following job failed but was present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

UNSTABLE - The following jobs are marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@danielvegamyhre
Copy link
Contributor Author

@bdhirsh i'd appreciate your thoughts on this section in particular, let me know if you think this solution to the issue we discussed in #149876 makes sense.

@@ -211,8 +211,25 @@ class _ReduceScatterMatch:
group_name: str

def replace_with(self, new_node: torch.fx.Node) -> None:
# Replace all uses of the result node (wait_tensor) with the fused node.
self.res_node.replace_all_uses_with(new_node)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so this node is presumably replacing all usages of the original reduce_scatter output node with the new fused node. Does that sound right? If that sounds right, maybe we should look into why it is not successfully replacing the usage by the target == 'output' node in the graph?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line is actually replacing all usages of the result node (wait_tensor) with the new fused node. IMO these variable names are a bit ambiguous (rs_node vs res_node) so I just pushed a change changing:

  • rs_node => reduce_scatter_node
  • res_node => wait_tensor_node (for ReduceScatterMatch, where the result node is explicitly always a wait_tensor).

Hopefully that helps clarify the approach here. Basically we are replacing all uses of the result/wait_tensor with the fused node, and then if the reduce_scatter node is being used by output (i.e., saved for backward), we change output to use the fused node instead.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense, thanks for explaining!

Let me know if you need me to review the PR. Happy to review it - also @kwen2501 might be interested since it touches the asyncTP code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok awesome! Yes if you and @kwen2501 want to review that would be greatly appreciated, thanks

output_node = user
break
if output_node is not None:
output_node.replace_input_with(self.reduce_scatter_node, new_node)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it true that the original reduce_scatter_node should have only one user at this moment (the original wait_tensor)? iirc, only wait_tensor_node will be used by other nodes. So can we also assert that there is only one user for reduce_scatter_node?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, the reduce-scatter may have 2 users: the wait_tensor, and the final output node (if it is saved for backward)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fegin I think one alternative here is that we do something to force the partitioner to never save a collective for bw directly, and to always save its corresponding wait_tensor. We could try to do this. Two things though:

(1) You could imagine cases where.... a collective is run in the forward, but its result is not actually needed until the backward. In that case, it would actually be more profitable to delay the sync until the backward when the collective is actually used. I can pretty easily construct a case like this but I'm not sure how likely it is to show up in practice

(1) I'm not sure how difficult of an invariant this would be to maintain in the partitioner. So the tradeoff here is probably more around "increased complexity in the partitioner" vs "increased complexity in the pattern matcher"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@danielvegamyhre I didn't express my question clearly. I meant after the output_node .replace_input_with(), there should be only one node for the ORIGINAL reduce_scatter_node.

@bdhirsh I do believe there are some cases where collective waits are intentionally delayed until the backward. So making partitioner have the assumption is not great. I simply wants to add some check AFTER the replacement.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@danielvegamyhre I didn't express my question clearly. I meant after the output_node .replace_input_with(), there should be only one node for the ORIGINAL reduce_scatter_node.

@fegin Sure, that would be a helpful check - I added an assertion here and reran torchittan training runs for bf16, float8 tensorwise, float8 rowwise and validated everything still works as expected (with the exception of all-gather-matmuls not fusing properly for rowwise scales, which is a known issue I'm already tracking in #149990).


if reduce_scatters and not fused_reduce_scatters:
raise AssertionError("no successful fusions of matul-reduce-scatters")
fuse_matmul_reduce_scatter(reduce_scatter)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious why do we decide to change the check of the return value? What if there are no successful fusion occurred, do we still want to raise?

Copy link
Contributor Author
@danielvegamyhre danielvegamyhre Mar 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I found that there were only valid scaled_mm-reduce-scatter patterns in the forward graph, but not in the backward graph, so we can't assert this here. In the backward graph, the reduce-scatters are receiving as input the addition of various scaled_mms, so it's not what we are looking to fuse. See diagram below:

Screenshot 2025-03-28 at 8 41 43 AM

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

uh, thanks, his makes sense.

Copy link
Contributor
@fegin fegin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, let me know if adding the check of reduce_scatter_node make sense. There are some test errors, at least linter is real. Not sure if others are broken in the trunk.

@danielvegamyhre
Copy link
Contributor Author

@pytorchbot merge -f "dr ci confirmed test failures are unrelated"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

amathewc pushed a commit to amathewc/pytorch that referenced this pull request Apr 17, 2025
…iple users, and save fused node for backward instead of reduce_scatter node (pytorch#149946)

Fixes pytorch#149876

## Stack
- [previous PR in stack] pytorch#149247

## TL;DR
This PR implements support in async TP for saving the reduce-scatter result for backward, which previously would break the torchtitan AC policies: no AC, per op SAC, and per layer SAC.

## Context
In torchtitan's LLama3 per op SAC policy, we want to save the output of `reduce_scatter` ops for backward, which is useful for TP. The reduce_scatter op is also saved for No AC (since all activations are saved) and per layer SAC (since we save the activations for N full layers, which do contain reduce-scatters for TP.

However, doing this causes incompatibility with Async TP for the AC policies above, for 2 reasons:

1) The graph pattern matching specifically only matches on reduce scatter nodes with 1 user, but reduce_scatter nodes saved for backwards will have 2 users (the 2nd one being the return/output node, which saves it for backward).

2) The subgraph replacement logic which replaces the users of the `wait_tensor` after the reduce-scatter with the new fused node has no mechanism to save the fused_node for backward instead of the reduce-scatter node. This means we cannot directly replace the subgraph, since we can't delete nodes which still have users (in this case, the output node is still using the reduce-scatter node).

To fix this, we do 2 things:

1) Add additional pattern matching logic to also match reduce-scatter nodes with 2 users, so we also perform fusion when reduce-scatter is saved for backward.

2) When replacing the subgraph with the fused node, detect if the reduce-scatter was saved for backward, and if so, save the result of the fused node for backward instead. This enables us to properly erase the subgraph and prevent the memory leak which occurred in pytorch#149876

## Other changes
- Continue to throw an error if we don't find any candidate all-gathers or reduce-scatters for fusion (since TP should have both) but DON'T throw an error if we don't fuse any matmul-reduce-scatters. This is because I've found there are actually valid graphs where we do fuse reduce scatters in the forward graph but not the backward graph (in the backward pass there are reduce-scatters but the producer op is an "add" not a mm/scaled_mm).

## Test plan

1. All unit tests are passing
2. Visualized the graphs and verified the fusion is occurring properly.
3. Verified via manual torchtitan runs there is no memory leak / OOM occurring anymore.

Pull Request resolved: pytorch#149946
Approved by: https://github.com/fegin
danielvegamyhre added a commit to pytorch/torchtitan that referenced this pull request Apr 21, 2025
… now that it's supported in core (#1031)

## Summary

In pytorch/pytorch#149876 I found there was a
problem with per op SAC, per layer SAC, and no AC because all these
settings saved reduce_scatter_tensor for backward but this a problem: it
broke the async TP pattern matching which expects reduce scatter node to
only have 1 user (wait_tensor), not 2 (wait_tensor and output_node).

In pytorch/pytorch#149946 I addressed this by:

1) Adding new graph patterns to match on which allow reduce_scatter to
have 2 users.
2) Updating the subgraph replacement logic to save the "fused matmul
reduce scatter" node for backward instead of the reduce scatter node, if
it detects the graph is saving reduce_scatter for backward. This allows
the original matmul reduce scatter graph to be replaced and erased
correctly.

Once pytorch/pytorch#149946 is landed, we can
add back reduce_scatter_tensor to the op save list for SAC in
torchtitan, and it won't break SAC and no AC anymore 👍
@github-actions github-actions bot deleted the rs-ac branch May 2, 2025 02:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/inductor ciflow/mps Run MPS tests (subset of trunk) Merged module: compiled autograd compiled_autograd module: dynamo module: inductor oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (pipeline) release notes category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Async TP] Activations not cleared after backward when reduce_scatter_tensor saved for backward by per op SAC
4 participants
0