8000 [async TP] insert reshape node to handle "reshape -> scaled mm -> reshape pattern" in async TP with rowwise scales by danielvegamyhre · Pull Request #148001 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[async TP] insert reshape node to handle "reshape -> scaled mm -> reshape pattern" in async TP with rowwise scales #148001

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from

Conversation

danielvegamyhre
Copy link
Contributor
@danielvegamyhre danielvegamyhre commented Feb 26, 2025

Fixes pytorch/torchtitan#864

Summary

While testing torchtitan with float8 training with rowwise scaling + async TP, a bug was discovered. The symptom was the scaling factor dims did not match the dims of the tensor the scales were to be applied to.

My root cause analysis determined the reason is that when async TP graph manipulation constructs the fused_scaled_matmul_reduce_scatter op, it does not yet handle the "reshape -> scaled mm -> reshape" pattern used in torchao here - specifically when row-wise scales are being used.

TL;DR of root cause

  • When a Float8Tensor is reshaped, the scale is reshaped along with it so the dimensions are aligned.
  • In the graph manipulation logic of the micropipeline TP post grad pass, the scaled_mm A tensor node is referencing the tensor before to the reshape op, but referencing the A_scale node after the reshape op.

Example

  • Concrete example:
    • A tensor is a Float8Tensor with shape (1,8192,2048) and scale of shape (1,8192,1) when a matmul op is called in torchao here. Torchao does a reshape -> scaled mm -> reshape here. When a Float8Tensor is reshaped, its scale is reshaped along with it here. So the first reshape makes the "A tensor" (1,8192,2048) => (8192,2048) and the scale (1,8192,1) => (8192,1).
    • During post grad pass in async TP:
      • A_node has shape (1,8192,2048) (tensor from before this reshape)
      • A_scale has shape (8192,1) (due to reshape op above, which caused the scale to be reshaped from (1,8192,1) => (8192,1)).

Solution

Note: the compiler inserts a reciprocal op after the reshape, so we can't simply use the node before the reshape as the A_scale_node, otherwise it will affect the numerics.

  • Short-term solution: if the specific pattern showne below is detected, insert a reshape node after the reciprocal, to reshape the reciprocal output back to the originals shape before the reshape.
    • reshape is just a view, so there should be no impact on performance
Before:
    reshape (a,bc,) to (a*b,c) -> reciprocal

After:
    reshape (a,bc,) to (a*b,c) -> reciprocal -> reshape (a*b,c) to (a,b,c)
  • Long-term solution: implement a torch._scaled_matmul which can support 3D+ A tensor

Test plan

  • Added unit test which exercises this new path
  • Manually tested with torchtitan with float8 rowwise + async TP

cc @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov

Copy link
pytorch-bot bot commented Feb 26, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/148001

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 2a841bd with merge base 1a68837 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@danielvegamyhre danielvegamyhre changed the title Insert reshape op to handle reshape -> scaled mm -> reshape pattern in async TP + float8 rowwise Insert reshape op to handle reshape -> scaled mm -> reshape pattern in async TP with float8 rowwise scales Feb 26, 2025
@danielvegamyhre danielvegamyhre changed the title Insert reshape op to handle reshape -> scaled mm -> reshape pattern in async TP with float8 rowwise scales Insert reshape op to handle reshape -> scaled mm -> reshape pattern in async TP + float8 rowwise scales Feb 26, 2025
@pytorch-bot pytorch-bot bot added oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (pipeline) release notes category labels Feb 26, 2025
@danielvegamyhre danielvegamyhre changed the title Insert reshape op to handle reshape -> scaled mm -> reshape pattern in async TP + float8 rowwise scales Insert reshape node to handle reshape -> scaled mm -> reshape pattern in async TP + float8 rowwise scales Feb 26, 2025
@danielvegamyhre danielvegamyhre changed the title Insert reshape node to handle reshape -> scaled mm -> reshape pattern in async TP + float8 rowwise scales [async TP] insert reshape node to handle reshape -> scaled mm -> reshape pattern in async TP + float8 rowwise scales Feb 26, 2025
@danielvegamyhre danielvegamyhre changed the title [async TP] insert reshape node to handle reshape -> scaled mm -> reshape pattern in async TP + float8 rowwise scales [async TP] insert reshape node to handle "reshape -> scaled mm -> reshape pattern" in async TP + float8 rowwise scales Feb 26, 2025
@danielvegamyhre danielvegamyhre changed the title [async TP] insert reshape node to handle "reshape -> scaled mm -> reshape pattern" in async TP + float8 rowwise scales [async TP] insert reshape node to handle "reshape -> scaled mm -> reshape pattern" in async TP with rowwise scales Feb 26, 2025
@vkuzo
Copy link
Contributor
vkuzo commented Feb 27, 2025

In the graph manipulation logic of the micropipeline TP post grad pass, the scaled_mm A tensor node is referencing the tensor before to the reshape op, but referencing the A_scale node after the reshape op.

just curious, why is that? If this is changed to always look before the reshape, would it simplify things?

@danielvegamyhre
Copy link
Contributor Author
danielvegamyhre commented Feb 27, 2025

In the graph manipulation logic of the micropipeline TP post grad pass, the scaled_mm A tensor node is referencing the tensor before to the reshape op, but referencing the A_scale node after the reshape op.

just curious, why is that?

it's because the "fused scaled mm reduce scatter" implementation is specifically designed for the original 3D+ "A tensor." For example, the scatter_dim is specifically for the original tensor shape and is not updated through the reshapes, so using the 2D tensor would break the collectives. So the graph manipulation code specifically references the 3D+ A_node from before the reshape.

For the scale, though, previously this code was only designed for handling scalar tensors for tensorwise-scaling (as can be seen in the test code it only tests scalar value scales). So they just directly use the A_scale_node without regard for the shape.

If this is changed to always look before the reshape, would it simplify things?

Unfortunately no, because the compiler inserts a reciprocal op after the reshape, so using the node before the reshape will affect the numerics. All the downstream nodes in the graph depend on this scale having gone through the reciprocal op, so we can't exclude it from the graph.

If you're interested, here are some details on other solutions I tried which don't work for various reasons:

Other solutions tried:

  1. Use 3D+ tensor node and 3D+ scale node from before reshapes (PR https://github.com/pytorch/pytorch/pull/147794): tried it and it won’t work because the scale node from before the reshape has not gone through the reciprocal inserted by the compiler between the reshape and the scaled_mm, so this affects the numerics. More generally, this method will “cut out” any tensor ops that occur between the reshape and the scaled mm, leading to unexpected behavior.

  2. Use 2D tensor node and 2D scale node from after the reshapes: tried it and it won’t work because the scatter_dim is based on the original 3D+ tensor shape, so scatter_dim will need to be kept in sync with arbitrary reshapes, which would be complicated / not feasible.

  3. Use 2D tensor node and 2D scale node and plumb through output shape of “C” tensor to reshape it here in the fused matmul reduce scatter implementation: tried it and it will require a large rewrite of the symmetric memory code, which includes complicated low-level pipelining code that assumes the “A” tensor retains its original 3D+ shape and metadata, it’s not as simply as just calling C.view(…).

@yifuwang
Copy link
Collaborator

Very impressive investigation and nice fix! Thank you!

@danielvegamyhre
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Feb 27, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command
For more information see pytorch-bot wiki.

@danielvegamyhre
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: trunk / macos-py3-arm64 / test (default, 2, 3, macos-m1-stable)

Details for Dev Infra team Raised by workflow job

@danielvegamyhre
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@davidberard98
Copy link
Contributor

@pytorchbot revert -m 'looks like another lint error' -c ghfirst

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

pytorchmergebot added a commit that referenced this pull request Mar 1, 2025
…m -> reshape pattern" in async TP with rowwise scales (#148001)"

This reverts commit b8efebe.

Reverted #148001 on behalf of https://github.com/davidberard98 due to looks like another lint error ([comment](#148001 (comment)))
@pytorchmergebot
Copy link
Collaborator

@danielvegamyhre your PR has been successfully reverted.

@danielvegamyhre
Copy link
Contributor Author
danielvegamyhre commented Mar 1, 2025

@pytorchbot revert -m 'looks like another lint error' -c ghfirst

@davidberard98 any idea why/how the linter is passing locally and in CI, but not when merging? I rebased to include the ruff update

@davidberard98
Copy link
Contributor

@pytorchbot rebase

@davidberard98
Copy link
Contributor

Not sure… maybe you can try checking out the commit on main that was failing before the revert?

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased swap onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout swap && git pull --rebase)

@danielvegamyhre
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

majing921201 pushed a commit to majing921201/pytorch that referenced this pull request Mar 4, 2025
…hape pattern" in async TP with rowwise scales (pytorch#148001)

Fixes pytorch/torchtitan#864

## Summary
While testing torchtitan with float8 training with rowwise scaling + async TP, a [bug](pytorch/torchtitan#864) was discovered. The symptom was the scaling factor dims did not match the dims of the tensor the scales were to be applied to.

My [root cause analysis](pytorch/torchtitan#864 (comment)) determined the reason is that when async TP graph manipulation constructs the `fused_scaled_matmul_reduce_scatter` op, it does not yet handle the "reshape -> scaled mm -> reshape" pattern used in torchao [here](https://github.com/pytorch/ao/blob/ed361ff5c7dd33aba9b4a0da2bd744de5a5debfb/torchao/float8/float8_linear.py#L122-L124) - specifically when row-wise scales are being used.

## TL;DR of root cause
- When a Float8Tensor is reshaped, the scale is reshaped along with it so the dimensions are aligned.
- In the graph manipulation logic of the micropipeline TP post grad pass, the scaled_mm `A tensor` node is referencing the tensor _before_ to the reshape op, but referencing the `A_scale` node _after_ the reshape op.

## Example
- Concrete example:
    - `A tensor` is a Float8Tensor with shape (1,8192,2048) and scale of shape (1,8192,1) when a matmul op is called in torchao [here](https://github.com/pytorch/ao/blob/8706d3f3b087b876d625c720e98236c265c0ba98/torchao/float8/float8_linear.py#L70). Torchao does a reshape -> scaled mm -> reshape [here](https://github.com/pytorch/ao/blob/ed361ff5c7dd33aba9b4a0da2bd744de5a5debfb/torchao/float8/float8_linear.py#L122). When a Float8Tensor is reshaped, its scale is reshaped along with it [here](https://github.com/pytorch/ao/blob/8706d3f3b087b876d625c720e98236c265c0ba98/torchao/float8/float8_ops.py#L152). So the first reshape makes the "A tensor" (1,8192,2048) => (8192,2048) and the scale (1,8192,1) => (8192,1).
    - During post grad pass in async TP:
        - `A_node` has shape (1,8192,2048) (tensor from before this [reshape](https://github.com/pytorch/ao/blob/ed361ff5c7dd33aba9b4a0da2bd744de5a5debfb/torchao/float8/float8_linear.py#L122))
        - `A_scale` has shape (8192,1) (due to reshape op above, which caused the scale to be reshaped from (1,8192,1) => (8192,1)).

## Solution

**Note:** the compiler inserts a `reciprocal` op after the reshape, so we can't simply use the node before the reshape as the `A_scale_node`, otherwise it will affect the numerics.

- Short-term solution: if the specific pattern showne below is detected, insert a reshape node after the reciprocal, to reshape the reciprocal output back to the originals shape before the reshape.
    - reshape is just a view, so there should be no impact on performance
```
Before:
    reshape (a,bc,) to (a*b,c) -> reciprocal

After:
    reshape (a,bc,) to (a*b,c) -> reciprocal -> reshape (a*b,c) to (a,b,c)
```

- Long-term solution: implement a `torch._scaled_matmul` which can support 3D+ `A tensor`

## Test plan
- Added unit test which exercises this new path
- Manually tested with torchtitan with float8 rowwise + async TP

Pull Request resolved: pytorch#148001
Approved by: https://github.com/yifuwang
majing921201 pushed a commit to majing921201/pytorch that referenced this pull request Mar 4, 2025
…m -> reshape pattern" in async TP with rowwise scales (pytorch#148001)"

This reverts commit 6e037ac.

Reverted pytorch#148001 on behalf of https://github.com/wdvr due to lint error ([comment](pytorch#148001 (comment)))
majing921201 pushed a commit to majing921201/pytorch that referenced this pull request Mar 4, 2025
…hape pattern" in async TP with rowwise scales (pytorch#148001)

Fixes pytorch/torchtitan#864

## Summary
While testing torchtitan with float8 training with rowwise scaling + async TP, a [bug](pytorch/torchtitan#864) was discovered. The symptom was the scaling factor dims did not match the dims of the tensor the scales were to be applied to.

My [root cause analysis](pytorch/torchtitan#864 (comment)) determined the reason is that when async TP graph manipulation constructs the `fused_scaled_matmul_reduce_scatter` op, it does not yet handle the "reshape -> scaled mm -> reshape" pattern used in torchao [here](https://github.com/pytorch/ao/blob/ed361ff5c7dd33aba9b4a0da2bd744de5a5debfb/torchao/float8/float8_linear.py#L122-L124) - specifically when row-wise scales are being used.

## TL;DR of root cause
- When a Float8Tensor is reshaped, the scale is reshaped along with it so the dimensions are aligned.
- In the graph manipulation logic of the micropipeline TP post grad pass, the scaled_mm `A tensor` node is referencing the tensor _before_ to the reshape op, but referencing the `A_scale` node _after_ the reshape op.

## Example
- Concrete example:
    - `A tensor` is a Float8Tensor with shape (1,8192,2048) and scale of shape (1,8192,1) when a matmul op is called in torchao [here](https://github.com/pytorch/ao/blob/8706d3f3b087b876d625c720e98236c265c0ba98/torchao/float8/float8_linear.py#L70). Torchao does a reshape -> scaled mm -> reshape [here](https://github.com/pytorch/ao/blob/ed361ff5c7dd33aba9b4a0da2bd744de5a5debfb/torchao/float8/float8_linear.py#L122). When a Float8Tensor is reshaped, its scale is reshaped along with it [here](https://github.com/pytorch/ao/blob/8706d3f3b087b876d625c720e98236c265c0ba98/torchao/float8/float8_ops.py#L152). So the first reshape makes the "A tensor" (1,8192,2048) => (8192,2048) and the scale (1,8192,1) => (8192,1).
    - During post grad pass in async TP:
        - `A_node` has shape (1,8192,2048) (tensor from before this [reshape](https://github.com/pytorch/ao/blob/ed361ff5c7dd33aba9b4a0da2bd744de5a5debfb/torchao/float8/float8_linear.py#L122))
        - `A_scale` has shape (8192,1) (due to reshape op above, which caused the scale to be reshaped from (1,8192,1) => (8192,1)).

## Solution

**Note:** the compiler inserts a `reciprocal` op after the reshape, so we can't simply use the node before the reshape as the `A_scale_node`, otherwise it will affect the numerics.

- Short-term solution: if the specific pattern showne below is detected, insert a reshape node after the reciprocal, to reshape the reciprocal output back to the originals shape before the reshape.
    - reshape is just a view, so there should be no impact on performance
```
Before:
    reshape (a,bc,) to (a*b,c) -> reciprocal

After:
    reshape (a,bc,) to (a*b,c) -> reciprocal -> reshape (a*b,c) to (a,b,c)
```

- Long-term solution: implement a `torch._scaled_matmul` which can support 3D+ `A tensor`

## Test plan
- Added unit test which exercises this new path
- Manually tested with torchtitan with float8 rowwise + async TP

Pull Request resolved: pytorch#148001
Approved by: https://github.com/yifuwang
majing921201 pushed a commit to majing921201/pytorch that referenced this pull request Mar 4, 2025
…m -> reshape pattern" in async TP with rowwise scales (pytorch#148001)"

This reverts commit b8efebe.

Reverted pytorch#148001 on behalf of https://github.com/davidberard98 due to looks like another lint error ([comment](pytorch#148001 (comment)))
majing921201 pushed a commit to majing921201/pytorch that referenced this pull request Mar 4, 2025
…hape pattern" in async TP with rowwise scales (pytorch#148001)

Fixes pytorch/torchtitan#864

## Summary
While testing torchtitan with float8 training with rowwise scaling + async TP, a [bug](pytorch/torchtitan#864) was discovered. The symptom was the scaling factor dims did not match the dims of the tensor the scales were to be applied to.

My [root cause analysis](pytorch/torchtitan#864 (comment)) determined the reason is that when async TP graph manipulation constructs the `fused_scaled_matmul_reduce_scatter` op, it does not yet handle the "reshape -> scaled mm -> reshape" pattern used in torchao [here](https://github.com/pytorch/ao/blob/ed361ff5c7dd33aba9b4a0da2bd744de5a5debfb/torchao/float8/float8_linear.py#L122-L124) - specifically when row-wise scales are being used.

## TL;DR of root cause
- When a Float8Tensor is reshaped, the scale is reshaped along with it so the dimensions are aligned.
- In the graph manipulation logic of the micropipeline TP post grad pass, the scaled_mm `A tensor` node is referencing the tensor _before_ to the reshape op, but referencing the `A_scale` node _after_ the reshape op.

## Example
- Concrete example:
    - `A tensor` is a Float8Tensor with shape (1,8192,2048) and scale of shape (1,8192,1) when a matmul op is called in torchao [here](https://github.com/pytorch/ao/blob/8706d3f3b087b876d625c720e98236c265c0ba98/torchao/float8/float8_linear.py#L70). Torchao does a reshape -> scaled mm -> reshape [here](https://github.com/pytorch/ao/blob/ed361ff5c7dd33aba9b4a0da2bd744de5a5debfb/torchao/float8/float8_linear.py#L122). When a Float8Tensor is reshaped, its scale is reshaped along with it [here](https://github.com/pytorch/ao/blob/8706d3f3b087b876d625c720e98236c265c0ba98/torchao/float8/float8_ops.py#L152). So the first reshape makes the "A tensor" (1,8192,2048) => (8192,2048) and the scale (1,8192,1) => (8192,1).
    - During post grad pass in async TP:
        - `A_node` has shape (1,8192,2048) (tensor from before this [reshape](https://github.com/pytorch/ao/blob/ed361ff5c7dd33aba9b4a0da2bd744de5a5debfb/torchao/float8/float8_linear.py#L122))
        - `A_scale` has shape (8192,1) (due to reshape op above, which caused the scale to be reshaped from (1,8192,1) => (8192,1)).

## Solution

**Note:** the compiler inserts a `reciprocal` op after the reshape, so we can't simply use the node before the reshape as the `A_scale_node`, otherwise it will affect the numerics.

- Short-term solution: if the specific pattern showne below is detected, insert a reshape node after the reciprocal, to reshape the reciprocal output back to the originals shape before the reshape.
    - reshape is just a view, so there should be no impact on performance
```
Before:
    reshape (a,bc,) to (a*b,c) -> reciprocal

After:
    reshape (a,bc,) to (a*b,c) -> reciprocal -> reshape (a*b,c) to (a,b,c)
```

- Long-term solution: implement a `torch._scaled_matmul` which can support 3D+ `A tensor`

## Test plan
- Added unit test which exercises this new path
- Manually tested with torchtitan with float8 rowwise + async TP

Pull Request resolved: pytorch#148001
Approved by: https://github.com/yifuwang
pytorchmergebot pushed a commit that referenced this pull request Mar 27, 2025
…reduce-scatter (#149247)

Part of pytorch/torchtitan#866

## Context
- Async TP needs to support the "reshape -> scaled_mm -> reshape" pattern because scaled mm only supports 2D input tensors and 2D scales.
    - (a,b,c) => (a*b,c)
    - (a\*b,c) @ (c,d) = (a\*b,d)
    - (a\*b,d) => (a,b,d)

- Currently the implementation does not support scaled mm with rowwise scales **for all cases** of the reshape -> scaled_mm -> reshape pattern. The minimal example of this pattern is confirmed to work via this [unit test](https://github.com/pytorch/pytorch/blob/00a2c68f67adbd38847845016fd1ab9275cefbab/test/distributed/tensor/parallel/test_micro_pipeline_tp.py#L406), but more involved e2e examples in torchtitan fail silently (more context in final bullet point).
- Previously, the "A tensor" **node** referenced in the async TP graph manipulation code is the 3D+ node before the reshape, but the "A_scale" node is the 2d node from after the reshape, so they are incompatible.
- I previously implemented a simpler solution to this problem in #148001, with a [unit test](https://github.com/pytorch/pytorch/pull/148001/files#diff-115f1d0852382c9b58f22640d80999d879b33618e5f6c633fc9e4d0ca9781cecR406) confirming the fused node is indeed in the graph for the minimal example of the reshape->mm->reshape pattern. I also confirmed via manual e2e testing w/ torchtitan that the crash I was fixing no longer occurred. However, it turns out due to this [bug in torchtitan](pytorch/torchtitan#866)  it was causing async TP to fail silently and fall back to vanilla TP, hiding the fact that this original solution fixed the crash but the fusion would not occur for rowwise scales. Thus, more robust solution is needed to support all cases.

## Solution TL;DR
- Use the 2D 'A' tensor and corresponding 2D scales as input to the fused_matmul_reduce_scatter implementation, instead of the 3D+ tensor/scales.
- Track the "pre mm reshape" and "post mm reshape" separately, to be referenced in the `fused_scaled_matmul_reduce_scatter` implementation, to update the scatter dim through the pre-mm reshape, and apply the post-mm reshape before applying the reduce scatter and returning the output tensor.
- Separate the `fused_matmul_reduce_scatter` and the `fused_scaled_matmul_reduce_scatter` code paths, to simplify them both.
- By fixing the bug in torchtitan (PR pytorch/torchtitan#965) and implementing support for rowwise scales in pytorch in this PR, together these changes will solve the problem of how to support rowwise scales with all types of AC.

## Additional details for reviewers
To use the 2D A tensor while also supporting the "reshape -> mm -> reshape" pattern, the following other changes were needed:
- Track the pre-mm reshape, as it will affect the scatter dim used in the fused_matmul_reduce_scatter impementation.
- Track the post-mm reshape, as it will affect the output shape used in the fused_matmul_reduce_scatter impementation
- Based on the pre-mm reshape and the original scatter dim, calculate the new scatter dim for the 2D tensor. This is needed because during the pipelined producer mm implementation, the scatter dim is moved to dim 0 (so it can be sharded along the first dim and then get chunks to do mm ops on by indexing into the first dim), then moved back to it's original place before the reduce-scatter.
- Use the tracked post-mm reshape to reshape the stacked partial 2D outputs of the mm ops into 3D outputs needed for 1) the reduce-scatter w/ the original scatter dim, and 2) the expected output shape to prevent shape errors with subsequent ops.

## Test plan
- All existing unit tests passing.
- Expand unit tests for rowwise scales to test more scatter dims
- Added unit tests enforcing that async TP fails fast / throws an error if it fails to perform any fusions. Previously it just "failed silently" (fell back to vanilla TP without the user knowing) which has led to confusion, so this will improve the UX.
- Compared loss curves of bf16 vs float8 w/ rowwise scales to confirm integrity of numerics
- Confirmed via manual testing with torchtitan and inspecting the compile graph that the fusion is working as intended for:
    - bfloat16
    - float8 with tensorwise scales
    - float8 with rowwise scales

## Loss curves

Loss curves are virtually identical for bf16 + vanilla TP versus float8 with rowwise scales + async TP:

<img width="1017" alt="loss_async_tp" src="https://github.com/user-attachments/assets/4995db78-7012-490f-a370-f4fecc289a22" />

## Performance

#### Per op SAC
Performance benchmarks for torchtitan Llama3 8b training runs on 4 H100s with per op SAC, using FSDP degree=2, TP degree=2:
- bf16 (vanilla TP): TPS 5161.5, peak memory 50.53 GB
- bf16 (async TP): TPS  5229.5, peak memory 50.68 GB
- float8 tensorwise (vanilla TP): TPS: 5959.5, peak memory: 50.47 GB
- float8 tensorwise (async TP): TPS 5964.5, peak memory 50.47 GB
- float8 rowwise (vanilla TP): TPS: 4962.0, peak memory: 50.55 GB
- float8 rowwise (async TP): TPS 4966.5, peak memory 50.65 GB

#### Full AC
Llama3 70b training runs on 128 H100s with full AC, using FSDP=16, TP=8
- bf16 (vanilla TP): 598 TPS, peak memory 71.51 GB
- bf16 (async TP): TPS  673, peak memory 71.08 (+12.54% TPS vs vanilla TP)
- float8 tensorwise (vanilla TP): 820 TPS, peak memory  55.26 GB
- float8 tensorwise (async TP): 950 TPS, peak memory 55.91 GB (+15.85% TPS vs vanilla TP)
- float8 rowwise (vanilla TP): TPS: 540 TPS, peak memory 71.46 GB
- float8 rowwise (async TP): 560 TPS, peak memory 70.65 GB (+3.7% TPS vs vanilla TP but still unexpectedly lower than bf16)

As you can see, float8 rowwise is working but performance needs to be improved further.

## Other changes
- Added logging so the user will know why fusion failed if it does.
- Remove logic which inserted a reshape node targeting "A scale" to get it to be in 3D like the "A tensor" since it's no longer needed.

## Long term plan
- Add a `scaled_matmul` op in pytorch, which will natively support a 3D+ "A tensor" and allow us to simplify the async TP implementation by avoiding the reshape -> scaled_mm -> reshape pattern and the special handling for it.

## Visualizing fused nodes in graphs for torchtitan training runs

Below are examples of the visualized graph generated by torch compile for torchtitan llama3 8b training runs with per op SAC. These graphs provide additional evidence (beyond the new unit tests added) that the implementation is working correctly.

### bf16

<img width="900" alt="bf16-fusion" src="https://github.com/user-attachments/assets/a3bed917-28eb-4a56-8d6e-2d2bf498385c" />

### float8 with tensorwise scales

<img width="900" alt="tensorwise-node" src="https://github.com/user-attachments/assets/b212ec4a-1899-44de-a4de-18c74e1de68a" />

### float8 with rowwise scales

<img width="900" alt="rowwise" src="https://github.com/user-attachments/assets/ed3354a3-894b-4ec9-86d0-f80364bf3d83" />

Pull Request resolved: #149247
Approved by: https://github.com/kwen2501
amathewc pushed a commit to amathewc/pytorch that referenced this pull request Apr 17, 2025
…reduce-scatter (pytorch#149247)

Part of pytorch/torchtitan#866

## Context
- Async TP needs to support the "reshape -> scaled_mm -> reshape" pattern because scaled mm only supports 2D input tensors and 2D scales.
    - (a,b,c) => (a*b,c)
    - (a\*b,c) @ (c,d) = (a\*b,d)
    - (a\*b,d) => (a,b,d)

- Currently the implementation does not support scaled mm with rowwise scales **for all cases** of the reshape -> scaled_mm -> reshape pattern. The minimal example of this pattern is confirmed to work via this [unit test](https://github.com/pytorch/pytorch/blob/00a2c68f67adbd38847845016fd1ab9275cefbab/test/distributed/tensor/parallel/test_micro_pipeline_tp.py#L406), but more involved e2e examples in torchtitan fail silently (more context in final bullet point).
- Previously, the "A tensor" **node** referenced in the async TP graph manipulation code is the 3D+ node before the reshape, but the "A_scale" node is the 2d node from after the reshape, so they are incompatible.
- I previously implemented a simpler solution to this problem in pytorch#148001, with a [unit test](https://github.com/pytorch/pytorch/pull/148001/files#diff-115f1d0852382c9b58f22640d80999d879b33618e5f6c633fc9e4d0ca9781cecR406) confirming the fused node is indeed in the graph for the minimal example of the reshape->mm->reshape pattern. I also confirmed via manual e2e testing w/ torchtitan that the crash I was fixing no longer occurred. However, it turns out due to this [bug in torchtitan](pytorch/torchtitan#866)  it was causing async TP to fail silently and fall back to vanilla TP, hiding the fact that this original solution fixed the crash but the fusion would not occur for rowwise scales. Thus, more robust solution is needed to support all cases.

## Solution TL;DR
- Use the 2D 'A' tensor and corresponding 2D scales as input to the fused_matmul_reduce_scatter implementation, instead of the 3D+ tensor/scales.
- Track the "pre mm reshape" and "post mm reshape" separately, to be referenced in the `fused_scaled_matmul_reduce_scatter` implementation, to update the scatter dim through the pre-mm reshape, and apply the post-mm reshape before applying the reduce scatter and returning the output tensor.
- Separate the `fused_matmul_reduce_scatter` and the `fused_scaled_matmul_reduce_scatter` code paths, to simplify them both.
- By fixing the bug in torchtitan (PR pytorch/torchtitan#965) and implementing support for rowwise scales in pytorch in this PR, together these changes will solve the problem of how to support rowwise scales with all types of AC.

## Additional details for reviewers
To use the 2D A tensor while also supporting the "reshape -> mm -> reshape" pattern, the following other changes were needed:
- Track the pre-mm reshape, as it will affect the scatter dim used in the fused_matmul_reduce_scatter impementation.
- Track the post-mm reshape, as it will affect the output shape used in the fused_matmul_reduce_scatter impementation
- Based on the pre-mm reshape and the original scatter dim, calculate the new scatter dim for the 2D tensor. This is needed because during the pipelined producer mm implementation, the scatter dim is moved to dim 0 (so it can be sharded along the first dim and then get chunks to do mm ops on by indexing into the first dim), then moved back to it's original place before the reduce-scatter.
- Use the tracked post-mm reshape to reshape the stacked partial 2D outputs of the mm ops into 3D outputs needed for 1) the reduce-scatter w/ the original scatter dim, and 2) the expected output shape to prevent shape errors with subsequent ops.

## Test plan
- All existing unit tests passing.
- Expand unit tests for rowwise scales to test more scatter dims
- Added unit tests enforcing that async TP fails fast / throws an error if it fails to perform any fusions. Previously it just "failed silently" (fell back to vanilla TP without the user knowing) which has led to confusion, so this will improve the UX.
- Compared loss curves of bf16 vs float8 w/ rowwise scales to confirm integrity of numerics
- Confirmed via manual testing with torchtitan and inspecting the compile graph that the fusion is working as intended for:
    - bfloat16
    - float8 with tensorwise scales
    - float8 with rowwise scales

## Loss curves

Loss curves are virtually identical for bf16 + vanilla TP versus float8 with rowwise scales + async TP:

<img width="1017" alt="loss_async_tp" src="https://github.com/user-attachments/assets/4995db78-7012-490f-a370-f4fecc289a22" />

## Performance

#### Per op SAC
Performance benchmarks for torchtitan Llama3 8b training runs on 4 H100s with per op SAC, using FSDP degree=2, TP degree=2:
- bf16 (vanilla TP): TPS 5161.5, peak memory 50.53 GB
- bf16 (async TP): TPS  5229.5, peak memory 50.68 GB
- float8 tensorwise (vanilla TP): TPS: 5959.5, peak memory: 50.47 GB
- float8 tensorwise (async TP): TPS 5964.5, peak memory 50.47 GB
- float8 rowwise (vanilla TP): TPS: 4962.0, peak memory: 50.55 GB
- float8 rowwise (async TP): TPS 4966.5, peak memory 50.65 GB

#### Full AC
Llama3 70b training runs on 128 H100s with full AC, using FSDP=16, TP=8
- bf16 (vanilla TP): 598 TPS, peak memory 71.51 GB
- bf16 (async TP): TPS  673, peak memory 71.08 (+12.54% TPS vs vanilla TP)
- float8 tensorwise (vanilla TP): 820 TPS, peak memory  55.26 GB
- float8 tensorwise (async TP): 950 TPS, peak memory 55.91 GB (+15.85% TPS vs vanilla TP)
- float8 rowwise (vanilla TP): TPS: 540 TPS, peak memory 71.46 GB
- float8 rowwise (async TP): 560 TPS, peak memory 70.65 GB (+3.7% TPS vs vanilla TP but still unexpectedly lower than bf16)

As you can see, float8 rowwise is working but performance needs to be improved further.

## Other changes
- Added logging so the user will know why fusion failed if it does.
- Remove logic which inserted a reshape node targeting "A scale" to get it to be in 3D like the "A tensor" since it's no longer needed.

## Long term plan
- Add a `scaled_matmul` op in pytorch, which will natively support a 3D+ "A tensor" and allow us to simplify the async TP implementation by avoiding the reshape -> scaled_mm -> reshape pattern and the special handling for it.

## Visualizing fused nodes in graphs for torchtitan training runs

Below are examples of the visualized graph generated by torch compile for torchtitan llama3 8b training runs with per op SAC. These graphs provide additional evidence (beyond the new unit tests added) that the implementation is working correctly.

### bf16

<img width="900" alt="bf16-fusion" src="https://github.com/user-attachments/assets/a3bed917-28eb-4a56-8d6e-2d2bf498385c" />

### float8 with tensorwise scales

<img width="900" alt="tensorwise-node" src="https://github.com/user-attachments/assets/b212ec4a-1899-44de-a4de-18c74e1de68a" />

### float8 with rowwise scales

<img width="900" alt="rowwise" src="https://github.com/user-attachments/assets/ed3354a3-894b-4ec9-86d0-f80364bf3d83" />

Pull Request resolved: pytorch#149247
Approved by: https://github.com/kwen2501
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci-no-td Do not run TD on this PR ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged module: inductor oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (pipeline) release notes category Reverted
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Float8] Unable to run asyncTP + Float8 row with 'full' AC active, leading dims mismatch
7 participants
0