-
Notifications
You must be signed in to change notification settings - Fork 24.2k
add sbgemv dispatch in torch cpu flash attention #151108
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/151108
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ✅ No FailuresAs of commit c52849f with merge base d759a51 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@drisspg do you know who should look at this? |
@pytorchbot label "ciflow/linux-aarch64" |
this pr should be pending until openblas 0.3.30 is released. |
@pytorchbot label "topic: not user facing" |
Changes LGTM! |
yes, this what i'd like to do. but could someone maybe point out the proper place to do such uts? |
check this: test_linalg.py . I think we can add a relevant case here (if not already there) |
+1 on UT |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we make the description more clear and comment on the heuristic.
Also, as discussed above, possibly adding a UT
@parametrize("m", [32, 35, 36, 40, 64, 128]) | ||
@parametrize("k", [32, 35, 36, 40, 64, 128]) | ||
# NOTE: This is intended to cover sbgemv_ testcase in CPUBlas.cpp. | ||
def test_lowprecision_gemv_cpu(self, device, dtype, m, k): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! That's very thoughtful, I like that it covers both the transposed and non-transposed case. A minor comment.
Question: Would it make sense to split the transposed and non-transposed cases into separate tests for clarity and easier debugging?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you provide benchmarks of the before / after for a few relevant shapes?
Alternative approach for this is to enable the Path in OpenBLAS. @taoye9 already has a PR for this. Linking it here for visibility: |
Hi, sry we are pending this PR for a while to further investigate which is the best approach: i.e. inside openblas or pytorch. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM if it passes CI, but also would be good to enable it directly for torch.mv
call as wlel
@pytorchbot rebase |
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here |
Successfully rebased |
3695c3d
to
c52849f
Compare
Summary
This PR introduces a dispatch to the OpenBLAS sbgemv kernel in PyTorch CPU Flash Attention kernel when the query sequence length is 1.
Motivation
During the decoding phase in transformer models (e.g., for autoregressive inference), the shape of the query tensor often has sequence length = 1. Currently, this leads to dispatching A(m, k) * B(k, n) into the general sbgemm kernel, even when the operation is effectively a matrix-vector multiplication. This PR optimizes such cases by dispatching to sbgemv, which is better suited and shows measurable performance improvements.
Heuristic Consideration
Our heuristic ensures that the matmul is dispatched to sbgemv only when matrix A is multiplied by a vector B, which is the intended use case for GEMV operations. Also we limit the dispatch to transb == NoTranspose because when transb == Transpose, the leading dimension (lda) might not be 1. This causes the sbgemv kernel to handle non-contiguous memory, which performs poorly.
Benchmark result
Benchmarked using
torch.nn.functional.scaled_dot_product_attention
on Neoverse™ V1.Configuration:
OMP_NUM_THREADS=16
[1, 16, 1, 32]
[1, 16, 1500, 32]
[1, 16, 1500, 32]
Results:
sbgemm
sbgemv
Benchmark script
cc @malfet @snadampal @milpuz01 @aditew01 @nikhil-arm @fadara01