-
Notifications
You must be signed in to change notification settings - Fork 25.3k
Description
Creating this issue to track my work investigating the root cause of unexpected slowdowns observed in flex attention using low precision input tensors.
TL;DR
Current investigation seems to point to the root cause being related to a huge increase in shared memory access bank conflicts. Evidence so far points to the loading of fp8 V blocks into SRAM being the problem.
Repro script
As a first step I wrote this repro script which runs benchmarks and optionally produces traces, for bf16 and fp8 dtypes.
Benchmark
Initial benchmarks show flex attention forward pass takes roughly ~1.39x longer using fp8 inputs versus bf16 inputs.
$ python3 profile_flex.py --fp8 --bf16
2025-02-16 21:51:55,038 - flex_bench - INFO - Running benchmark: bf16
2025-02-16 21:51:56,765 - flex_bench - INFO - bf16: 441.3840833333334 us
2025-02-16 21:51:56,772 - flex_bench - INFO - Running benchmark: fp8e4m3
2025-02-16 21:51:57,373 - flex_bench - INFO - fp8e4m3: 615.4808518518514 us
Triton kernel analysis
The main difference between the triton kernels generated by inductor for "compiled_flex" and "compiled_scale_flex" is the existence of the following lines of code which implement the score mod func. Nothing here looks problematic to me.
tmp0 = (qk).to(tl.float32)
tmp1 = tmp0 * tl.load(in_ptr8 + 0)
tmp2 = tmp1 * tl.load(in_ptr9 + 0)
post_mod_scores = tmp2
NCU
We can use ncu
to analyze the specific kernel which implements flex attention:
ncu --set detailed -k regex:triton_tem_.* python3 profile_flex.py --bf16
ncu --set detailed -k regex:triton_tem_.* python3 profile_flex.py --fp8
Speed of light bf16
triton_tem_fused_2 (8, 256, 1)x(256, 1, 1), Context 1, Stream 7, Device 0, CC 9.0
Section: GPU Speed Of Light Throughput
----------------------- ----------- ------------
Metric Name Metric Unit Metric Value
----------------------- ----------- ------------
DRAM Frequency Ghz 1.59
SM Frequency Ghz 1.24
Elapsed Cycles cycle 751,814
Memory Throughput % 43.38
DRAM Throughput % 17.69
Duration us 602.69
L1/TEX Cache Throughput % 45.31
L2 Cache Throughput % 21.36
SM Active Cycles cycle 719,559.25
Compute (SM) Throughput % 35.59
----------------------- ----------- ------------
Speed of light fp8
Section: GPU Speed Of Light Throughput
----------------------- ----------- ------------
Metric Name Metric Unit Metric Value
----------------------- ----------- ------------
DRAM Frequency Ghz 1.59
SM Frequency Ghz 1.23
Elapsed Cycles cycle 1,056,196
Memory Throughput % 72.38
DRAM Throughput % 8.70
Duration us 853.86
L1/TEX Cache Throughput % 74.56
L2 Cache Throughput % 9.74
SM Active Cycles cycle 1,022,350.08
Compute (SM) Throughput % 27.49
----------------------- ----------- ------------
Uncoalesced shared memory access
Importantly, in the NCU output for fp8 we get a warning regarding uncoalesced shared memory accesses causing a excessive wavefronts. It seems likely this is related to the observed slowdown:
OPT Est. Speedup: 60.51%
This kernel has uncoalesced shared accesses resulting in a total of 58720256 excessive wavefronts (63% of the
total 92856320 wavefronts). Check the L1 Wavefronts Shared Excessive table for the primary source locations.
The CUDA Best Practices Guide
(https://docs.nvidia.com/cuda/cuda-c-best-practices-guide/index.html#shared-memory-in-matrix-multiplication-c
-ab) has an example on optimizing shared memory accesses.
Next I generated some profiles for bf16 and fp8 to analyze in the NCU UI:
TORCH_LOGS="output_code" TORCH_LOGS_OUT="compile_logs/fp8_log.txt" ncu --set detailed -k regex:triton_tem_.* -o profiles/fp8-prof python3 profile_flex.py --fp8
Here I also observed the fp8 profile has uncoalesced shared access warnings which are not present in the bf16 profile:
Diving deeper, we can see the exact line of triton code where this is occurring:
Looking at the sampling counts, we can see the majority are flagged as "short scoreboard." In the NVIDIA docs we can see this usually means this is related to bank conflicts in shared memory load/store operations.
To confirm this, I ran some metric counts to measure the number of shared memory load/store bank conflicts for bf16 vs fp8. I observed an orders of magnitude more conflicts in fp8 than bf16, for both load and store operations:
Load and store conflicts bf16
triton_tem_fused_2 (8, 256, 1)x(256, 1, 1), Context 1, Stream 7, Device 0, CC 9.0
Section: Command line profiler metrics
-------------------------------------------------------- ----------- ------------
Metric Name Metric Unit Metric Value
-------------------------------------------------------- ----------- ------------
l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld.sum 0
l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_st.sum 111,863
-------------------------------------------------------- ----------- ------------
triton_tem_fused_2 (8, 256, 1)x(256, 1, 1), Context 1, Stream 7, Device 0, CC 9.0
Section: Command line profiler metrics
-------------------------------------------------------- ----------- ------------
Metric Name Metric Unit Metric Value
-------------------------------------------------------- ----------- ------------
l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld.sum 0
l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_st.sum 104,116
-------------------------------------------------------- ----------- ------------
triton_tem_fused_2 (8, 256, 1)x(256, 1, 1), Context 1, Stream 7, Device 0, CC 9.0
Section: Command line profiler metrics
-------------------------------------------------------- ----------- ------------
Metric Name Metric Unit Metric Value
-------------------------------------------------------- ----------- ------------
l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld.sum 0
l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_st.sum 114,396
-------------------------------------------------------- ----------- ------------
triton_tem_fused_2 (8, 256, 1)x(256, 1, 1), Context 1, Stream 7, Device 0, CC 9.0
Section: Command line profiler metrics
-------------------------------------------------------- ----------- ------------
Metric Name Metric Unit Metric Value
-------------------------------------------------------- ----------- ------------
l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld.sum 0
l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_st.sum 113,613
-------------------------------------------------------- ----------- ------------
triton_tem_fused_2 (8, 256, 1)x(256, 1, 1), Context 1, Stream 7, Device 0, CC 9.0
Section: Command line profiler metrics
-------------------------------------------------------- ----------- ------------
Metric Name Metric Unit Metric Value
-------------------------------------------------------- ----------- ------------
l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld.sum 0
l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_st.sum 106,008
-------------------------------------------------------- ----------- ------------
triton_tem_fused_2 (8, 256, 1)x(256, 1, 1), Context 1, Stream 7, Device 0, CC 9.0
Section: Command line profiler metrics
-------------------------------------------------------- ----------- ------------
Metric Name Metric Unit Metric Value
-------------------------------------------------------- ----------- ------------
l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld.sum
B335
0
l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_st.sum 102,859
-------------------------------------------------------- ----------- ------------
triton_tem_fused_2 (8, 256, 1)x(256, 1, 1), Context 1, Stream 7, Device 0, CC 9.0
Section: Command line profiler metrics
-------------------------------------------------------- ----------- ------------
Metric Name Metric Unit Metric Value
-------------------------------------------------------- ----------- ------------
l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld.sum 0
l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_st.sum 101,981
-------------------------------------------------------- ----------- ------------
triton_tem_fused_2 (8, 256, 1)x(256, 1, 1), Context 1, Stream 7, Device 0, CC 9.0
Section: Command line profiler metrics
-------------------------------------------------------- ----------- ------------
Metric Name Metric Unit Metric Value
-------------------------------------------------------- ----------- ------------
l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld.sum 0
l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_st.sum 104,583
-------------------------------------------------------- ----------- ------------
Load and store conflicts fp8
triton_tem_fused__to_copy_mul_2 (8, 256, 1)x(128, 1, 1), Context 1, Stream 7, Device 0, CC 9.0
Section: Command line profiler metrics
-------------------------------------------------------- ----------- ------------
Metric Name Metric Unit Metric Value
-------------------------------------------------------- ----------- ------------
l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld.sum 6,467
l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_st.sum 58,782,390
-------------------------------------------------------- ----------- ------------
triton_tem_fused__to_copy_mul_2 (8, 256, 1)x(128, 1, 1), Context 1, Stream 7, Device 0, CC 9.0
Section: Command line profiler metrics
-------------------------------------------------------- ----------- ------------
Metric Name Metric Unit Metric Value
-------------------------------------------------------- ----------- ------------
l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld.sum 5,698
l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_st.sum 58,771,364
-------------------------------------------------------- ----------- ------------
triton_tem_fused__to_copy_mul_2 (8, 256, 1)x(128, 1, 1), Context 1, Stream 7, Device 0, CC 9.0
Section: Command line profiler metrics
-------------------------------------------------------- ----------- ------------
Metric Name Metric Unit Metric Value
-------------------------------------------------------- ----------- ------------
l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld.sum 6,234
l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_st.sum 58,783,926
-------------------------------------------------------- ----------- ------------
triton_tem_fused__to_copy_mul_2 (8, 256, 1)x(128, 1, 1), Context 1, Stream 7, Device 0, CC 9.0
Section: Command line profiler metrics
-------------------------------------------------------- ----------- ------------
Metric Name Metric Unit Metric Value
-------------------------------------------------------- ----------- ------------
l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld.sum 5,518
l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_st.sum 58,800,274
-------------------------------------------------------- ----------- ------------
triton_tem_fused__to_copy_mul_2 (8, 256, 1)x(128, 1, 1), Context 1, Stream 7, Device 0, CC 9.0
Section: Command line profiler metrics
-------------------------------------------------------- ----------- ------------
Metric Name Metric Unit Metric Value
-------------------------------------------------------- ----------- ------------
l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld.sum 7,216
l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_st.sum 58,776,341
-------------------------------------------------------- ----------- ------------
triton_tem_fused__to_copy_mul_2 (8, 256, 1)x(128, 1, 1), Context 1, Stream 7, Device 0, CC 9.0
Section: Command line profiler metrics
-------------------------------------------------------- ----------- ------------
Metric Name Metric Unit Metric Value
-------------------------------------------------------- ----------- ------------
l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld.sum 7,586
l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_st.sum 58,750,044
-------------------------------------------------------- ----------- ------------
triton_tem_fused__to_copy_mul_2 (8, 256, 1)x(128, 1, 1), Context 1, Stream 7, Device 0, CC 9.0
Section: Command line profiler metrics
-------------------------------------------------------- ----------- ------------
Metric Name Metric Unit Metric Value
-------------------------------------------------------- ----------- ------------
l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld.sum 5,236
l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_st.sum 58,797,745
-------------------------------------------------------- ----------- ------------
triton_tem_fused__to_copy_mul_2 (8, 256, 1)x(128, 1, 1), Context 1, Stream 7, Device 0, CC 9.0
Section: Command line profiler metrics
-------------------------------------------------------- ----------- ------------
Metric Name Metric Unit Metric Value
-------------------------------------------------------- ----------- ------------
l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld.sum 6,156
l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_st.sum 58,800,346
-------------------------------------------------------- ----------- ------------
cc @chauhang @penguinwu @bertmaher @int3 @davidberard98 @nmacchioni @chenyang78 @embg @peterbell10 @aakhundov @zou3519 @ydwu4 @bdhirsh @Chillee @drisspg @yanboliang @BoyuanFeng