8000 [Async TP] Fuse matmul-reduce-scatters when reduce scatters have multiple users, and save fused node for backward instead of reduce_scatter node by danielvegamyhre · Pull Request #149875 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[Async TP] Fuse matmul-reduce-scatters when reduce scatters have multiple users, and save fused node for backward instead of reduce_scatter node #149875

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasio 8000 nally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 18 commits into from

Conversation

danielvegamyhre
Copy link
Contributor
@danielvegamyhre danielvegamyhre commented Mar 24, 2025

Fixes #149876

Stack

TL;DR

This PR implements support in async TP for saving the reduce-scatter result for backward, which previously would break the torchtitan AC policies: no AC, per op SAC, and per layer SAC.

Context

In torchtitan's LLama3 per op SAC policy, we want to save the output of reduce_scatter ops for backward, which is useful for TP. The reduce_scatter op is also saved for No AC (since all activations are saved) and per layer SAC (since we save the activations for N full layers, which do contain reduce-scatters for TP.

However, doing this causes incompatibility with Async TP for the AC policies above, for 2 reasons:

  1. The graph pattern matching specifically only matches on reduce scatter nodes with 1 user, but reduce_scatter nodes saved for backwards will have 2 users (the 2nd one being the return/output node, which saves it for backward).

  2. The subgraph replacement logic which replaces the users of the wait_tensor after the reduce-scatter with the new fused node has no mechanism to save the fused_node for backward instead of the reduce-scatter node. This means we cannot directly replace the subgraph, since we can't delete nodes which still have users (in this case, the output node is still using the reduce-scatter node).

To fix this, we do 2 things:

  1. Add additional pattern matching logic to also match reduce-scatter nodes with 2 users, so we also perform fusion when reduce-scatter is saved for backward.

  2. When replacing the subgraph with the fused node, detect if the reduce-scatter was saved for backward, and if so, save the result of the fused node for backward instead. This enables us to properly erase the subgraph and prevent the memory leak which occurred in [Async TP] Activations not cleared after backward when reduce_scatter_tensor saved for backward by per op SAC #149876

Other changes

  • Continue to throw an error if we don't find any candidate all-gathers or reduce-scatters for fusion (since TP should have both) but DON'T throw an error if we don't fuse any matmul-reduce-scatters. This is because I've found there are actually valid graphs where we do fuse reduce scatters in the forward graph but not the backward graph (in the backward pass there are reduce-scatters but the producer op is an "add" not a mm/scaled_mm).

Test plan

  1. All unit tests are passing
  2. Visualized the graphs and verified the fusion is occurring properly.
  3. Verified via manual torchtitan runs there is no memory leak / OOM occurring anymore.

cc @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov

Copy link
pytorch-bot bot commented Mar 24, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/149875

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@danielvegamyhre danielvegamyhre marked this pull request as draft March 24, 2025 19:57
@danielvegamyhre danielvegamyhre changed the title [WIP] [Async TP] Fuse matmul-reduce-scatters when reduce scatters have multiple users [Async TP] Fuse matmul-reduce-scatters when reduce scatters have multiple users, and save fused node for backward instead of reduce_scatter node Mar 25, 2025
@danielvegamyhre danielvegamyhre marked this pull request as ready for review March 25, 2025 01:55
@pytorch-bot pytorch-bot bot added oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (pipeline) release notes category labels Mar 25, 2025
@github-actions github-actions bot deleted the users branch April 27, 2025 02:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/inductor module: inductor oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (pipeline) release notes category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant
0