8000 Investigate FlexAttention performance degradation on low precision inputs · Issue #147336 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content
Investigate FlexAttention performance degradation on low precision inputs #147336
@danielvegamyhre

Description

@danielvegamyhre

Creating this issue to track my work investigating the root cause of unexpected slowdowns observed in flex attention using low precision input tensors.

TL;DR

Current investigation seems to point to the root cause being related to a huge increase in shared memory access bank conflicts. Evidence so far points to the loading of fp8 V blocks into SRAM being the problem.

Repro script

As a first step I wrote this repro script which runs benchmarks and optionally produces traces, for bf16 and fp8 dtypes.

Benchmark

Initial benchmarks show flex attention forward pass takes roughly ~1.39x longer using fp8 inputs versus bf16 inputs.

$ python3 profile_flex.py --fp8 --bf16
2025-02-16 21:51:55,038 - flex_bench - INFO - Running benchmark: bf16
2025-02-16 21:51:56,765 - flex_bench - INFO - bf16: 441.3840833333334 us
2025-02-16 21:51:56,772 - flex_bench - INFO - Running benchmark: fp8e4m3
2025-02-16 21:51:57,373 - flex_bench - INFO - fp8e4m3: 615.4808518518514 us

Triton kernel analysis

The main difference between the triton kernels generated by inductor for "compiled_flex" and "compiled_scale_flex" is the existence of the following lines of code which implement the score mod func. Nothing here looks problematic to me.

    tmp0 = (qk).to(tl.float32)
    tmp1 = tmp0 * tl.load(in_ptr8 + 0)
    tmp2 = tmp1 * tl.load(in_ptr9 + 0)
    post_mod_scores = tmp2

NCU

We can use ncu to analyze the specific kernel which implements flex attention:

ncu --set detailed -k regex:triton_tem_.* python3 profile_flex.py --bf16
ncu --set detailed -k regex:triton_tem_.* python3 profile_flex.py --fp8 

Speed of light bf16

  triton_tem_fused_2 (8, 256, 1)x(256, 1, 1), Context 1, Stream 7, Device 0, CC 9.0
    Section: GPU Speed Of Light Throughput
    ----------------------- ----------- ------------
    Metric Name             Metric Unit Metric Value
    ----------------------- ----------- ------------
    DRAM Frequency                  Ghz         1.59
    SM Frequency                    Ghz         1.24
    Elapsed Cycles                cycle      751,814
    Memory Throughput                 %        43.38
    DRAM Throughput                   %        17.69
    Duration                         us       602.69
    L1/TEX Cache Throughput           %        45.31
    L2 Cache Throughput               %        21.36
    SM Active Cycles              cycle   719,559.25
    Compute (SM) Throughput           %        35.59
    ----------------------- ----------- ------------

Speed of light fp8

    Section: GPU Speed Of Light Throughput
    ----------------------- ----------- ------------
    Metric Name             Metric Unit Metric Value
    ----------------------- ----------- ------------
    DRAM Frequency                  Ghz         1.59
    SM Frequency                    Ghz         1.23
    Elapsed Cycles                cycle    1,056,196
    Memory Throughput                 %        72.38
    DRAM Throughput                   %         8.70
    Duration                         us       853.86
    L1/TEX Cache Throughput           %        74.56
    L2 Cache Throughput               %         9.74
    SM Active Cycles              cycle 1,022,350.08
    Compute (SM) Throughput           %        27.49
    ----------------------- ----------- ------------

Uncoalesced shared memory access

Importantly, in the NCU output for fp8 we get a warning regarding uncoalesced shared memory accesses causing a excessive wavefronts. It seems likely this is related to the observed slowdown:

    OPT   Est. Speedup: 60.51%                                                                                          
          This kernel has uncoalesced shared accesses resulting in a total of 58720256 excessive wavefronts (63% of the 
          total 92856320 wavefronts). Check the L1 Wavefronts Shared Excessive table for the primary source locations.  
          The CUDA Best Practices Guide                                                                                 
           (https://docs.nvidia.com/cuda/cuda-c-best-practices-guide/index.html#shared-memory-in-matrix-multiplication-c
          -ab) has an example on optimizing shared memory accesses.             

Next I generated some profiles for bf16 and fp8 to analyze in the NCU UI:

TORCH_LOGS="output_code" TORCH_LOGS_OUT="compile_logs/fp8_log.txt" ncu --set detailed -k regex:triton_tem_.* -o profiles/fp8-prof python3 profile_flex.py --fp8

Here I also observed the fp8 profile has uncoalesced shared access warnings which are not present in the bf16 profile:

Image

Diving deeper, we can see the exact line of triton code where this is occurring:

Image

Looking at the sampling counts, we can see the majority are flagged as "short scoreboard." In the NVIDIA docs we can see this usually means this is related to bank conflicts in shared memory load/store operations.

Image

To confirm this, I ran some metric counts to measure the number of shared memory load/store bank conflicts for bf16 vs fp8. I observed an orders of magnitude more conflicts in fp8 than bf16, for both load and store operations:

Load and store conflicts bf16

  triton_tem_fused_2 (8, 256, 1)x(256, 1, 1), Context 1, Stream 7, Device 0, CC 9.0
    Section: Command line profiler metrics
    -------------------------------------------------------- ----------- ------------
    Metric Name                                              Metric Unit Metric Value
    -------------------------------------------------------- ----------- ------------
    l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld.sum                        0
    l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_st.sum                  111,863
    -------------------------------------------------------- ----------- ------------

  triton_tem_fused_2 (8, 256, 1)x(256, 1, 1), Context 1, Stream 7, Device 0, CC 9.0
    Section: Command line profiler metrics
    -------------------------------------------------------- ----------- ------------
    Metric Name                                              Metric Unit Metric Value
    -------------------------------------------------------- ----------- ------------
    l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld.sum                        0
    l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_st.sum                  104,116
    -------------------------------------------------------- ----------- ------------

  triton_tem_fused_2 (8, 256, 1)x(256, 1, 1), Context 1, Stream 7, Device 0, CC 9.0
    Section: Command line profiler metrics
    -------------------------------------------------------- ----------- ------------
    Metric Name                                              Metric Unit Metric Value
    -------------------------------------------------------- ----------- ------------
    l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld.sum                        0
    l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_st.sum                  114,396
    -------------------------------------------------------- ----------- ------------

  triton_tem_fused_2 (8, 256, 1)x(256, 1, 1), Context 1, Stream 7, Device 0, CC 9.0
    Section: Command line profiler metrics
    -------------------------------------------------------- ----------- ------------
    Metric Name                                              Metric Unit Metric Value
    -------------------------------------------------------- ----------- ------------
    l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld.sum                        0
    l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_st.sum                  113,613
    -------------------------------------------------------- ----------- ------------

  triton_tem_fused_2 (8, 256, 1)x(256, 1, 1), Context 1, Stream 7, Device 0, CC 9.0
    Section: Command line profiler metrics
    -------------------------------------------------------- ----------- ------------
    Metric Name                                              Metric Unit Metric Value
    -------------------------------------------------------- ----------- ------------
    l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld.sum                        0
    l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_st.sum                  106,008
    -------------------------------------------------------- ----------- ------------

  triton_tem_fused_2 (8, 256, 1)x(256, 1, 1), Context 1, Stream 7, Device 0, CC 9.0
    Section: Command line profiler metrics
    -------------------------------------------------------- ----------- ------------
    Metric Name                                              Metric Unit Metric Value
    -------------------------------------------------------- ----------- ------------
    l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld.sum    
B335
                    0
    l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_st.sum                  102,859
    -------------------------------------------------------- ----------- ------------

  triton_tem_fused_2 (8, 256, 1)x(256, 1, 1), Context 1, Stream 7, Device 0, CC 9.0
    Section: Command line profiler metrics
    -------------------------------------------------------- ----------- ------------
    Metric Name                                              Metric Unit Metric Value
    -------------------------------------------------------- ----------- ------------
    l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld.sum                        0
    l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_st.sum                  101,981
    -------------------------------------------------------- ----------- ------------

  triton_tem_fused_2 (8, 256, 1)x(256, 1, 1), Context 1, Stream 7, Device 0, CC 9.0
    Section: Command line profiler metrics
    -------------------------------------------------------- ----------- ------------
    Metric Name                                              Metric Unit Metric Value
    -------------------------------------------------------- ----------- ------------
    l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld.sum                        0
    l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_st.sum                  104,583
    -------------------------------------------------------- ----------- ------------

Load and store conflicts fp8

 triton_tem_fused__to_copy_mul_2 (8, 256, 1)x(128, 1, 1), Context 1, Stream 7, Device 0, CC 9.0
    Section: Command line profiler metrics
    -------------------------------------------------------- ----------- ------------
    Metric Name                                              Metric Unit Metric Value
    -------------------------------------------------------- ----------- ------------
    l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld.sum                    6,467
    l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_st.sum               58,782,390
    -------------------------------------------------------- ----------- ------------

  triton_tem_fused__to_copy_mul_2 (8, 256, 1)x(128, 1, 1), Context 1, Stream 7, Device 0, CC 9.0
    Section: Command line profiler metrics
    -------------------------------------------------------- ----------- ------------
    Metric Name                                              Metric Unit Metric Value
    -------------------------------------------------------- ----------- ------------
    l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld.sum                    5,698
    l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_st.sum               58,771,364
    -------------------------------------------------------- ----------- ------------

  triton_tem_fused__to_copy_mul_2 (8, 256, 1)x(128, 1, 1), Context 1, Stream 7, Device 0, CC 9.0
    Section: Command line profiler metrics
    -------------------------------------------------------- ----------- ------------
    Metric Name                                              Metric Unit Metric Value
    -------------------------------------------------------- ----------- ------------
    l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld.sum                    6,234
    l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_st.sum               58,783,926
    -------------------------------------------------------- ----------- ------------

  triton_tem_fused__to_copy_mul_2 (8, 256, 1)x(128, 1, 1), Context 1, Stream 7, Device 0, CC 9.0
    Section: Command line profiler metrics
    -------------------------------------------------------- ----------- ------------
    Metric Name                                              Metric Unit Metric Value
    -------------------------------------------------------- ----------- ------------
    l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld.sum                    5,518
    l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_st.sum               58,800,274
    -------------------------------------------------------- ----------- ------------

  triton_tem_fused__to_copy_mul_2 (8, 256, 1)x(128, 1, 1), Context 1, Stream 7, Device 0, CC 9.0
    Section: Command line profiler metrics
    -------------------------------------------------------- ----------- ------------
    Metric Name                                              Metric Unit Metric Value
    -------------------------------------------------------- ----------- ------------
    l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld.sum                    7,216
    l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_st.sum               58,776,341
    -------------------------------------------------------- ----------- ------------

  triton_tem_fused__to_copy_mul_2 (8, 256, 1)x(128, 1, 1), Context 1, Stream 7, Device 0, CC 9.0
    Section: Command line profiler metrics
    -------------------------------------------------------- ----------- ------------
    Metric Name                                              Metric Unit Metric Value
    -------------------------------------------------------- ----------- ------------
    l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld.sum                    7,586
    l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_st.sum               58,750,044
    -------------------------------------------------------- ----------- ------------

  triton_tem_fused__to_copy_mul_2 (8, 256, 1)x(128, 1, 1), Context 1, Stream 7, Device 0, CC 9.0
    Section: Command line profiler metrics
    -------------------------------------------------------- ----------- ------------
    Metric Name                                              Metric Unit Metric Value
    -------------------------------------------------------- ----------- ------------
    l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld.sum                    5,236
    l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_st.sum               58,797,745
    -------------------------------------------------------- ----------- ------------

  triton_tem_fused__to_copy_mul_2 (8, 256, 1)x(128, 1, 1), Context 1, Stream 7, Device 0, CC 9.0
    Section: Command line profiler metrics
    -------------------------------------------------------- ----------- ------------
    Metric Name                                              Metric Unit Metric Value
    -------------------------------------------------------- ----------- ------------
    l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld.sum                    6,156
    l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_st.sum               58,800,346
    -------------------------------------------------------- ----------- ------------

cc @chauhang @penguinwu @bertmaher @int3 @davidberard98 @nmacchioni @chenyang78 @embg @peterbell10 @aakhundov @zou3519 @ydwu4 @bdhirsh @Chillee @drisspg @yanboliang @BoyuanFeng

Metadata

Metadata

Labels

module: flex attentionmodule: higher order operatorstorch.cond and similarmodule: pt2-dispatcherPT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op,oncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleupstream tritonUpstream Triton Issue

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions

    0