8000 add sbgemv dispatch in torch cpu flash attention by taoye9 · Pull Request #151108 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

add sbgemv dispatch in torch cpu flash attention #151108

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

taoye9
Copy link
@taoye9 taoye9 commented Apr 11, 2025

Summary

This PR introduces a dispatch to the OpenBLAS sbgemv kernel in PyTorch CPU Flash Attention kernel when the query sequence length is 1.

Motivation

During the decoding phase in transformer models (e.g., for autoregressive inference), the shape of the query tensor often has sequence length = 1. Currently, this leads to dispatching A(m, k) * B(k, n) into the general sbgemm kernel, even when the operation is effectively a matrix-vector multiplication. This PR optimizes such cases by dispatching to sbgemv, which is better suited and shows measurable performance improvements.

Heuristic Consideration

Our heuristic ensures that the matmul is dispatched to sbgemv only when matrix A is multiplied by a vector B, which is the intended use case for GEMV operations. Also we limit the dispatch to transb == NoTranspose because when transb == Transpose, the leading dimension (lda) might not be 1. This causes the sbgemv kernel to handle non-contiguous memory, which performs poorly.

Benchmark result

Benchmarked using torch.nn.functional.scaled_dot_product_attention on Neoverse™ V1.

Configuration:

  • OMP_NUM_THREADS=16
  • Tensor shapes:
    • Query: [1, 16, 1, 32]
    • Key: [1, 16, 1500, 32]
    • Value: [1, 16, 1500, 32]

Results:

Kernel Latency (µs) Speedup
sbgemm 121.700
8000 sbgemv 104.663 ~16%

Benchmark script

import torch
import time
import numpy as np
import math
from torch.profiler import profile, record_function, ProfilerActivity

class SimpleAttentionModel(torch.nn.Module):
    def __init__(self, query, key, value):
        super(SimpleAttentionModel, self).__init__()
        self.query = query
        self.key = key
        self.value = value

    def forward(self, attn_mask=None):
        torch.nn.functional.scaled_dot_product_attention(
                    self.query,
                    self.key,
                    self.value,
                    attn_mask=attn_mask)


# implementation run for BertSdpaSelfAttention
def bench_sdpa(batch_size = 1, num_attention_heads = 16, sequence_length = 142, query_sequence_length = 142 , hidden_size=1024, precision=torch.float32):
    with torch.no_grad():
    
        attention_head_size = int(hidden_size / num_attention_heads)
    
        query = torch.rand(size=(batch_size, num_attention_heads, query_sequence_length, attention_head_size), dtype=precision)
        key = torch.rand(size=(batch_size, num_attention_heads, sequence_length, attention_head_size), dtype=precision)
        value = torch.rand(size=(batch_size, num_attention_heads, sequence_length, attention_head_size), dtype=precision)
         
        model = SimpleAttentionModel(query, key, value)
        model.eval()
        #model = torch.nn.utils.pack_linear.pack_linear_weights(model)
        
        for _ in range(100):
            model()

        times = []
        n_iters = 10000
        for _ in range(n_iters):
            s = time.time_ns()
            model()
            times.append((time.time_ns() - s) / 1e3)
        
        min_times = np.min(times)
        mean_times = np.mean(times)
        print(f"Min Times = {min_times} us")
        print(f"Mean Times = {mean_times} us")
        # print("Times = ", times)




if __name__ == "__main__":
    batch_size = 1
    num_attention_heads = 16
    sequence_length = 1500
    query_sequence_length = 1
    hidden_size=512

    print("BF16 mode:")
    with profile(activities=[ProfilerActivity.CPU], record_shapes=True) as prof:
        with record_function("model_inference"):
            bench_sdpa(batch_size = batch_size, num_attention_heads = num_attention_heads, sequence_length = sequence_length, query_sequence_length = query_sequence_length, hidden_size = hidden_size, precision=torch.bfloat16)
    profile_data = prof.key_averages(group_by_input_shape=True).table(sort_by="cpu_time_total")
    print(profile_data)

cc @malfet @snadampal @milpuz01 @aditew01 @nikhil-arm @fadara01

Copy link
pytorch-bot bot commented Apr 11, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/151108

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ No Failures

As of commit c52849f with merge base d759a51 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Copy link
linux-foundation-easycla bot commented Apr 11, 2025

CLA Signed

The committers listed above are authorized under a signed CLA.

@soulitzer soulitzer requested a review from drisspg April 11, 2025 23:08
@soulitzer soulitzer added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Apr 11, 2025
@soulitzer
Copy link
Contributor

@drisspg do you know who should look at this?

@aditew01 aditew01 added the module: arm Related to ARM architectures builds of PyTorch. Includes Apple M1 label Apr 14, 2025
@fadara01
Copy link
Collaborator

@pytorchbot label "ciflow/linux-aarch64"

@pytorch-bot pytorch-bot bot added the ciflow/linux-aarch64 linux aarch64 CI workflow label Apr 14, 2025
@taoye9
Copy link
Author
taoye9 commented Apr 14, 2025

this pr should be pending until openblas 0.3.30 is released.

@nikhil-arm nikhil-arm requested a review from aditew01 April 14, 2025 10:16
@fadara01
Copy link
Collaborator

@pytorchbot label "topic: not user facing"

@pytorch-bot pytorch-bot bot added the topic: not user facing topic category label Apr 14, 2025
@aditew01
Copy link
Collaborator

Changes LGTM!
Can we make sure there's a UT in place for this / make sure this will be tested with existing test cases?

@taoye9
Copy link
Author
taoye9 commented Apr 14, 2025

Changes LGTM! Can we make sure there's a UT in place for this / make sure this will be tested with existing test cases?

yes, this what i'd like to do. but could someone maybe point out the proper place to do such uts?

@aditew01
Copy link
Collaborator
aditew01 commented Apr 14, 2025

yes, this what i'd like to do. but could someone maybe point out the proper place to do such uts?

check this: test_linalg.py . I think we can add a relevant case here (if not already there)

@drisspg
Copy link
Contributor
drisspg commented Apr 14, 2025

+1 on UT

Copy link
Collaborator
@aditew01 aditew01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make the description more clear and comment on the heuristic.
Also, as discussed above, possibly adding a UT

@pytorch-bot pytorch-bot bot removed the ciflow/linux-aarch64 linux aarch64 CI workflow label Apr 15, 2025
@parametrize("m", [32, 35, 36, 40, 64, 128])
@parametrize("k", [32, 35, 36, 40, 64, 128])
# NOTE: This is intended to cover sbgemv_ testcase in CPUBlas.cpp.
def test_lowprecision_gemv_cpu(self, device, dtype, m, k):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! That's very thoughtful, I like that it covers both the transposed and non-transposed case. A minor comment.

Question: Would it make sense to split the transposed and non-transposed cases into separate tests for clarity and easier debugging?

Copy link
Collaborator
@lezcano lezcano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you provide benchmarks of the before / after for a few relevant shapes?

@aditew01
Copy link
Collaborator

Alternative approach for this is to enable the Path in OpenBLAS. @taoye9 already has a PR for this. Linking it here for visibility:
OpenMathLib/OpenBLAS#5260

@taoye9
Copy link
Author
taoye9 commented May 13, 2025

Can you provide benchmarks of the before / after for a few relevant shapes?

Hi, sry we are pending this PR for a while to further investigate which is the best approach: i.e. inside openblas or pytorch.

@malfet malfet added the ciflow/trunk Trigger trunk jobs on your pull request label May 14, 2025
Copy link
Contributor
@malfet malfet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM if it passes CI, but also would be good to enable it directly for torch.mv call as wlel

@fadara01
Copy link
Collaborator

@pytorchbot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased spda_gemv onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout spda_gemv && git pull --rebase)

@pytorch-bot pytorch-bot bot removed the ciflow/trunk Trigger trunk jobs on your pull request label May 14, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: arm Related to ARM architectures builds of PyTorch. Includes Apple M1 open source topic: not user facing topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants
0