Description
🐛 Describe the bug
Summary
I recently implemented async TP support fusing scaled-matmul-reduce-scatter patterns with rowwise scales (#149247) as well as support for various AC settings which had become broken (no AC, per layer SAC, per op SAC with reduce_scatter saved) (#149946).
When testing the performance of various configurations to ensure stability of the changes, I found that while float8 rowwise training with async TP had correct numerical accuracy and some performance gains over the vanilla TP baseline, the performance gains were less than that of bfloat16 and float8 tensorwise.
After looking at the traces, I found the matmul-reduce-scatters were being fused properly, so my change was working as intended - however, the all-gather-matmul patterns were NOT being fused. This seems to be the root cause.
Looking at the benchmarks below we ALSO see vanilla TP baseline perf for rowwise scales is unexpectedly low. I will create a separate issue for this, though.
Performance Benchmarks
Llama3 70b training runs on 128 H100s with full AC, using FSDP=16, TP=8
- bf16 (vanilla TP): 598 TPS, peak memory 71.51 GB
- bf16 (async TP): TPS 673, peak memory 71.08 (+12.54% TPS vs vanilla TP)
- float8 tensorwise (vanilla TP): 820 TPS, peak memory 55.26 GB
- float8 tensorwise (async TP): 950 TPS, peak memory 55.91 GB (+15.85% TPS vs vanilla TP)
- float8 rowwise (vanilla TP): TPS: 540 TPS, peak memory 71.46 GB
- float8 rowwise (async TP): 560 TPS, peak memory 70.65 GB (+3.7% TPS vs vanilla TP but still unexpectedly lower than bf16)
As you can see, float8 rowwise is working but performance needs to be improved further.
cc @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o @vkuzo @lessw2020
Versions
pytorch @ HEAD