You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
[Async TP] Fuse matmul-reduce-scatters when reduce scatters have multiple users, and save fused node for backward instead of reduce_scatter node (#149946)
Fixes#149876
## Stack
- [previous PR in stack] #149247
## TL;DR
This PR implements support in async TP for saving the reduce-scatter result for backward, which previously would break the torchtitan AC policies: no AC, per op SAC, and per layer SAC.
## Context
In torchtitan's LLama3 per op SAC policy, we want to save the output of `reduce_scatter` ops for backward, which is useful for TP. The reduce_scatter op is also saved for No AC (since all activations are saved) and per layer SAC (since we save the activations for N full layers, which do contain reduce-scatters for TP.
However, doing this causes incompatibility with Async TP for the AC policies above, for 2 reasons:
1) The graph pattern matching specifically only matches on reduce scatter nodes with 1 user, but reduce_scatter nodes saved for backwards will have 2 users (the 2nd one being the return/output node, which saves it for backward).
2) The subgraph replacement logic which replaces the users of the `wait_tensor` after the reduce-scatter with the new fused node has no mechanism to save the fused_node for backward instead of the reduce-scatter node. This means we cannot directly replace the subgraph, since we can't delete nodes which still have users (in this case, the output node is still using the reduce-scatter node).
To fix this, we do 2 things:
1) Add additional pattern matching logic to also match reduce-scatter nodes with 2 users, so we also perform fusion when reduce-scatter is saved for backward.
2) When replacing the subgraph with the fused node, detect if the reduce-scatter was saved for backward, and if so, save the result of the fused node for backward instead. This enables us to properly erase the subgraph and prevent the memory leak which occurred in #149876
## Other changes
- Continue to throw an error if we don't find any candidate all-gathers or reduce-scatters for fusion (since TP should have both) but DON'T throw an error if we don't fuse any matmul-reduce-scatters. This is because I've found there are actually valid graphs where we do fuse reduce scatters in the forward graph but not the backward graph (in the backward pass there are reduce-scatters but the producer op is an "add" not a mm/scaled_mm).
## Test plan
1. All unit tests are passing
2. Visualized the graphs and verified the fusion is occurring properly.
3. Verified via manual torchtitan runs there is no memory leak / OOM occurring anymore.
Pull Request resolved: #149946
Approved by: https://github.com/fegin
0 commit comments