This repository documents a 16-step optimization journey targeting NVIDIA's Ampere architecture (A100, RTX 3090/4090), ultimately achieving performance parity with the industry-standard implementation.
The final iteration (Kernel 16) achieves:
- 99.2% of the official Flash Attention 2 performance on A100.
- 102.9% of official performance on RTX 3090.
- Benchmarks measured at sequence length 4096, .
The kernels are designed with a specific focus on high-performance deep learning primitives:
- Algorithm: Flash Attention 2 (Forward Pass, Non-Causal).
- Hardware: Optimized for CUDA Compute Capability 8.x (Ampere).
- Data Types: 16-bit (BF16/FP16) I/O with FP32 precision for softmax accumulation.
- Constraints: Head dimension of 128; sequence lengths divisible by block sizes (64-128).
src/: Source code for the final Kernel 16.previous_kernels/: Evolution of the project (Kernels 1–15).py/: Python packages forflash_helpersand benchmarking.
To build the flash_attention CUDA kernels and the accompanying Python utility suite:
# Build CUDA kernels
pip install --no-build-isolation .
# Build Python utilities and configuration helpers
pip install ./py
Testing requires an Ampere-class GPU. This script validates the kernels against the PyTorch reference implementation.
python py/flash_helpers/test/test.py
Performance is represented as a percentage of the Official Flash Attention 2 TFLOPs.
| Kernel Iteration | A100 | A100 | RTX 3090 | RTX 3090 |
|---|---|---|---|---|
| seq_len = 4096 | harm. mean* | seq_len = 4096 | harm. mean* | |
| 0. Reference (TFLOPs) | 186.4 | 174.0 | 67.29 | 66.2 |
| 1. Base Implementation | 15.8% | 16.6% | 49.5% | 49.8% |
| 2. Swizzling | 72.6% | 72.4% | 98.3% | 98.6% |
| 3. Eagerly Loading K & V Blocks | 77.6% | 79.9% | 99.4% | 100.0% |
| 4. Interleaving On-Chip LD/ST with Computation | 77.6% | 80.0% | 100.0% | 100.4% |
| 5. Double Buffering Shared Memory to Register File Loads | 76.8% | 79.1% | 99.7% | 100.3% |
| 6. Improving FP32 Throughput | 78.1% | 80.4% | 99.9% | 100.4% |
| 7. Auto-Tuning | 80.3% | 82.3% | 101.5% | 101.8% |
8. Reducing IADD3, LOP3, and SHF instructions |
87.8% | 88.9% | 101.7% | 101.2% |
9. Reducing IMAD.MOV.U32 and MOV instructions |
95.3% | 96.3% | 97.5% | 97.4% |
10. Removing CSRZ Instructions + Optimizing Initial Softmax Iteration |
93.9% | 95.0% | 102.9% | 102.3% |
| 11. Encoded Swizzling from the RF to SMEM | 95.2% | 96.7% | 102.8% | 102.3% |
| 12. Miscellaneous Code Changes | 95.3% | 97.0% | 102.8% | 102.3% |
| 13. Iterating Backwards | 97.6% | 98.8% | 101.5% | 101.2% |
| 14. Cache Configuration | 97.7% | 99.1% | 101.5% | 101.2% |
15. Tiling along d_head |
97.9% | 99.5% | 101.5% | 101.3% |
| 16. Static GMEM Stride | 99.2% | 100.4% | 100.9% | 100.7% |
The harmonic mean is taken over sequence lengths 512, 1024, 2048, 4096, 8192, 16384.