E5C7 GitHub - kalyani-25/Reimplementation_flash-attention-from-scratch: 16-step CUDA optimization of FlashAttention-2 achieving 99.2% of official performance on A100 — Ampere architecture · GitHub
[go: up one dir, main page]

Skip to content

kalyani-25/Reimplementation_flash-attention-from-scratch

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

1 Commit
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Flash Attention from Scratch: Ampere Optimization

This repository documents a 16-step optimization journey targeting NVIDIA's Ampere architecture (A100, RTX 3090/4090), ultimately achieving performance parity with the industry-standard implementation.

🚀 Performance Highlights

The final iteration (Kernel 16) achieves:

  • 99.2% of the official Flash Attention 2 performance on A100.
  • 102.9% of official performance on RTX 3090.
  • Benchmarks measured at sequence length 4096, .

🛠️ Features & Scope

The kernels are designed with a specific focus on high-performance deep learning primitives:

  • Algorithm: Flash Attention 2 (Forward Pass, Non-Causal).
  • Hardware: Optimized for CUDA Compute Capability 8.x (Ampere).
  • Data Types: 16-bit (BF16/FP16) I/O with FP32 precision for softmax accumulation.
  • Constraints: Head dimension of 128; sequence lengths divisible by block sizes (64-128).

📂 Project Structure

  • src/: Source code for the final Kernel 16.
  • previous_kernels/: Evolution of the project (Kernels 1–15).
  • py/: Python packages for flash_helpers and benchmarking.

📦 Installation

To build the flash_attention CUDA kernels and the accompanying Python utility suite:

# Build CUDA kernels
pip install --no-build-isolation .

# Build Python utilities and configuration helpers
pip install ./py

🧪 Testing

Testing requires an Ampere-class GPU. This script validates the kernels against the PyTorch reference implementation.

python py/flash_helpers/test/test.py

📈 Optimization Roadmap

Performance is represented as a percentage of the Official Flash Attention 2 TFLOPs.

Kernel Iteration A100 A100 RTX 3090 RTX 3090
seq_len = 4096 harm. mean* seq_len = 4096 harm. mean*
0. Reference (TFLOPs) 186.4 174.0 67.29 66.2
1. Base Implementation 15.8% 16.6% 49.5% 49.8%
2. Swizzling 72.6% 72.4% 98.3% 98.6%
3. Eagerly Loading K & V Blocks 77.6% 79.9% 99.4% 100.0%
4. Interleaving On-Chip LD/ST with Computation 77.6% 80.0% 100.0% 100.4%
5. Double Buffering Shared Memory to Register File Loads 76.8% 79.1% 99.7% 100.3%
6. Improving FP32 Throughput 78.1% 80.4% 99.9% 100.4%
7. Auto-Tuning 80.3% 82.3% 101.5% 101.8%
8. Reducing IADD3, LOP3, and SHF instructions 87.8% 88.9% 101.7% 101.2%
9. Reducing IMAD.MOV.U32 and MOV instructions 95.3% 96.3% 97.5% 97.4%
10. Removing CSRZ Instructions + Optimizing Initial Softmax Iteration 93.9% 95.0% 102.9% 102.3%
11. Encoded Swizzling from the RF to SMEM 95.2% 96.7% 102.8% 102.3%
12. Miscellaneous Code Changes 95.3% 97.0% 102.8% 102.3%
13. Iterating Backwards 97.6% 98.8% 101.5% 101.2%
14. Cache Configuration 97.7% 99.1% 101.5% 101.2%
15. Tiling along d_head 97.9% 99.5% 101.5% 101.3%
16. Static GMEM Stride 99.2% 100.4% 100.9% 100.7%

The harmonic mean is taken over sequence lengths 512, 1024, 2048, 4096, 8192, 16384.

About

16-step CUDA optimization of FlashAttention-2 achieving 99.2% of official performance on A100 — Ampere architecture

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 
0