8000 Investigate FlexAttention performance degradation on low precision inputs · Issue #147336 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Investigate FlexAttention performance degradation on low precision inputs #147336

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
danielvegamyhre opened this issue Feb 17, 2025 · 4 comments
Closed
Assignees
Labels
module: flex attention module: higher order operators torch.cond and similar module: pt2-dispatcher PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op, oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module upstream triton Upstream Triton Issue

Comments

@danielvegamyhre
Copy link
Contributor
danielvegamyhre commented Feb 17, 2025

Creating this issue to track my work investigating the root cause of unexpected slowdowns observed in flex attention using low precision input tensors.

TL;DR

Current investigation seems to point to the root cause being related to a huge increase in shared memory access bank conflicts. Evidence so far points to the loading of fp8 V blocks into SRAM being the problem.

Repro script

As a first step I wrote this repro script which runs benchmarks and optionally produces traces, for bf16 and fp8 dtypes.

Benchmark

Initial benchmarks show flex attention forward pass takes roughly ~1.39x longer using fp8 inputs versus bf16 inputs.

$ python3 profile_flex.py --fp8 --bf16
2025-02-16 21:51:55,038 - flex_bench - INFO - Running benchmark: bf16
2025-02-16 21:51:56,765 - flex_bench - INFO - bf16: 441.3840833333334 us
2025-02-16 21:51:56,772 - flex_bench - INFO - Running benchmark: fp8e4m3
2025-02-16 21:51:57,373 - flex_bench - INFO - fp8e4m3: 615.4808518518514 us

Triton kernel analysis

The main difference between the triton kernels generated by inductor for "compiled_flex" and "compiled_scale_flex" is the existence of the following lines of code which implement the score mod func. Nothing here looks problematic to me.

    tmp0 = (qk).to(tl.float32)
    tmp1 = tmp0 * tl.load(in_ptr8 + 0)
    tmp2 = tmp1 * tl.load(in_ptr9 + 0)
    post_mod_scores = tmp2

NCU

We can use ncu to analyze the specific kernel which implements flex attention:

ncu --set detailed -k regex:triton_tem_.* python3 profile_flex.py --bf16
ncu --set detailed -k regex:triton_tem_.* python3 profile_flex.py --fp8 

Speed of light bf16

  triton_tem_fused_2 (8, 256, 1)x(256, 1, 1), Context 1, Stream 7, Device 0, CC 9.0
    Section: GPU Speed Of Light Throughput
    ----------------------- ----------- ------------
    Metric Name             Metric Unit Metric Value
    ----------------------- ----------- ------------
    DRAM Frequency                  Ghz         1.59
    SM Frequency                    Ghz         1.24
    Elapsed Cycles                cycle      751,814
    Memory Throughput                 %        43.38
    DRAM Throughput                   %        17.69
    Duration                         us       602.69
    L1/TEX Cache Throughput           %        45.31
    L2 Cache Throughput               %        21.36
    SM Active Cycles              cycle   719,559.25
    Compute (SM) Throughput           %        35.59
    ----------------------- ----------- ------------

Speed of light fp8

    Section: GPU Speed Of Light Throughput
    ----------------------- ----------- ------------
    Metric Name             Metric Unit Metric Value
    ----------------------- ----------- ------------
    DRAM Frequency                  Ghz         1.59
    SM Frequency                    Ghz         1.23
    Elapsed Cycles                cycle    1,056,196
    Memory Throughput                 %        72.38
    DRAM Throughput                   %         8.70
    Duration                         us       853.86
    L1/TEX Cache Throughput           %        74.56
    L2 Cache Throughput               %         9.74
    SM Active Cycles              cycle 1,022,350.08
    Compute (SM) Throughput           %        27.49
    ----------------------- ----------- ------------

Uncoalesced shared memory access

Importantly, in the NCU output for fp8 we get a warning regarding uncoalesced shared memory accesses causing a excessive wavefronts. It seems likely this is related to the observed slowdown:

    OPT   Est. Speedup: 60.51%                                                                                          
          This kernel has uncoalesced shared accesses resulting in a total of 58720256 excessive wavefronts (63% of the 
          total 92856320 wavefronts). Check the L1 Wavefronts Shared Excessive table for the primary source locations.  
          The CUDA Best Practices Guide                                                                                 
           (https://docs.nvidia.com/cuda/cuda-c-best-practices-guide/index.html#shared-memory-in-matrix-multiplication-c
          -ab) has an example on optimizing shared memory accesses.             

Next I generated some profiles for bf16 and fp8 to analyze in the NCU UI:

TORCH_LOGS="output_code" TORCH_LOGS_OUT="compile_logs/fp8_log.txt" ncu --set detailed -k regex:triton_tem_.* -o profiles/fp8-prof python3 profile_flex.py --fp8

Here I also observed the fp8 profile has uncoalesced shared access warnings which are not present in the bf16 profile:

Image

Diving deeper, we can see the exact line of triton code where this is occurring:

Image

Looking at the sampling counts, we can see the majority are flagged as "short scoreboard." In the NVIDIA docs we can see this usually means this is related to bank conflicts in shared memory load/store operations.

Image

To confirm this, I ran some metric counts to measure the number of shared memory load/store bank conflicts for bf16 vs fp8. I observed an orders of magnitude more conflicts in fp8 than bf16, for both load and store operations:

Load and store conflicts bf16

  triton_tem_fused_2 (8, 256, 1)x(256, 1, 1), Context 1, Stream 7, Device 0, CC 9.0
    Section: Command line profiler metrics
    -------------------------------------------------------- ----------- ------------
    Metric Name                                              Metric Unit Metric Value
    -------------------------------------------------------- ----------- ------------
    l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld.sum                        0
    l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_st.sum                  111,863
    -------------------------------------------------------- ----------- ------------

  triton_tem_fused_2 (8, 256, 1)x(256, 1, 1), Context 1, Stream 7, Device 0, CC 9.0
    Section: Command line profiler metrics
    -------------------------------------------------------- ----------- ------------
    Metric Name                                              Metric Unit Metric Value
    -------------------------------------------------------- ----------- ------------
    l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld.sum                        0
    l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_st.sum                  104,116
    -------------------------------------------------------- ----------- ------------

  triton_tem_fused_2 (8, 256, 1)x(256, 1, 1), Context 1, Stream 7, Device 0, CC 9.0
    Section: Command line profiler metrics
    -------------------------------------------------------- ----------- ------------
    Metric Name                                              Metric Unit Metric Value
    -------------------------------------------------------- ----------- ------------
    l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld.sum                        0
    l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_st.sum                  114,396
    -------------------------------------------------------- ----------- ------------

  triton_tem_fused_2 (8, 256, 1)x(256, 1, 1), Context 1, Stream 7, Device 0, CC 9.0
    Section: Command line profiler metrics
    -------------------------------------------------------- ----------- ------------
    Metric Name                                              Metric Unit Metric Value
    -------------------------------------------------------- ----------- ------------
    l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld.sum                        0
    l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_st.sum                  113,613
    -------------------------------------------------------- ----------- ------------

  triton_tem_fused_2 (8, 256, 1)x(256, 1, 1), Context 1, Stream 7, Device 0, CC 9.0
    Section: Command line profiler metrics
    -------------------------------------------------------- ----------- ------------
    Metric Name                                              Metric Unit Metric Value
    -------------------------------------------------------- ----------- ------------
    l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld.sum                        0
    l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_st.sum                  106,008
    -------------------------------------------------------- ----------- ------------

  triton_tem_fused_2 (8, 256, 1)x(256, 1, 1), Context 1, Stream 7, Device 0, CC 9.0
    Section: Command line profiler metrics
    -------------------------------------------------------- ----------- ------------
    Metric Name                                              Metric Unit Metric Value
    -------------------------------------------------------- ----------- ------------
    l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld.sum                        0
    l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_st.sum                  102,859
    -------------------------------------------------------- ----------- ------------

  triton_tem_fused_2 (8, 256, 1)x(256, 1, 1), Context 1, Stream 7, Device 0, CC 9.0
    Section: Command line profiler metrics
    -------------------------------------------------------- ----------- ------------
    Metric Name                                              Metric Unit Metric Value
    -------------------------------------------------------- ----------- ------------
    l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld.sum                        0
    l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_st.sum                  101,981
    -------------------------------------------------------- ----------- ------------

  triton_tem_fused_2 (8, 256, 1)x(256, 1, 1), Context 1, Stream 7, Device 0, CC 9.0
    Section: Command line profiler metrics
    -------------------------------------------------------- ----------- ------------
    Metric Name                                              Metric Unit Metric Value
    -------------------------------------------------------- ----------- ------------
    l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld.sum                        0
    l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_st.sum                  104,583
    -------------------------------------------------------- ----------- ------------

Load and store conflicts fp8

 triton_tem_fused__to_copy_mul_2 (8, 256, 1)x(128, 1, 1), Context 1, Stream 7, Device 0, CC 9.0
    Section: Command line profiler metrics
    -------------------------------------------------------- ----------- ------------
    Metric Name                                              Metric Unit Metric Value
    -------------------------------------------------------- ----------- ------------
    l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld.sum                    6,467
    l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_st.sum               58,782,390
    -------------------------------------------------------- ----------- ------------

  triton_tem_fused__to_copy_mul_2 (8, 256, 1)x(128, 1, 1), Context 1, Stream 7, Device 0, CC 9.0
    Section: Command line profiler metrics
    -------------------------------------------------------- ----------- ------------
    Metric Name                                              Metric Unit Metric Value
    -------------------------------------------------------- ----------- ------------
    l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld.sum                    5,698
    l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_st.sum               58,771,364
    -------------------------------------------------------- ----------- ------------

  triton_tem_fused__to_copy_mul_2 (8, 256, 1)x(128, 1, 1), Context 1, Stream 7, Device 0, CC 9.0
    Section: Command line profiler metrics
    -------------------------------------------------------- ----------- ------------
    Metric Name                                              Metric Unit Metric Value
    -------------------------------------------------------- ----------- ------------
    l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld.sum                    6,234
    l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_st.sum               58,783,926
    -------------------------------------------------------- ----------- ------------

  triton_tem_fused__to_copy_mul_2 (8, 256, 1)x(128, 1, 1), Context 1, Stream 7, Device 0, CC 9.0
    Section: Command line profiler metrics
    -------------------------------------------------------- ----------- ------------
    Metric Name                                              Metric Unit Metric Value
    -------------------------------------------------------- ----------- ------------
    l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld.sum                    5,518
    l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_st.sum               58,800,274
    -------------------------------------------------------- ----------- ------------

  triton_tem_fused__to_copy_mul_2 (8, 256, 1)x(128, 1, 1), Context 1, Stream 7, Device 0, CC 9.0
    Section: Command line profiler metrics
    -------------------------------------------------------- ----------- ------------
    Metric Name                                              Metric Unit Metric Value
    -------------------------------------------------------- ----------- ------------
    l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld.sum                    7,216
    l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_st.sum               58,776,341
    -------------------------------------------------------- ----------- ------------

  triton_tem_fused__to_copy_mul_2 (8, 256, 1)x(128, 1, 1), Context 1, Stream 7, Device 0, CC 9.0
    Section: Command line profiler metrics
    -------------------------------------------------------- ----------- ------------
    Metric Name                                              Metric Unit Metric Value
    -------------------------------------------------------- ----------- ------------
    l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld.sum                    7,586
    l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_st.sum               58,750,044
    -------------------------------------------------------- ----------- ------------

  triton_tem_fused__to_copy_mul_2 (8, 256, 1)x(128, 1, 1), Context 1, Stream 7, Device 0, CC 9.0
    Section: Command line profiler metrics
    -------------------------------------------------------- ----------- ------------
    Metric Name                                              Metric Unit Metric Value
    -------------------------------------------------------- ----------- ------------
    l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld.sum                    5,236
    l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_st.sum               58,797,745
    -------------------------------------------------------- ----------- ------------

  triton_tem_fused__to_copy_mul_2 (8, 256, 1)x(128, 1, 1), Context 1, Stream 7, Device 0, CC 9.0
    Section: Command line profiler metrics
    -------------------------------------------------------- ----------- ------------
    Metric Name                                              Metric Unit Metric Value
    -------------------------------------------------------- ----------- ------------
    l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld.sum                    6,156
    l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_st.sum               58,800,346
    -------------------------------------------------------- ----------- ------------

cc @chauhang @penguinwu @bertmaher @int3 @davidberard98 @nmacchioni @chenyang78 @embg @peterbell10 @aakhundov @zou3519 @ydwu4 @bdhirsh @Chillee @drisspg @yanboliang @BoyuanFeng

@danielvegamyhre danielvegamyhre self-assigned this Feb 17, 2025
@pytorch-bot pytorch-bot bot added module: higher order operators torch.cond and similar oncall: pt2 module: pt2-dispatcher PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op, labels Feb 17, 2025
@danielvegamyhre danielvegamyhre removed oncall: pt2 module: pt2-dispatcher PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op, module: higher order operators torch.cond and similar labels Feb 17, 2025
@drisspg
Copy link
Contributor
drisspg commented Feb 18, 2025

cc @davidberard98 on the bank conflicts and probably extra layout transformations

@colesbury colesbury added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Feb 18, 2025
@davidberard98 davidberard98 added upstream triton Upstream Triton Issue and removed triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Feb 18, 2025
@pytorch-bot pytorch-bot bot added the module: pt2-dispatcher PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op, label Feb 24, 2025
@jansel jansel added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Mar 29, 2025
@danielvegamyhre
Copy link
Contributor Author
danielvegamyhre commented May 1, 2025

@davidberard98 I tested w/ pytorch 2.7 which comes bundled with the newer triton 3.3.0 but the issue is still present. Any chance you have capacity to take a look now? It seems to be like a bug in how triton swizzles fp8 data perhaps

(flex) [danvm@devgpu007.eag6 ~/ml-perf-tools/flex_attention (main)]$ conda list | grep torch
pytorch-triton            3.3.0+git96316ce5          pypi_0    pypi
torch                     2.8.0.dev20250421+cu126          pypi_0    pypi
torchaudio                2.6.0.dev20250421+cu126          pypi_0    pypi
torchvision               0.22.0.dev20250421+cu126          pypi_0    pypi
(flex) [danvm@devgpu007.eag6 ~/ml-perf-tools/flex_attention (main)]$ python profile_flex.py --bf16 --fp8
2025-05-05 16:34:20,966 - flex_bench - INFO - Running benchmark: bf16
2025-05-05 16:34:23,856 - flex_bench - INFO - bf16: 424.6423101604277 us
2025-05-05 16:34:23,868 - flex_bench - INFO - Running benchmark: fp8e4m3
2025-05-05 16:34:25,329 - flex_bench - INFO - fp8e4m3: 515.4672531645572 us

@davidberard98 davidberard98 self-assigned this May 7, 2025
@davidberard98
Copy link
Contributor

I believe the issue is that we need the V tensor to be transposed in order to do the second dot op efficiently!

Reasoning: Unlike fp16/bf16, fp8 WGMMA requires the B matrix fragment to be transposed in shared memory before executing the WGMMA instruction.

When the V tensor is not transposed in global memory, this breaks pipelining: Triton can't issue async copies directly into shared memory (as async copies need to be at least 4 bytes wide per thread); and instead, it does a load into registers, and then single-byte stores into shared memory that cause a lot of bank conflicts.

@danielvegamyhre
Copy link
Contributor Author
danielvegamyhre commented May 7, 2025

That makes sense, great find! I hadn't considered this, since in pytorch for fp8 GEMMs this constraint is enforced and it would throw an error if the B tensor isn't in column-major memory layout. I guess triton is different in that it isn't a strict requirement, but rather just non-optimal.

I will look into enforcing column-major memory layout for the V tensor here, I should be able to figure something out. Thanks for taking a look at this.

pytorchmergebot pushed a commit that referenced this issue May 16, 2025
…o avoid perf degradation (#153357)

Fixes #147336

## Context

NCU analysis of the fp8 flex attention perf issue in #147336 showed an unexpected increase in shared memory access bank conflicts when loading the V tensor from HBM to SRAM.

Bringing this to the attention of triton developer @davidberard98 he identified the memory layout of the tensor in HBM to be causing non-pipelined loads into SRAM, causing the slowdown.

To summarize:

In flex attention when performing the FP8 GEMM `softmax_scores @ V` the right operand V must be in column-major memory layout. However, the `tl.load` of V blocks from HBM to SRAM cannot be pipelined if the V tensor isn't column-major in HBM already, leading to substantial performance degradation.

This is because triton does not perform async copies with the `cp.async` PTX instruction if the number of contiguous bytes is less than 4 (see [here](https://github.com/triton-lang/triton/blob/81f93f2c8ec7d20a1f8184def767edeaebeb6812/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp#L403)).

i.e., when loading 4 bytes of contiguous data from a tensor stored in row-major in HBM, we have to perform 4 separate non-contiguous writes to SRAM to place those bytes in their new location in the col-major layout in SRAM. Thus the load is not a candidate for pipelining w/ cp.async and just moves data to registers then performs a series of single byte stores.

## Fix summary
- To fix this, we should enforce memory layouts for Q, K, V in FlexAttention when fp8 is being used, to ensure they each exist in HBM in the necessary memory layout to facilitate pipelined loads into SRAM ahead of the FP8 GEMMs

## Benchmarks
Rerunning the repro we see fp8 runtime is reduced from 120% of bf16 to 76% of bf16 runtime.

Before fix:

```
(flex) [danvm@devgpu007.eag6 ~/ml-perf-tools/flex_attention (main)]$ rm -rf /tmp/torchinductor_${USER}; python profile_flex.py --bf16 --fp8
2025-05-11 19:07:33,402 - flex_bench - INFO - Running benchmark: bf16
2025-05-11 19:07:35,885 - flex_bench - INFO - bf16: 424.87228804347734 us
2025-05-11 19:07:35,893 - flex_bench - INFO - Running benchmark: fp8e4m3
2025-05-11 19:07:37,319 - flex_bench - INFO - fp8e4m3: 515.714000000001 us
```

After fix:
```
(flex) [danvm@devgpu007.eag6 ~/ml-perf-tools/flex_attention (main)]$ rm -rf /tmp/torchinductor_${USER}; python profile_flex.py --bf16 --fp8
2025-05-11 17:34:38,223 - flex_bench - INFO - Running benchmark: bf16
2025-05-11 17:34:41,157 - flex_bench - INFO - bf16: 423.4662032967036 us
2025-05-11 17:34:41,167 - flex_bench - INFO - Running benchmark: fp8e4m3
2025-05-11 17:34:42,917 - flex_bench - INFO - fp8e4m3: 326.3694803493453 us
```

Pull Request resolved: #153357
Approved by: https://github.com/ngimel, https://github.com/davidberard98
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: flex attention module: higher order operators torch.cond and similar module: pt2-dispatcher PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op, oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module upstream triton Upstream Triton Issue
Projects
None yet
6 participants
0