8000 [Async TP] More robust support for rowwise scales when fusing matmul reduce-scatter by danielvegamyhre · Pull Request #149247 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[Async TP] More robust support for rowwise scales when fusing matmul reduce-scatter #149247

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from

Conversation

danielvegamyhre
Copy link
Contributor
@danielvegamyhre danielvegamyhre commented Mar 15, 2025

Part of pytorch/torchtitan#866

Context

  • Async TP needs to support the "reshape -> scaled_mm -> reshape" pattern because scaled mm only supports 2D input tensors and 2D scales.

    • (a,b,c) => (a*b,c)
    • (a*b,c) @ (c,d) = (a*b,d)
    • (a*b,d) => (a,b,d)
  • Currently the implementation does not support scaled mm with rowwise scales for all cases of the reshape -> scaled_mm -> reshape pattern. The minimal example of this pattern is confirmed to work via this unit test, but more involved e2e examples in torchtitan fail silently (more context in final bullet point).

  • Previously, the "A tensor" node referenced in the async TP graph manipulation code is the 3D+ node before the reshape, but the "A_scale" node is the 2d node from after the reshape, so they are incompatible.

  • I previously implemented a simpler solution to this problem in [async TP] insert reshape node to handle "reshape -> scaled mm -> reshape pattern" in async TP with rowwise scales #148001, with a unit test confirming the fused node is indeed in the graph for the minimal example of the reshape->mm->reshape pattern. I also confirmed via manual e2e testing w/ torchtitan that the crash I was fixing no longer occurred. However, it turns out due to this bug in torchtitan it was causing async TP to fail silently and fall back to vanilla TP, hiding the fact that this original solution fixed the crash but the fusion would not occur for rowwise scales. Thus, more robust solution is needed to support all cases.

Solution TL;DR

  • Use the 2D 'A' tensor and corresponding 2D scales as input to the fused_matmul_reduce_scatter implementation, instead of the 3D+ tensor/scales.
  • Track the "pre mm reshape" and "post mm reshape" separately, to be referenced in the fused_scaled_matmul_reduce_scatter implementation, to update the scatter dim through the pre-mm reshape, and apply the post-mm reshape before applying the reduce scatter and returning the output tensor.
  • Separate the fused_matmul_reduce_scatter and the fused_scaled_matmul_reduce_scatter code paths, to simplify them both.
  • By fixing the bug in torchtitan (PR [Async TP] Don't save reduce scatter output for per op SAC torchtitan#965) and implementing support for rowwise scales in pytorch in this PR, together these changes will solve the problem of how to support rowwise scales with all types of AC.

Additional details for reviewers

To use the 2D A tensor while also supporting the "reshape -> mm -> reshape" pattern, the following other changes were needed:

  • Track the pre-mm reshape, as it will affect the scatter dim used in the fused_matmul_reduce_scatter impementation.
  • Track the post-mm reshape, as it will affect the output shape used in the fused_matmul_reduce_scatter impementation
  • Based on the pre-mm reshape and the original scatter dim, calculate the new scatter dim for the 2D tensor. This is needed because during the pipelined producer mm implementation, the scatter dim is moved to dim 0 (so it can be sharded along the first dim and then get chunks to do mm ops on by indexing into the first dim), then moved back to it's original place before the reduce-scatter.
  • Use the tracked post-mm reshape to reshape the stacked partial 2D outputs of the mm ops into 3D outputs needed for 1) the reduce-scatter w/ the original scatter dim, and 2) the expected output shape to prevent shape errors with subsequent ops.

Test plan

  • All existing unit tests passing.
  • Expand unit tests for rowwise scales to test more scatter dims
  • Added unit tests enforcing that async TP fails fast / throws an error if it fails to perform any fusions. Previously it just "failed silently" (fell back to vanilla TP without the user knowing) which has led to confusion, so this will improve the UX.
  • Compared loss curves of bf16 vs float8 w/ rowwise scales to confirm integrity of numerics
  • Confirmed via manual testing with torchtitan and inspecting the compile graph that the fusion is working as intended for:
    • bfloat16
    • float8 with tensorwise scales
    • float8 with rowwise scales

Loss curves

Loss curves are virtually identical for bf16 + vanilla TP versus float8 with rowwise scales + async TP:

loss_async_tp

Performance

Per op SAC

Performance benchmarks for torchtitan Llama3 8b training runs on 4 H100s with per op SAC, using FSDP degree=2, TP degree=2:

  • bf16 (vanilla TP): TPS 5161.5, peak memory 50.53 GB
  • bf16 (async TP): TPS 5229.5, peak memory 50.68 GB
  • float8 tensorwise (vanilla TP): TPS: 5959.5, peak memory: 50.47 GB
  • float8 tensorwise (async TP): TPS 5964.5, peak memory 50.47 GB
  • float8 rowwise (vanilla TP): TPS: 4962.0, peak memory: 50.55 GB
  • float8 rowwise (async TP): TPS 4966.5, peak memory 50.65 GB

Full AC

Llama3 70b training runs on 128 H100s with full AC, using FSDP=16, TP=8

  • bf16 (vanilla TP): 598 TPS, peak memory 71.51 GB
  • bf16 (async TP): TPS 673, peak memory 71.08 (+12.54% TPS vs vanilla TP)
  • float8 tensorwise (vanilla TP): 820 TPS, peak memory 55.26 GB
  • float8 tensorwise (async TP): 950 TPS, peak memory 55.91 GB (+15.85% TPS vs vanilla TP)
  • float8 rowwise (vanilla TP): TPS: 540 TPS, peak memory 71.46 GB
  • float8 rowwise (async TP): 560 TPS, peak memory 70.65 GB (+3.7% TPS vs vanilla TP but still unexpectedly lower than bf16)

As you can see, float8 rowwise is working but performance needs to be improved further.

Other changes

  • Added logging so the user will know why fusion failed if it does.
  • Remove logic which inserted a reshape node targeting "A scale" to get it to be in 3D like the "A tensor" since it's no longer needed.

Long term plan

  • Add a scaled_matmul op in pytorch, which will natively support a 3D+ "A tensor" and allow us to simplify the async TP implementation by avoiding the reshape -> scaled_mm -> reshape pattern and the special handling for it.

Visualizing fused nodes in graphs for torchtitan training runs

Below are examples of the visualized graph generated by torch compile for torchtitan llama3 8b training runs with per op SAC. These graphs provide additional evidence (beyond the new unit tests added) that the implementation is working correctly.

bf16

bf16-fusion

float8 with tensorwise scales

tensorwise-node

float8 with rowwise scales

rowwise

cc @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov

Copy link
pytorch-bot bot commented Mar 15, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/149247

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (2 Unrelated Failures)

As of commit 8943e6d with merge base c0af782 (image):

UNSTABLE - The following jobs are marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added ciflow/inductor module: inductor oncall: distributed Add this issue/PR to distributed oncall triage queue labels Mar 15, 2025
@danielvegamyhre danielvegamyhre marked this pull request as draft March 15, 2025 05:05
@danielvegamyhre danielvegamyhre changed the title [WIP] [Async TP] Support fused_scaled_mm_reduce_scatter for float8 rowwise scales [WIP] [Async TP] Fix fused_scaled_mm_reduce_scatter for float8 rowwise scales Mar 15, 2025
@danielvegamyhre danielvegamyhre changed the title [WIP] [Async TP] Fix fused_scaled_mm_reduce_scatter for float8 rowwise scales [WIP] [Async TP] Support fused_scaled_mm_reduce_scatter for float8 rowwise scales Mar 15, 2025
@pytorch-bot pytorch-bot bot added the release notes: distributed (pipeline) release notes category label Mar 16, 2025
@danielvegamyhre danielvegamyhre changed the title [WIP] [Async TP] Support fused_scaled_mm_reduce_scatter for float8 rowwise scales [Async TP] Support fused_scaled_mm_reduce_scatter for float8 rowwise scales Mar 18, 2025
@danielvegamyhre danielvegamyhre marked this pull request as ready for review March 18, 2025 23:28
@danielvegamyhre danielvegamyhre changed the title [Async TP] Support fused_scaled_mm_reduce_scatter for float8 rowwise scales [Async TP] Support rowwise scales for 3D+ input tensors in fused_scaled_matmul_reduce_scatter Mar 20, 2025
@danielvegamyhre danielvegamyhre changed the title [Async TP] Support rowwise scales for 3D+ input tensors in fused_scaled_matmul_reduce_scatter [Async TP] Support rowwise scales for 3D+ input tensors Mar 20, 2025
@danielvegamyhre danielvegamyhre changed the title [Async TP] Support rowwise scales for 3D+ input tensors [Async TP] Support rowwise scales for 3D+ input tensors when per op SAC is used Mar 20, 2025
@danielvegamyhre danielvegamyhre changed the title [Async TP] Support rowwise scales for 3D+ input tensors when per op SAC is used [Async TP] More robust support for rowwise scales for 3D+ input tensors using reshape -> mm -> reshape pattern Mar 20, 2025
@danielvegamyhre danielvegamyhre changed the title [Async TP] More robust support for rowwise scales for 3D+ input tensors using reshape -> mm -> reshape pattern [Async TP] More robust support for rowwise scales for 3D+ input tensors using reshape -> scaled_mm -> reshape pattern Mar 20, 2025
@danielvegamyhre
Copy link
Contributor Author
danielvegamyhre commented Mar 20, 2025

cc interested parties @vkuzo @drisspg @lessw2020 @yifuwang I'd appreciate your thoughts and/or reviews on this

I've tried to make these changes understandable via the PR description and code comments, but am happy to chat/meet to answer questions or discuss this in more detail.

danielvegamyhre added a commit to pytorch/torchtitan that referenced this pull request Mar 21, 2025
Part of #866

## Summary
- This PR fixes async TP for bf16 and float8 w/ tensorwise scales.
- Support for async TP + float8 w/ rowwise scales requires a different
change in pytorch here pytorch/pytorch#149247

## Details
Saving reduces scatter op in per op SAC was breaking the graph pattern
matching used implement part of async TP, causing it to "fail silently"
- async TP would fail to identify subgraphs to replace with
`fused_scaled_mm_reduce_scatter` but continue to run using vanilla TP
anyway, unbeknownst to the user.

See root cause analysis here
#866 (comment)

## Test plan
- Confirmed via manual testing with torchtitan + inspecting compile
graphs, that the fusion is occurring correctly for bf16 and float8 w/
tensorwise scales.

## Next steps
- Investigate if it's possible to only store the activations of specific
reduce-scatter -> wait-tensor ops, since this bug has shown we cannot
store them all.
@kwen2501 kwen2501 self-requested a review March 21, 2025 06:30
@danielvegamyhre danielvegamyhre changed the title [Async TP] More robust support for rowwise scales for 3D+ input tensors using reshape -> scaled_mm -> reshape pattern [Async TP] More robust support for rowwise scales when fusing matmul reduce-scatter Mar 25, 2025
Comment on lines 296 to +299
A_node: torch.fx.Node
B_node: torch.fx.Node
pre_mm_reshape: Optional[torch.fx.Node]
post_mm_reshape: Optional[torch.fx.Node]
Copy link
Contributor
@kwen2501 kwen2501 Mar 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you educate me what those four nodes stand for?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A_node and B_node are the nodes representing the input tensors to the matmul (i.e., C = torch.mm(A, B) or C. = torch._scaled_mm(A, B, A_scale, B_scale)).

"str reduce_op, int scatter_dim, str group_name, "
"str reduce_op, int orig_scatter_dim, int scatter_dim_after_maybe_reshape, str group_name, int[]? output_shape, "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you comment on the need of this API change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • The scatter dim is originally assigned based on a 3D+ output tensor shape, since the "A" tensor (left operand in the matmul) has 3+ dims (e.g., (a,b,c) @ (c,d) = (a,b,d)).
  • However, the input tensor "A" (left operand in matmul) is reshaped from 3D+ to 2D to prepare for the mm op, since mm/scaled_mm only accept 2D inputs.
  • After the reshape, the scatter dim may be invalid now (e.g., if we had a 3D tensor with scatter_dim=2, and now we reshaped it to a 2D tensor, the scatter_dim is now out of bounds). This is a problem because the scatter dim is swapped into dim 0 before being sharded along dim 0 and going through the pipelined matmul implementation.
  • Therefore, in order to avoid index out of bounds errors during this step, we need to use an updated scatter dim which is valid after the reshape (scatter_dim_after_maybe_reshape).
  • Finally, the output of the matmul, which is originally a 2D output, is reshaped back to the originally intended 3D+ output shape, then reduce-scattered along the orginally intended scatter dim (orig_scatter_dim).

Let me know if that makes sense!

def _fused_matmul_reduce_scatter_impl(
mm_out_op: torch._ops.OpOverload,
A: torch.Tensor,
B: torch.Tensor,
A_scale: Optional[torch.Tensor],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we separating the impl into two paths? Looks like it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, that's correct. This simplifies the code for both paths, and also ensures both implementations only require args that they actually use.

final_out_shape = [*output_shape[:-1], B.shape[-1]]
final_out_shape[orig_scatter_dim] //= group.size()
out = reduced_out.view(*final_out_shape)
return out
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This impl looks mostly new. Do we have a test somewhere?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this test (which I added in a previous PR) tests this specific code path: https://github.com/pytorch/pytorch/pull/149247/files#diff-115f1d0852382c9b58f22640d80999d879b33618e5f6c633fc9e4d0ca9781cecR407

In this PR, I expanded the test to test all possible scatter dims.

@danielvegamyhre
Copy link
Contributor Author

@pytorchbot merge -f "dr ci confirmed the only failing tests are unrelated to my change and are just flaky"

< 67F4 /svg>

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit that referenced this pull request Mar 30, 2025
…iple users, and save fused node for backward instead of reduce_scatter node (#149946)

Fixes #149876

## Stack
- [previous PR in stack] #149247

## TL;DR
This PR implements support in async TP for saving the reduce-scatter result for backward, which previously would break the torchtitan AC policies: no AC, per op SAC, and per layer SAC.

## Context
In torchtitan's LLama3 per op SAC policy, we want to save the output of `reduce_scatter` ops for backward, which is useful for TP. The reduce_scatter op is also saved for No AC (since all activations are saved) and per layer SAC (since we save the activations for N full layers, which do contain reduce-scatters for TP.

However, doing this causes incompatibility with Async TP for the AC policies above, for 2 reasons:

1) The graph pattern matching specifically only matches on reduce scatter nodes with 1 user, but reduce_scatter nodes saved for backwards will have 2 users (the 2nd one being the return/output node, which saves it for backward).

2) The subgraph replacement logic which replaces the users of the `wait_tensor` after the reduce-scatter with the new fused node has no mechanism to save the fused_node for backward instead of the reduce-scatter node. This means we cannot directly replace the subgraph, since we can't delete nodes which still have users (in this case, the output node is still using the reduce-scatter node).

To fix this, we do 2 things:

1) Add additional pattern matching logic to also match reduce-scatter nodes with 2 users, so we also perform fusion when reduce-scatter is saved for backward.

2) When replacing the subgraph with the fused node, detect if the reduce-scatter was saved for backward, and if so, save the result of the fused node for backward instead
A3D4
. This enables us to properly erase the subgraph and prevent the memory leak which occurred in #149876

## Other changes
- Continue to throw an error if we don't find any candidate all-gathers or reduce-scatters for fusion (since TP should have both) but DON'T throw an error if we don't fuse any matmul-reduce-scatters. This is because I've found there are actually valid graphs where we do fuse reduce scatters in the forward graph but not the backward graph (in the backward pass there are reduce-scatters but the producer op is an "add" not a mm/scaled_mm).

## Test plan

1. All unit tests are passing
2. Visualized the graphs and verified the fusion is occurring properly.
3. Verified via manual torchtitan runs there is no memory leak / OOM occurring anymore.

Pull Request resolved: #149946
Approved by: https://github.com/fegin
amathewc pushed a commit to amathewc/pytorch that referenced this pull request Apr 17, 2025
…reduce-scatter (pytorch#149247)

Part of pytorch/torchtitan#866

## Context
- Async TP needs to support the "reshape -> scaled_mm -> reshape" pattern because scaled mm only supports 2D input tensors and 2D scales.
    - (a,b,c) => (a*b,c)
    - (a\*b,c) @ (c,d) = (a\*b,d)
    - (a\*b,d) => (a,b,d)

- Currently the implementation does not support scaled mm with rowwise scales **for all cases** of the reshape -> scaled_mm -> reshape pattern. The minimal example of this pattern is confirmed to work via this [unit test](https://github.com/pytorch/pytorch/blob/00a2c68f67adbd38847845016fd1ab9275cefbab/test/distributed/tensor/parallel/test_micro_pipeline_tp.py#L406), but more involved e2e examples in torchtitan fail silently (more context in final bullet point).
- Previously, the "A tensor" **node** referenced in the async TP graph manipulation code is the 3D+ node before the reshape, but the "A_scale" node is the 2d node from after the reshape, so they are incompatible.
- I previously implemented a simpler solution to this problem in pytorch#148001, with a [unit test](https://github.com/pytorch/pytorch/pull/148001/files#diff-115f1d0852382c9b58f22640d80999d879b33618e5f6c633fc9e4d0ca9781cecR406) confirming the fused node is indeed in the graph for the minimal example of the reshape->mm->reshape pattern. I also confirmed via manual e2e testing w/ torchtitan that the crash I was fixing no longer occurred. However, it turns out due to this [bug in torchtitan](pytorch/torchtitan#866)  it was causing async TP to fail silently and fall back to vanilla TP, hiding the fact that this original solution fixed the crash but the fusion would not occur for rowwise scales. Thus, more robust solution is needed to support all cases.

## Solution TL;DR
- Use the 2D 'A' tensor and corresponding 2D scales as input to the fused_matmul_reduce_scatter implementation, instead of the 3D+ tensor/scales.
- Track the "pre mm reshape" and "post mm reshape" separately, to be referenced in the `fused_scaled_matmul_reduce_scatter` implementation, to update the scatter dim through the pre-mm reshape, and apply the post-mm reshape before applying the reduce scatter and returning the output tensor.
- Separate the `fused_matmul_reduce_scatter` and the `fused_scaled_matmul_reduce_scatter` code paths, to simplify them both.
- By fixing the bug in torchtitan (PR pytorch/torchtitan#965) and implementing support for rowwise scales in pytorch in this PR, together these changes will solve the problem of how to support rowwise scales with all types of AC.

## Additional details for reviewers
To use the 2D A tensor while also supporting the "reshape -> mm -> reshape" pattern, the following other changes were needed:
- Track the pre-mm reshape, as it will affect the scatter dim used in the fused_matmul_reduce_scatter impementation.
- Track the post-mm reshape, as it will affect the output shape used in the fused_matmul_reduce_scatter impementation
- Based on the pre-mm reshape and the original scatter dim, calculate the new scatter dim for the 2D tensor. This is needed because during the pipelined producer mm implementation, the scatter dim is moved to dim 0 (so it can be sharded along the first dim and then get chunks to do mm ops on by indexing into the first dim), then moved back to it's original place before the reduce-scatter.
- Use the tracked post-mm reshape to reshape the stacked partial 2D outputs of the mm ops into 3D outputs needed for 1) the reduce-scatter w/ the original scatter dim, and 2) the expected output shape to prevent shape errors with subsequent ops.

## Test plan
- All existing unit tests passing.
- Expand unit tests for rowwise scales to test more scatter dims
- Added unit tests enforcing that async TP fails fast / throws an error if it fails to perform any fusions. Previously it just "failed silently" (fell back to vanilla TP without the user knowing) which has led to confusion, so this will improve the UX.
- Compared loss curves of bf16 vs float8 w/ rowwise scales to confirm integrity of numerics
- Confirmed via manual testing with torchtitan and inspecting the compile graph that the fusion is working as intended for:
    - bfloat16
    - float8 with tensorwise scales
    - float8 with rowwise scales

## Loss curves

Loss curves are virtually identical for bf16 + vanilla TP versus float8 with rowwise scales + async TP:

<img width="1017" alt="loss_async_tp" src="https://github.com/user-attachments/assets/4995db78-7012-490f-a370-f4fecc289a22" />

## Performance

#### Per op SAC
Performance benchmarks for torchtitan Llama3 8b training runs on 4 H100s with per op SAC, using FSDP degree=2, TP degree=2:
- bf16 (vanilla TP): TPS 5161.5, peak memory 50.53 GB
- bf16 (async TP): TPS  5229.5, peak memory 50.68 GB
- float8 tensorwise (vanilla TP): TPS: 5959.5, peak memory: 50.47 GB
- float8 tensorwise (async TP): TPS 5964.5, peak memory 50.47 GB
- float8 rowwise (vanilla TP): TPS: 4962.0, peak memory: 50.55 GB
- float8 rowwise (async TP): TPS 4966.5, peak memory 50.65 GB

#### Full AC
Llama3 70b training runs on 128 H100s with full AC, using FSDP=16, TP=8
- bf16 (vanilla TP): 598 TPS, peak memory 71.51 GB
- bf16 (async TP): TPS  673, peak memory 71.08 (+12.54% TPS vs vanilla TP)
- float8 tensorwise (vanilla TP): 820 TPS, peak memory  55.26 GB
- float8 tensorwise (async TP): 950 TPS, peak memory 55.91 GB (+15.85% TPS vs vanilla TP)
- float8 rowwise (vanilla TP): TPS: 540 TPS, peak memory 71.46 GB
- float8 rowwise (async TP): 560 TPS, peak memory 70.65 GB (+3.7% TPS vs vanilla TP but still unexpectedly lower than bf16)

As you can see, float8 rowwise is working but performance needs to be improved further.

## Other changes
- Added logging so the user will know why fusion failed if it does.
- Remove logic which inserted a reshape node targeting "A scale" to get it to be in 3D like the "A tensor" since it's no longer needed.

## Long term plan
- Add a `scaled_matmul` op in pytorch, which will natively support a 3D+ "A tensor" and allow us to simplify the async TP implementation by avoiding the reshape -> scaled_mm -> reshape pattern and the special handling for it.

## Visualizing fused nodes in graphs for torchtitan training runs

Below are examples of the visualized graph generated by torch compile for torchtitan llama3 8b training runs with per op SAC. These graphs provide additional evidence (beyond the new unit tests added) that the implementation is working correctly.

### bf16

<img width="900" alt="bf16-fusion" src="https://github.com/user-attachments/assets/a3bed917-28eb-4a56-8d6e-2d2bf498385c" />

### float8 with tensorwise scales

<img width="900" alt="tensorwise-node" src="https://github.com/user-attachments/assets/b212ec4a-1899-44de-a4de-18c74e1de68a" />

### float8 with rowwise scales

<img width="900" alt="rowwise" src="https://github.com/user-attachments/assets/ed3354a3-894b-4ec9-86d0-f80364bf3d83" />

Pull Request resolved: pytorch#149247
Approved by: https://github.com/kwen2501
amathewc pushed a commit to amathewc/pytorch that referenced this pull request Apr 17, 2025
…iple users, and save fused node for backward instead of reduce_scatter node (pytorch#149946)

Fixes pytorch#149876

## Stack
- [previous PR in stack] pytorch#149247

## TL;DR
This PR implements support in async TP for saving the reduce-scatter result for backward, which previously would break the torchtitan AC policies: no AC, per op SAC, and per layer SAC.

## Context
In torchtitan's LLama3 per op SAC policy, we want to save the output of `reduce_scatter` ops for backward, which is useful for TP. The reduce_scatter op is also saved for No AC (since all activations are saved) and per layer SAC (since we save the activations for N full layers, which do contain reduce-scatters for TP.

However, doing this causes incompatibility with Async TP for the AC policies above, for 2 reasons:

1) The graph pattern matching specifically only matches on reduce scatter nodes with 1 user, but reduce_scatter nodes saved for backwards will have 2 users (the 2nd one being the return/output node, which saves it for backward).

2) The subgraph replacement logic which replaces the users of the `wait_tensor` after the reduce-scatter with the new fused node has no mechanism to save the fused_node for backward instead of the reduce-scatter node. This means we cannot directly replace the subgraph, since we can't delete nodes which still have users (in this case, the output node is still using the reduce-scatter node).

To fix this, we do 2 things:

1) Add additional pattern matching logic to also match reduce-scatter nodes with 2 users, so we also perform fusion when reduce-scatter is saved for backward.

2) When replacing the subgraph with the fused node, detect if the reduce-scatter was saved for backward, and if so, save the result of the fused node for backward instead. This enables us to properly erase the subgraph and prevent the memory leak which occurred in pytorch#149876

## Other changes
- Continue to throw an error if we don't find any candidate all-gathers or reduce-scatters for fusion (since TP should have both) but DON'T throw an error if we don't fuse any matmul-reduce-scatters. This is because I've found there are actually valid graphs where we do fuse reduce scatters in the forward graph but not the backward graph (in the backward pass there are reduce-scatters but the producer op is an "add" not a mm/scaled_mm).

## Test plan

1. All unit tests are passing
2. Visualized the graphs and verified the fusion is occurring properly.
3. Verified via manual torchtitan runs there is no memory leak / OOM occurring anymore.

Pull Request resolved: pytorch#149946
Approved by: https://github.com/fegin
@github-actions github-actions bot deleted the scatter-dim branch May 2, 2025 02:16
pytorchmergebot pushed a commit that referenced this pull request May 15, 2025
…educe_scatter (#153595)

## Summary
- The unit test `pytest test/distributed/test_symmetric_memory.py -k test_fused_scaled_matmul_reduce_scatter_scatter` was not running for some reason when #149247 was merged, giving false green CI signals. When it was ran manually recently, the test failed, highlighting a bug causing incorrect numerics when `scatter_dim=1`.
- This PR fixes the bug, which was related to how we swap dims 0<=>scatter_dim at the beginning of the custom op (for more efficient cross-device data movement I believe), then swap it back prior to reduction.

## Test plan
- I confirmed the unit test `pytest test/distributed/test_symmetric_memory.py -k test_fused_scaled_matmul_reduce_scatter_scatter` is now passing.
- I confirmed e2e training w/ torchtitan looks good ([logs](https://www.internalfb.com/phabricator/paste/view/P1812054188))
- I analyzed the tlparse to verify the fused_all_gather_matmul and fused_scaled_matmul_reduce_scatter both appear at least once in the post grad graphs ([tlparse](https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/.tmpVbUsdG/dedicated_log_torch_trace_65oh3qj_.log/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000))

## Next steps
1. I think for async TP `fused_scaled_matmul_reduce_scatter` we may only need `scatter_dim_after_maybe_reshape` and not `orig_scatter_dim` after all. I can confirm this and refactor if it is the case.
2. This op is specifically designed for async TP, and many of the arguments don't make sense for a user trying to use this as a standalone op. IMO we should have separate standalone custom op without all the extra function args and internal logic that doesn't apply to non-async TP cases.
3. In a follow up PR I want to add shape annotations to each line (e.g. `# (B, T, H)` etc) to make this easier to debug in the future.

Pull Request resolved: #153595
Approved by: https://github.com/fegin
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/inductor Merged module: inductor oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (pipeline) release notes category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants
0