-
Notifications
You must be signed in to change notification settings - Fork 24.2k
Implement fast exp for AVX2 and AVX512 for the flash attention #151441
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/151441
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 5dfa3c6 with merge base 159e2f9 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@pytorchbot label "topic: not user facing" |
Thanks for your optimization! Need to do more validations before the PR lands, such as the three dynamo suites and LZ models, which I would follow. cc @mingfeima @leslie-fang-intel |
- Implement Fast Exponential Computation on SIMD Architectures A. Cristiano I. Malossi, Yves Ineichen, Costas Bekas, and Alessandro Curioni - AVX2 and AVX512 float only, up to 20% faster for mix precision flash attention than the current implementation. - For the other types legacy implementation. Precision: 1 ULP only valid in hybrid mode fp32 -> f16 due to the cast during the store operation in the flash attention: Benchmark: Machine Xeon 6972P, results in TOPs, Python forward pass flash attention numhead 16, Head dimension 64 |Seq. L.| PT | fexp | |-------|------|------| | 512 | 0.8 | 1.3 | | 1024 | 1.7 | 1.7 | | 2048 | 6 | 6.1 | | 4096 | 16 | 16.8 | | 8192 | 30.6 | 32.3 | | 16384 | 40 | 40.8 | | 32768 | 44.9 | 51.4 | | 65536 | 45.8 | 54.4 | numhead 16, Head dimension 128 |Seq. L.| PT | fexp | |-------|------|------| | 512 | 2.5 | 4.1 | | 1024 | 3.3 | 4 | | 2048 | 11.4 | 10.5 | | 4096 | 27.4 | 28.4 | | 8192 | 44.4 | 46 | | 16384 | 64.2 | 68.1 | | 32768 | 77.8 | 83 | | 65536 | 82.1 | 88.1 | numhead 16, Head dimension 256 |Seq. L.| PT | fexp | |-------|------| | 512 | 1.7 | 3.4 | | 1024 | 4.2 | 6.5 | | 2048 | 14.6 | 16.1 | | 4096 | 30.1 | 31.1 | | 8192 | 60 | 62 | | 16384 | 83.3 | 87.3 | | 32768 | 98.7 | 106 | | 65536 | 102.2| 107.1|
- retrigger the CI by ammending the commit message due to the modification of 09e8ff9 4days ago.
The validation for dynamo suites is done, and we do not see obvious accuracy/perf change for all dtypes, including bf16/fp16/fp32. Next, we will check the accuracy/perf on Stable Diffusion, Llama3.1-8b and VIT. cc @timocafe |
The model validation is still WIP, due to the lack of machine and the large dataset for accuracy. |
The model validation result is ready.
Thanks for your work! @timocafe cc @mingfeima |
Implement fexp for avx2 and avx512
Cristiano and all propose a clever exp using the IEEE representation with a fine control of the precision, especially useful
for mix computation of the flash attention.
A. Cristiano I. Malossi, Yves Ineichen, Costas Bekas, and Alessandro Curioni
than the current implementation.
Precision
1 ULP only valid in hybrid mode fp32 -> f16 due to the cast during the
store operation in the flash attention:
Benchmark
Machine Xeon 6972P, results in TOPs, Python forward pass flash attention
numhead 16, Head dimension 64
numhead 16, Head dimension 128
numhead 16, Head dimension 256
cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @jerryzh168