-
Notifications
You must be signed in to change notification settings - Fork 24.7k
[inductor] [cpp] Support vectorization for score and mask in FlexAttention CPU #143638
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/143638
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (3 Unrelated Failures)As of commit f7b901d with merge base 66fb10f ( BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Hi @drisspg could you please review this PR when you get time? |
layout=FixedLayout( | ||
device, | ||
dtype, | ||
size if size else [], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we create the subgraphs using scalars we dont any real size or stride information, why this change?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Before this PR, we create the subgraphs using scalars. In this PR, in order to generate vectorized code, we create subgraphs using tensors and thus we changed this function
https://github.com/pytorch/pytorch/blob/gh/chunyuan-w/3/head/torch/_inductor/kernel/flex_attention.py#L918-L922
Hi @drisspg thanks for your reply. I have addressed all the comments and could you please re-review it to see if it's good to land or any other changes needed? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry I went on PTO for a few days and thought I had approved, thanks for the ping
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Stack from ghstack (oldest at bottom):
Description
We generate vectorized kernel for score and mask in FlexAttention with this PR.
Modification
The main change include:
The original mask graph:
Benchmark
For q, k, v of shape:
[1, 32, 1024, 128]
, using 40 CPU cores, we observe over 20x speedup compared with the non vectorized version for bothis_causal
=False
andTrue
.Test plan
The existing FlexAttention UTs (
test/inductor/test_flex_attention.py
,test/inductor/test_flex_decoding.py
) can cover the change in 8000 this PR.Output code
Code before this PR is in scalar version:
Code after this PR will be vectorized:
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @aakhundov