-
Notifications
You must be signed in to change notification settings - Fork 24.2k
Investigate FlexAttention performance degradation on low precision inputs #147336
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
cc @davidberard98 on the bank conflicts and probably extra layout transformations |
@davidberard98 I tested w/ pytorch 2.7 which comes bundled with the newer triton 3.3.0 but the issue is still present. Any chance you have capacity to take a look now? It seems to be like a bug in how triton swizzles fp8 data perhaps
|
I believe the issue is that we need the V tensor to be transposed in order to do the second dot op efficiently! Reasoning: Unlike fp16/bf16, fp8 WGMMA requires the B matrix fragment to be transposed in shared memory before executing the WGMMA instruction. When the V tensor is not transposed in global memory, this breaks pipelining: Triton can't issue async copies directly into shared memory (as async copies need to be at least 4 bytes wide per thread); and instead, it does a load into registers, and then single-byte stores into shared memory that cause a lot of bank conflicts. |
That makes sense, great find! I hadn't considered this, since in pytorch for fp8 GEMMs this constraint is enforced and it would throw an error if the B tensor isn't in column-major memory layout. I guess triton is different in that it isn't a strict requirement, but rather just non-optimal. I will look into enforcing column-major memory layout for the V tensor here, I should be able to figure something out. Thanks for taking a look at this. |
…o avoid perf degradation (#153357) Fixes #147336 ## Context NCU analysis of the fp8 flex attention perf issue in #147336 showed an unexpected increase in shared memory access bank conflicts when loading the V tensor from HBM to SRAM. Bringing this to the attention of triton developer @davidberard98 he identified the memory layout of the tensor in HBM to be causing non-pipelined loads into SRAM, causing the slowdown. To summarize: In flex attention when performing the FP8 GEMM `softmax_scores @ V` the right operand V must be in column-major memory layout. However, the `tl.load` of V blocks from HBM to SRAM cannot be pipelined if the V tensor isn't column-major in HBM already, leading to substantial performance degradation. This is because triton does not perform async copies with the `cp.async` PTX instruction if the number of contiguous bytes is less than 4 (see [here](https://github.com/triton-lang/triton/blob/81f93f2c8ec7d20a1f8184def767edeaebeb6812/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp#L403)). i.e., when loading 4 bytes of contiguous data from a tensor stored in row-major in HBM, we have to perform 4 separate non-contiguous writes to SRAM to place those bytes in their new location in the col-major layout in SRAM. Thus the load is not a candidate for pipelining w/ cp.async and just moves data to registers then performs a series of single byte stores. ## Fix summary - To fix this, we should enforce memory layouts for Q, K, V in FlexAttention when fp8 is being used, to ensure they each exist in HBM in the necessary memory layout to facilitate pipelined loads into SRAM ahead of the FP8 GEMMs ## Benchmarks Rerunning the repro we see fp8 runtime is reduced from 120% of bf16 to 76% of bf16 runtime. Before fix: ``` (flex) [danvm@devgpu007.eag6 ~/ml-perf-tools/flex_attention (main)]$ rm -rf /tmp/torchinductor_${USER}; python profile_flex.py --bf16 --fp8 2025-05-11 19:07:33,402 - flex_bench - INFO - Running benchmark: bf16 2025-05-11 19:07:35,885 - flex_bench - INFO - bf16: 424.87228804347734 us 2025-05-11 19:07:35,893 - flex_bench - INFO - Running benchmark: fp8e4m3 2025-05-11 19:07:37,319 - flex_bench - INFO - fp8e4m3: 515.714000000001 us ``` After fix: ``` (flex) [danvm@devgpu007.eag6 ~/ml-perf-tools/flex_attention (main)]$ rm -rf /tmp/torchinductor_${USER}; python profile_flex.py --bf16 --fp8 2025-05-11 17:34:38,223 - flex_bench - INFO - Running benchmark: bf16 2025-05-11 17:34:41,157 - flex_bench - INFO - bf16: 423.4662032967036 us 2025-05-11 17:34:41,167 - flex_bench - INFO - Running benchmark: fp8e4m3 2025-05-11 17:34:42,917 - flex_bench - INFO - fp8e4m3: 326.3694803493453 us ``` Pull Request resolved: #153357 Approved by: https://github.com/ngimel, https://github.com/davidberard98
Creating this issue to track my work investigating the root cause of unexpected slowdowns observed in flex attention using low precision input tensors.
TL;DR
Current investigation seems to point to the root cause being related to a huge increase in shared memory access bank conflicts. Evidence so far points to the loading of fp8 V blocks into SRAM being the problem.
Repro script
As a first step I wrote this repro script which runs benchmarks and optionally produces traces, for bf16 and fp8 dtypes.
Benchmark
Initial benchmarks show flex attention forward pass takes roughly ~1.39x longer using fp8 inputs versus bf16 inputs.
Triton kernel analysis
The main difference between the triton kernels generated by inductor for "compiled_flex" and "compiled_scale_flex" is the existence of the following lines of code which implement the score mod func. Nothing here looks problematic to me.
NCU
We can use
ncu
to analyze the specific kernel which implements flex attention:Speed of light bf16
Speed of light fp8
Uncoalesced shared memory access
Importantly, in the NCU output for fp8 we get a warning regarding uncoalesced shared memory accesses causing a excessive wavefronts. It seems likely this is related to the observed slowdown:
Next I generated some profiles for bf16 and fp8 to analyze in the NCU UI:
TORCH_LOGS="output_code" TORCH_LOGS_OUT="compile_logs/fp8_log.txt" ncu --set detailed -k regex:triton_tem_.* -o profiles/fp8-prof python3 profile_flex.py --fp8
Here I also observed the fp8 profile has uncoalesced shared access warnings which are not present in the bf16 profile:
Diving deeper, we can see the exact line of triton code where this is occurring:
Looking at the sampling counts, we can see the majority are flagged as "short scoreboard." In the NVIDIA docs we can see this usually means this is related to bank conflicts in shared memory load/store operations.
To confirm this, I ran some metric counts to measure the number of shared memory load/store bank conflicts for bf16 vs fp8. I observed an orders of magnitude more conflicts in fp8 than bf16, for both load and store operations:
Load and store conflicts bf16
Load and store conflicts fp8
cc @chauhang @penguinwu @bertmaher @int3 @davidberard98 @nmacchioni @chenyang78 @embg @peterbell10 @aakhundov @zou3519 @ydwu4 @bdhirsh @Chillee @drisspg @yanboliang @BoyuanFeng
The text was updated successfully, but these errors were encountered: