8000 Implement fast exp for AVX2 and AVX512 for the flash attention by timocafe · Pull Request #151441 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Implement fast exp for AVX2 and AVX512 for the flash attention #151441

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

timocafe
Copy link
@timocafe timocafe commented Apr 16, 2025

Implement fexp for avx2 and avx512

Cristiano and all propose a clever exp using the IEEE representation with a fine control of the precision, especially useful
for mix computation of the flash attention.

  • Implement Fast Exponential Computation on SIMD Architectures
    A. Cristiano I. Malossi, Yves Ineichen, Costas Bekas, and Alessandro Curioni
  • AVX2 and AVX512 float only, up to 20% faster for mix precision flash attention
    than the current implementation.
  • For the other types legacy implementation.

Precision

1 ULP only valid in hybrid mode fp32 -> f16 due to the cast during the
store operation in the flash attention:

Benchmark

Machine Xeon 6972P, results in TOPs, Python forward pass flash attention

numhead 16, Head dimension 64

Seq. L. PT fexp
512 0.8 1.3
1024 1.7 1.7
2048 6 6.1
4096 16 16.8
8192 30.6 32.3
16384 40 40.8
32768 44.9 51.4
65536 45.8 54.4

numhead 16, Head dimension 128

< 8000 /thead>
Seq. L. PT fexp
512 2.5 4.1
1024 3.3 4
2048 11.4 10.5
4096 27.4 28.4
8192 44.4 46
16384 64.2 68.1
32768 77.8 83
65536 82.1 88.1

numhead 16, Head dimension 256

Seq. L. PT fexp
512 1.7 3.4
1024 4.2 6.5
2048 14.6 16.1
4096 30.1 31.1
8192 60 62
16384 83.3 87.3
32768 98.7 106
65536 102.2 107.1

cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @jerryzh168

Copy link
pytorch-bot bot commented Apr 16, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/151441

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 5dfa3c6 with merge base 159e2f9 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the module: cpu CPU specific problem (e.g., perf, algorithm) label Apr 16, 2025
@timocafe
Copy link
Author

@pytorchbot label "topic: not user facing"

@pytorch-bot pytorch-bot bot added the topic: not user facing topic category label Apr 16, 2025
@drisspg drisspg added the module: sdpa All things related to torch.nn.functional.scaled_dot_product_attentiion label Apr 16, 2025
@Valentine233
Copy link
Collaborator

Thanks for your optimization! Need to do more validations before the PR lands, such as the three dynamo suites and LZ models, which I would follow. cc @mingfeima @leslie-fang-intel

@albanD albanD requested a review from drisspg April 17, 2025 13:39
@albanD albanD added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Apr 17, 2025
@mingfeima mingfeima moved this to In Progress in PyTorch Intel Apr 18, 2025
- Implement Fast Exponential Computation on SIMD Architectures
  A. Cristiano I. Malossi, Yves Ineichen, Costas Bekas, and Alessandro Curioni
- AVX2 and AVX512 float only, up to 20% faster for mix precision flash attention
  than the current implementation.
- For the other types legacy implementation.

Precision: 1 ULP only valid in hybrid mode fp32 -> f16 due to the cast during the
store operation in the flash attention:

Benchmark:

Machine Xeon 6972P, results in TOPs, Python forward pass flash attention

numhead 16, Head dimension 64

|Seq. L.| PT   | fexp |
|-------|------|------|
| 512   | 0.8  | 1.3  |
| 1024  | 1.7  | 1.7  |
| 2048  | 6    | 6.1  |
| 4096  | 16   | 16.8 |
| 8192  | 30.6 | 32.3 |
| 16384 | 40   | 40.8 |
| 32768 | 44.9 | 51.4 |
| 65536 | 45.8 | 54.4 |

numhead 16, Head dimension 128

|Seq. L.| PT   | fexp |
|-------|------|------|
| 512   | 2.5  | 4.1  |
| 1024  | 3.3  | 4    |
| 2048  | 11.4 | 10.5 |
| 4096  | 27.4 | 28.4 |
| 8192  | 44.4 | 46   |
| 16384 | 64.2 | 68.1 |
| 32768 | 77.8 | 83   |
| 65536 | 82.1 | 88.1 |

numhead 16, Head dimension 256

|Seq. L.| PT   | fexp |
|-------|------|
| 512   | 1.7  | 3.4  |
| 1024  | 4.2  | 6.5  |
| 2048  | 14.6 | 16.1 |
| 4096  | 30.1 | 31.1 |
| 8192  | 60   | 62   |
| 16384 | 83.3 | 87.3 |
| 32768 | 98.7 | 106  |
| 65536 | 102.2| 107.1|
- retrigger the CI by ammending the commit message
  due to the modification of 09e8ff9 4days ago.
@Valentine233
Copy link
Collaborator
Valentine233 commented Apr 24, 2025

The validation for dynamo suites is done, and we do not see obvious accuracy/perf change for all dtypes, including bf16/fp16/fp32. Next, we will check the accuracy/perf on Stable Diffusion, Llama3.1-8b and VIT. cc @timocafe

@Valentine233
Copy link
Collaborator

The model validation is still WIP, due to the lack of machine and the large dataset for accuracy.

8000

@Valentine233
Copy link
Collaborator
Valentine233 commented May 14, 2025

The model validation result is ready.

  • Accuracy: All is good.
  • Performance: For stable diffusion v2.1, we see an improvement of 5% for BF16 and 4% for FP16. No other obvious impacts.

Thanks for your work! @timocafe cc @mingfeima

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: cpu CPU specific problem (e.g., perf, algorithm) module: sdpa All things related to torch.nn.functional.scaled_dot_product_attentiion open source topic: not user facing topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
Status: In Progress
Development

Successfully merging this pull request may close these issues.

6 participants
0