8000 [Async TP] all-gather-matuls not fusing properly when rowwise scales are used · Issue #149990 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[Async TP] all-gather-matuls not fusing properly when rowwise scales are used #149990

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
danielvegamyhre opened this issue Mar 25, 2025 · 19 comments
Assignees
Labels
oncall: distributed Add this issue/PR to distributed oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@danielvegamyhre
Copy link
Contributor
danielvegamyhre commented Mar 25, 2025

🐛 Describe the bug

Summary

I recently implemented async TP support fusing scaled-matmul-reduce-scatter patterns with rowwise scales (#149247) as well as support for various AC settings which had become broken (no AC, per layer SAC, per op SAC with reduce_scatter saved) (#149946).

When testing the performance of various configurations to ensure stability of the changes, I found that while float8 rowwise training with async TP had correct numerical accuracy, the performance was non-optimal (see benchmarks below).

After looking at the traces, I found the matmul-reduce-scatters were being fused properly, so my change was working as intended - however, the all-gather-matmul patterns were NOT being fused properly. This seems to (at least in part) explain the poor performance for async TP with rowwise scales.

Looking at the benchmarks below we ALSO see vanilla TP perf with rowwise scales is unexpectedly low. I will create a separate issue for this, though.

Performance Benchmarks

Llama3 70b training runs on 128 H100s with full AC, using FSDP=16, TP=8

  • bf16 (vanilla TP): 598 TPS, peak memory 71.51 GB
  • bf16 (async TP): TPS 673, peak memory 71.08 (+12.54% TPS vs vanilla TP)
  • float8 tensorwise (vanilla TP): 820 TPS, peak memory 55.26 GB
  • float8 tensorwise (async TP): 950 TPS, peak memory 55.91 GB (+15.85% TPS vs vanilla TP)
  • float8 rowwise (vanilla TP): TPS: 540 TPS, peak memory 71.46 GB
  • float8 rowwise (async TP): 560 TPS, peak memory 70.65 GB (+3.7% TPS vs vanilla TP but still unexpectedly lower than bf16)

As you can see, float8 rowwise is working but performance needs to be improved further.

cc @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o @vkuzo @lessw2020

Versions

pytorch @ HEAD

@danielvegamyhre danielvegamyhre self-assigned this Mar 25, 2025
@danielvegamyhre
Copy link
Contributor Author

I'm working on this but cc-ing distributed folks to ensure they have visibility on this as well @tianyu-l @kwen2501

@malfet malfet added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Mar 26, 2025
@danielvegamyhre danielvegamyhre changed the title [Async TP] all-gather-matmuls not fusing properly when rowwise scales are used Poor performance using dynamic float8 quantization with rowwise scales for training with both TP and async TP Mar 27, 2025
@danielvegamyhre danielvegamyhre changed the title Poor performance using dynamic float8 quantization with rowwise scales for training with both TP and async TP [Async TP] all-gather-matuls not fusing properly when rowwise scales are used Mar 27, 2025
@danielvegamyhre
Copy link
Contributor Author
danielvegamyhre commented Mar 29, 2025

I looked into this and figured out why all-gather-matmuls are not fused when using rowwise scales, but are fused correctly with tensorwise scales. The reason is the following:

  • When using float8 with scaled_mm, the all-gather-matmul pattern matcher should match on subgraphs of the "all-gather -> reshape -> scaled mm -> reshape" pattern.

  • For TP with tensorwise scaling, the TP all-gathers happen in float8, so after the all-gather it can go directly into a reshape -> scaled mm -> reshape pattern, without doing any conversion to float8 first. This matches the graph pattern. (See diagram 1 below).

  • However, for TP with rowwise scaling, the TP all-gathers happen in high precision, so there needs to be float8 conversion ops between the all-gather and the reshape -> scaled mm -> reshape pattern. This does not match the graph pattern. (See diagram 2 below).

How we can fix this:

  • We can't simply update the pattern matching to include the float8 conversion ops, since that would require implementing float8 conversion in the underlying symmetric memory based "fused all-gather-scaled_mm" implementation in core, which is not feasible / not a good idea IMO.

  • The float8 rowwise/columnwise sharding primitives here currently assume tensorwise scales. However, one option we could look into is supporting rowwise scaling for these. My understanding is that there's no benefit of doing all-gather in float8 w/ rowwise scaling, since the overhead of (float8 conversion + M number of scales per tensor instead of 1, so many more bytes sent over the network) >= just sending the data in bf16.

  • So, while all-gather in float8 for rowwise scales may not provide a direct performance benefit, it will provide an indirect performance benefit by enabling async TP to fuse all-gather-matmuls as well, instead of only fusing matmul-reduce-scatters. cc @vkuzo for thoughts on this

Important distinction: this does not explain why the baseline/vanilla TP perf is bad for float8 rowwise, but it explains why the delta between basline vanilla TP and async TP is lower for rowwise than tensorwise or bf16.

cc @tianyu-l @fegin @lessw2020

Tensorwise

Image

Rowwise

Image

@danielvegamyhre
Copy link
Contributor Author

cc @drisspg as well if you have any thoughts on the potential solution outlined here: #149990 (comment)

@drisspg
Copy link
Contributor
drisspg commented Mar 31, 2025

One clarification question, if the PerTensor was communicated in bf16 and then converted would this also not cause the pattern to match?

@danielvegamyhre
Copy link
Contributor Author

One clarification question, if the PerTensor was communicated in bf16 and then converted would this also not cause the pattern to match?

Yes, if the conversion to float8 happens after the all-gather, the pattern will not match.

@vkuzo
Copy link
Contributor
vkuzo commented Apr 2, 2025

However, one option we could look into is supporting rowwise scaling for these. My understanding is that there's no benefit of doing all-gather in float8 w/ rowwise scaling, since the overhead of (float8 conversion + M number of scales per tensor instead of 1, so many more bytes sent over the network) >= just sending the data in bf16.

I think an estimate of eng time and expected benefit would help decide if this is worth doing! Thoughts on what those might be?

@fduwjj fduwjj added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Apr 23, 2025
@danielvegamyhre
Copy link
Contributor Author
danielvegamyhre commented Apr 29, 2025

Update: I prototyped a solution and discussed ideas with @fegin but there are some issues. Supporting fp8 all gather in TP for rowwise scales seems like it will be pretty ugly, implementation wise, due to:

  1. Needing 2 different "input_hp" (row-major with scales along logical rows, and col-major with scales along logical columns). The prototype still is not fully working but in theory it should be possible.

  2. grad_output cannot be converted w/ rowwise scales ahead of time, because in backward() we need grad_output and grad_output_t, both in row-major with scales along logical rows to be the left operands in float8 GEMMs, for computing grad_input and grad_weight, respectively. However, grad_output and grad_output_t will have different scales if computed along the same axis=-1. Unlike the forward() case, we can't write 2 outputs, because we have control over how many inputs to feed forward, but we can't arbitrarily control the inputs to backward().

@danielvegamyhre
Copy link
Contributor Author

update: I have a prototype fp8 rowwise all-gather working for the llama3 debug model with 1 layer, but there is still a bug for larger models.

@danielvegamyhre
Copy link
Contributor Author

update: i managed to get fp8 rowwise all-gather working for llama3 8b. perf is flat compared to using bf16 all-gather, which is expected, since we are transferring 2 fp8 input tensors over the network, instead of 1 bf16 input tensor. however, this should allow us to integrate with async TP.

@danielvegamyhre
Copy link
Contributor Author
danielvegamyhre commented Apr 30, 2025

i've found that when using Float8RowwiseParallel and Float8ColumnwiseParallel, there are actually no all-gathers happening in float8, it must be some kind of other collective in *.redistribute()? because we have an assertion that scales are tensorwise in all_gather, wait_tensor, cat, split op overrides, and that assertion is not triggered.

Furthermore, I found looking at the graph that fp8 quant ops are still occuring after the all gathers, so implementing Float8RowwiseParallel and Float8ColwiseParallel is not sufficient, we must implement PrepareFloat8ModuleInput as well, and add support for rowwise scales in primitive op overrides: all_gather, wait_tensor, cat, split.

I implemented this, following the same strategy of writing 2 fp8 outputs (row-major w/ rowwise scales and col-major with colwise scales) but this means that the Attention.forward() has 2 inputs to deal with, which i don't think we want...

@danielvegamyhre
Copy link
Contributor Author
danielvegamyhre commented Apr 30, 2025

another update: I've found that when Attention.forward() has an optional 2nd input (col-major) and passes that through to the projections, which have been converted to Float8Linears that now accept this optional 2nd arg (e.g. self.wq(input_row_major, input_col_major) etc) - for some reason, a pre-forward module hook is converting the 2nd arg from a DTensor(Float8Tensor) => None? This is causing a problem.

@danielvegamyhre
Copy link
Contributor Author

Latest issue: I have a PrepareModuleInput subclass for float8 rowwise, which is creating has an "input_layout" of Shard(dim=1) and "desired_layout" of Replicate(). So I expect an all-gather along dim=1. However, it all-gathers along dim=0. I set a breakpoint and confirmed this in the op override:

@implements(
    [
        c10d_functional.all_gather_into_tensor.default,
        _c10d_functional.all_gather_into_tensor.default,
    ]
)
def allgather_fp8(aten_op, args, kwargs=None):
    """
    override funcol with FP8 handling
    """
    #_assert_tensorwise_scale(aten_op, args[0]._scale)
    torch.distributed.breakpoint()
    fp8_input = args[0] # (8, 1024, 256)
    assert isinstance(fp8_input, Float8Tensor), (
        f"expecting a Float8Tensor for allgather but found {type(fp8_input)}"
    )

    fp8_data = fp8_input._data
    fp8_data = fp8_data.contiguous()

    # fp8_out becomes (16, 1024, 256) <- gathered along dim 0 even though gather_dim is 1?
    # input_layout = Shard(dim=1), desired_layout=Replicate()
    fp8_out = aten_op(fp8_data, *args[1:], **kwargs) 
    return Float8Tensor(
        fp8_out,
        fp8_input._scale,
        fp8_input._orig_dtype,
        fp8_input._linear_mm_config,
        fp8_input._gemm_input_role,
    )

@danielvegamyhre
Copy link
Contributor Author

Update: i have the full e2e PoC working! however, the loss is spiky / not identical to the bf16 run, so looking into that.

@danielvegamyhre
Copy link
Contributor Author
danielvegamyhre commented May 1, 2025

Update: figured out the issue, input-col-major should have 0 grads since it is only created for rowwise all gather and not meaningful for the gradient flow, so i did some tricks to pass in a 0 filled DTensor, and now the loss curve looks good.

@danielvegamyhre
Copy link
Contributor Author

Another update: I started looking into why async TP is crashing w/ fp8 rowwise all gather, and found it's due to a similar issue to #149247 where the "A" tensor used is from before a reshape, and the "A_scale" node is from after a reshape, so there's a mismatch.

@danielvegamyhre
Copy link
Contributor Author

another update: strangely, the loss curve looked good for the first 100 steps, yet I think the scales were actually wrong - I was repeating the scale for shard0 across both shards (TP degree=2) and somehow it was working. After fixing this and using the local shards for each tensor, it works for eager mode but with compile I get NaNs....

@danielvegamyhre
8000 Copy link
Contributor Author
danielvegamyhre commented May 8, 2025

another update: strangely, the loss curve looked good for the first 100 steps, yet I think the scales were actually wrong - I was repeating the scale for shard0 across both shards (TP degree=2) and somehow it was working. After fixing this and using the local shards for each tensor, it works for eager mode but with compile I get NaNs....

Update (forgot to post this last week): I figured out the reason we get NaNs after some training steps and it revealed a fundamental incompatibility between fp8 with rowwise scales in torchao and async TP. I stepped through the various collective ops with a debugger and discovered an issue with how all-gather works under the hood. This is the relevant code.

To illustrate the problem I created this diagram:

Image

Basically no matter what the sharding/gather dim is defined as (e.g. Shard(dim=1) => Replicate()), it does the following:

  1. all-gather along dim=0 (presumably due to how NCCL works, or perhaps it is just more efficient for collectives to write to row major layout - I know async TP also forces collectives to be along dim 0 as well, so I assume these collective primitives are more efficient this way)
  2. if gather dim != 0, then:
    • chunk along dim=0
    • cat along gather dim

This is problematic for the case where we need to convert the input activation tensor to float8 in row-major format with scales along logical rows (for output = input @ weight.t()) AND convert to float8 in col-major format with scales along logical columns (for grad_weight = grad_output_t @ input).

As you can see in the diagram, we cannot all-gather fp8 input tensor shards when the scale dim is the same as the all-gather dim (i.e., the case of the input shard in col-major format with scales along logical columns). The torchao Float8Tensor absraction cannot represent this.

The reason we were getting NaNs is because I was doing concatenating the float8 shards to set Float8Tensor._data, but setting the Float8Tensor._scale as the maxes of the scales, which is incorrect, since the original scales have already been applied to the fp8 data, so when we dequantize using the wrong scales we can get NaNs.

IMO we need to develop a solution to this that will also be compatible with MX formats + async TP, as well as future proof against future quantization strategies and dtypes.

Possible solutions

  1. Extend Float8Tensor or make a new tensor subclass with a more flexible way to represent "scale => subtensor associated with that scale." This is more like a workaround imo and would add additional complexity, debugability issues, etc.
  2. Add new async TP pattern matching to include the dynamic quant ops, and directly implement them in the custom ops which replace the target subgraph. We can't import torchao in PT core and have that be a dependency, so this would replicating the work we've done in torchao and maintaining parity between the two (not good).
  3. Preferred option: Have inductor post grad pass and potentially custom symmetric memory ops for implementations of async TP for low precision training (float8, MX) live in torchao. This is my preferred option because:
    • We can include the dynamic quant ops in the pattern matching without issue, since we can directly use the quantization code in torchao in the custom ops replacing the target subgraph.
    • We won't be forced to try to make float8/MX/etc subgraphs fit the same pattern matching as bf16 without breaking anything, which is very limiting and/or not possible in some cases (like this one).
      • For example, the only reason fp8 tensorwise training is compatible with async TP is because we can do all-gather in fp8 and thus the inductor pattern matches, but if we constrain async TP to only being compatible with dynamic quantization strategies that can be all-gathered in low-precision, that is limiting both composability and peak achievable performance.
    • Still a consistent story for the user of using torchao for "all things quantization and sparsity"
    • There is already precedent for doing this in torchao: [PT2E][X86] Migrate fusion passes in Inductor to torchao ao#2140

@danielvegamyhre
Copy link
Contributor Author

@vkuzo @drisspg any thoughts on option 3 in the above comment (#149990 (comment))? I'd like to align internally with the team first before proposing it to PTD team soon for H2 planning

@vkuzo
Copy link
Contributor
vkuzo commented May 12, 2025

thanks for the findings! let's chat offline about this and summarize here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
oncall: distributed Add this issue/PR to distributed oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

5 participants
0