You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
[async TP] insert reshape node to handle "reshape -> scaled mm -> reshape pattern" in async TP with rowwise scales (#148001)
Fixespytorch/torchtitan#864
## Summary
While testing torchtitan with float8 training with rowwise scaling + async TP, a [bug](pytorch/torchtitan#864) was discovered. The symptom was the scaling factor dims did not match the dims of the tensor the scales were to be applied to.
My [root cause analysis](pytorch/torchtitan#864 (comment)) determined the reason is that when async TP graph manipulation constructs the `fused_scaled_matmul_reduce_scatter` op, it does not yet handle the "reshape -> scaled mm -> reshape" pattern used in torchao [here](https://github.com/pytorch/ao/blob/ed361ff5c7dd33aba9b4a0da2bd744de5a5debfb/torchao/float8/float8_linear.py#L122-L124) - specifically when row-wise scales are being used.
## TL;DR of root cause
- When a Float8Tensor is reshaped, the scale is reshaped along with it so the dimensions are aligned.
- In the graph manipulation logic of the micropipeline TP post grad pass, the scaled_mm `A tensor` node is referencing the tensor _before_ to the reshape op, but referencing the `A_scale` node _after_ the reshape op.
## Example
- Concrete example:
- `A tensor` is a Float8Tensor with shape (1,8192,2048) and scale of shape (1,8192,1) when a matmul op is called in torchao [here](https://github.com/pytorch/ao/blob/8706d3f3b087b876d625c720e98236c265c0ba98/torchao/float8/float8_linear.py#L70). Torchao does a reshape -> scaled mm -> reshape [here](https://github.com/pytorch/ao/blob/ed361ff5c7dd33aba9b4a0da2bd744de5a5debfb/torchao/float8/float8_linear.py#L122). When a Float8Tensor is reshaped, its scale is reshaped along with it [here](https://github.com/pytorch/ao/blob/8706d3f3b087b876d625c720e98236c265c0ba98/torchao/float8/float8_ops.py#L152). So the first reshape makes the "A tensor" (1,8192,2048) => (8192,2048) and the scale (1,8192,1) => (8192,1).
- During post grad pass in async TP:
- `A_node` has shape (1,8192,2048) (tensor from before this [reshape](https://github.com/pytorch/ao/blob/ed361ff5c7dd33aba9b4a0da2bd744de5a5debfb/torchao/float8/float8_linear.py#L122))
- `A_scale` has shape (8192,1) (due to reshape op above, which caused the scale to be reshaped from (1,8192,1) => (8192,1)).
## Solution
**Note:** the compiler inserts a `reciprocal` op after the reshape, so we can't simply use the node before the reshape as the `A_scale_node`, otherwise it will affect the numerics.
- Short-term solution: if the specific pattern showne below is detected, insert a reshape node after the reciprocal, to reshape the reciprocal output back to the originals shape before the reshape.
- reshape is just a view, so there should be no impact on performance
```
Before:
reshape (a,bc,) to (a*b,c) -> reciprocal
After:
reshape (a,bc,) to (a*b,c) -> reciprocal -> reshape (a*b,c) to (a,b,c)
```
- Long-term solution: implement a `torch._scaled_matmul` which can support 3D+ `A tensor`
## Test plan
- Added unit test which exercises this new path
- Manually tested with torchtitan with float8 rowwise + async TP
Pull Request resolved: #148001
Approved by: https://github.com/yifuwang
0 commit comments