-
Notifications
You must be signed in to change notification settings - Fork 24.2k
[Async TP] all-gather-matuls not fusing properly when rowwise scales are used #149990
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
I looked into this and figured out why all-gather-matmuls are not fused when using rowwise scales, but are fused correctly with tensorwise scales. The reason is the following:
How we can fix this:
Important distinction: this does not explain why the baseline/vanilla TP perf is bad for float8 rowwise, but it explains why the delta between basline vanilla TP and async TP is lower for rowwise than tensorwise or bf16. cc @tianyu-l @fegin @lessw2020 TensorwiseRowwise |
cc @drisspg as well if you have any thoughts on the potential solution outlined here: #149990 (comment) |
One clarification question, if the PerTensor was communicated in bf16 and then converted would this also not cause the pattern to match? |
Yes, if the conversion to float8 happens after the all-gather, the pattern will not match. |
I think an estimate of eng time and expected benefit would help decide if this is worth doing! Thoughts on what those might be? |
Update: I prototyped a solution and discussed ideas with @fegin but there are some issues. Supporting fp8 all gather in TP for rowwise scales seems like it will be pretty ugly, implementation wise, due to:
|
update: I have a prototype fp8 rowwise all-gather working for the llama3 debug model with 1 layer, but there is still a bug for larger models. |
update: i managed to get fp8 rowwise all-gather working for llama3 8b. perf is flat compared to using bf16 all-gather, which is expected, since we are transferring 2 fp8 |
i've found that when using Float8RowwiseParallel and Float8ColumnwiseParallel, there are actually no all-gathers happening in float8, it must be some kind of other collective in *.redistribute()? because we have an assertion that scales are tensorwise in all_gather, wait_tensor, cat, split op overrides, and that assertion is not triggered. Furthermore, I found looking at the graph that fp8 quant ops are still occuring after the all gathers, so implementing Float8RowwiseParallel and Float8ColwiseParallel is not sufficient, we must implement PrepareFloat8ModuleInput as well, and add support for rowwise scales in primitive op overrides: all_gather, wait_tensor, cat, split. I implemented this, following the same strategy of writing 2 fp8 outputs (row-major w/ rowwise scales and col-major with colwise scales) but this means that the Attention.forward() has 2 inputs to deal with, which i don't think we want... |
another update: I've found that when Attention.forward() has an optional 2nd input (col-major) and passes that through to the projections, which have been converted to Float8Linears that now accept this optional 2nd arg (e.g. |
Latest issue: I have a PrepareModuleInput subclass for float8 rowwise, which is creating has an "input_layout" of Shard(dim=1) and "desired_layout" of Replicate(). So I expect an all-gather along dim=1. However, it all-gathers along dim=0. I set a breakpoint and confirmed this in the op override: @implements(
[
c10d_functional.all_gather_into_tensor.default,
_c10d_functional.all_gather_into_tensor.default,
]
)
def allgather_fp8(aten_op, args, kwargs=None):
"""
override funcol with FP8 handling
"""
#_assert_tensorwise_scale(aten_op, args[0]._scale)
torch.distributed.breakpoint()
fp8_input = args[0] # (8, 1024, 256)
assert isinstance(fp8_input, Float8Tensor), (
f"expecting a Float8Tensor for allgather but found {type(fp8_input)}"
)
fp8_data = fp8_input._data
fp8_data = fp8_data.contiguous()
# fp8_out becomes (16, 1024, 256) <- gathered along dim 0 even though gather_dim is 1?
# input_layout = Shard(dim=1), desired_layout=Replicate()
fp8_out = aten_op(fp8_data, *args[1:], **kwargs)
return Float8Tensor(
fp8_out,
fp8_input._scale,
fp8_input._orig_dtype,
fp8_input._linear_mm_config,
fp8_input._gemm_input_role,
) |
Update: i have the full e2e PoC working! however, the loss is spiky / not identical to the bf16 run, so looking into that. |
Update: figured out the issue, input-col-major should have 0 grads since it is only created for rowwise all gather and not meaningful for the gradient flow, so i did some tricks to pass in a 0 filled DTensor, and now the loss curve looks good.
|
Another update: I started looking into why async TP is crashing w/ fp8 rowwise all gather, and found it's due to a similar issue to #149247 where the "A" tensor used is from before a reshape, and the "A_scale" node is from after a reshape, so there's a mismatch. |
another update: strangely, the loss curve looked good for the first 100 steps, yet I think the scales were actually wrong - I was repeating the scale for shard0 across both shards (TP degree=2) and somehow it was working. After fixing this and using the local shards for each tensor, it works for eager mode but with compile I get NaNs.... |
Update (forgot to post this last week): I figured out the reason we get NaNs after some training steps and it revealed a fundamental incompatibility between fp8 with rowwise scales in torchao and async TP. I stepped through the various collective ops with a debugger and discovered an issue with how all-gather works under the hood. This is the relevant code. To illustrate the problem I created this diagram: Basically no matter what the sharding/gather dim is defined as (e.g.
This is problematic for the case where we need to convert the input activation tensor to float8 in row-major format with scales along logical rows (for As you can see in the diagram, we cannot all-gather fp8 input tensor shards when the scale dim is the same as the all-gather dim (i.e., the case of the input shard in col-major format with scales along logical columns). The torchao Float8Tensor absraction cannot represent this. The reason we were getting NaNs is because I was doing concatenating the float8 shards to set Float8Tensor._data, but setting the Float8Tensor._scale as the maxes of the scales, which is incorrect, since the original scales have already been applied to the fp8 data, so when we dequantize using the wrong scales we can get NaNs. IMO we need to develop a solution to this that will also be compatible with MX formats + async TP, as well as future proof against future quantization strategies and dtypes. Possible solutions
|
@vkuzo @drisspg any thoughts on option 3 in the above comment (#149990 (comment))? I'd like to align internally with the team first before proposing it to PTD team soon for H2 planning |
thanks for the findings! let's chat offline about this and summarize here |
🐛 Describe the bug
Summary
I recently implemented async TP support fusing scaled-matmul-reduce-scatter patterns with rowwise scales (#149247) as well as support for various AC settings which had become broken (no AC, per layer SAC, per op SAC with reduce_scatter saved) (#149946).
When testing the performance of various configurations to ensure stability of the changes, I found that while float8 rowwise training with async TP had correct numerical accuracy, the performance was non-optimal (see benchmarks below).
After looking at the traces, I found the matmul-reduce-scatters were being fused properly, so my change was working as intended - however, the all-gather-matmul patterns were NOT being fused properly. This seems to (at least in part) explain the poor performance for async TP with rowwise scales.
Looking at the benchmarks below we ALSO see vanilla TP perf with rowwise scales is unexpectedly low. I will create a separate issue for this, though.
Performance Benchmarks
Llama3 70b training runs on 128 H100s with full AC, using FSDP=16, TP=8
As you can see, float8 rowwise is working but performance needs to be improved further.
cc @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o @vkuzo @lessw2020
Versions
pytorch @ HEAD
The text was updated successfully, but these errors were encountered: