8000 [MPS] fix attention enable_gqa crash on mps (#150067) · pytorch/pytorch@9b4f085 · GitHub
[go: up one dir, main page]

Skip to content

Commit 9b4f085

Browse files
pytorchbotIsalia20
andauthored
[MPS] fix attention enable_gqa crash on mps (#150067)
[MPS] fix attention enable_gqa crash on mps (#149147) Fixes #149132 Pull Request resolved: #149147 Approved by: https://github.com/malfet (cherry picked from commit dd6e9df) Co-authored-by: Isalia20 <irakli.salia854@gmail.com>
1 parent d29e4c8 commit 9b4f085

File tree

3 files changed

+48
-1
lines changed

3 files changed

+48
-1
lines changed

aten/src/ATen/native/mps/operations/Attention.mm

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@
4444
TORCH_CHECK(!attn_mask.has_value(),
4545
"_scaled_dot_product_attention: Explicit attn_mask should not be set when is_causal=True");
4646
}
47-
47+
TORCH_CHECK(query.size(-3) == key.size(-3) && key.size(-3) == value.size(-3),
48+
"number of heads in query/key/value should match");
4849
TORCH_CHECK(dropout_p == 0.0, "_scaled_dot_product_attention_math_for_mps: dropout_p != 0.0 is not supported");
4950
TORCH_CHECK(macOS15_0_plus || (query.is_contiguous() && key.is_contiguous() && value.is_contiguous()),
5051
"_scaled_dot_product_attention_math_for_mps: query, key, and value must be contiguous");
@@ -55,6 +56,7 @@
5556
auto [q_, sq] = ensure_4d(query);
5657
auto [k_, sk] = ensure_4d(key);
5758
auto [v_, sv] = ensure_4d(value);
59+
5860
std::optional<Tensor> mask_;
5961
if (attn_mask) {
6062
auto maskExpandedDims = query.sizes().vec();

aten/src/ATen/native/transformers/attention.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -759,6 +759,28 @@ Tensor scaled_dot_product_attention(
759759
&& !(GradMode::is_enabled() && any_inputs_require_grad)
760760
&& (all_contiguous || mps::is_macos_13_or_newer(mps::MacOSVersion::MACOS_VER_15_0_PLUS))
761761
&& !any_nested) {
762+
if (enable_gqa) {
763+
int64_t q_heads = query_.size(-3);
764+
int64_t k_heads = key.size(-3);
765+
int64_t repeat_factor = q_heads / k_heads;
766+
767+
if (repeat_factor > 1) {
768+
TORCH_CHECK(q_heads % k_heads == 0,
769+
"For GQA, the query tensor's head dimension (" + std::to_string(q_heads) +
770+
") must be divisible by the key tensor's head dimension (" + std::to_string(k_heads) + ").");
771+
auto repeated_key = key.repeat_interleave(repeat_factor, /*dim=*/-3);
772+
auto repeated_value = value.repeat_interleave(repeat_factor, /*dim=*/-3);
773+
return std::get<0>(at::_scaled_dot_product_attention_math_for_mps(
774+
query_,
775+
repeated_key,
776+
repeated_value,
777+
attn_mask,
778+
dropout_p,
779+
is_causal,
780+
std::nullopt, /*dropout_mask*/
781+
scale));
782+
}
783+
}
762784
return std::get<0>(at::_scaled_dot_product_attention_math_for_mps(
763785
query_,
764786
key,

test/test_mps.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9909,6 +9909,29 @@ def test_sdpa_mask_5d(
99099909
y_ref = F.scaled_dot_product_attention(q.cpu(), k.cpu(), v.cpu(), attn_mask=mask.cpu(), dropout_p=0.0, is_causal=False)
99109910
self._compare_tensors(y.cpu(), y_ref)
99119911

9912+
@parametrize("dtype", [torch.float16, torch.float32])
9913+
@parametrize("is_causal", [True, False])
9914+
def test_sdpa_enable_gqa(self, dtype, is_causal):
9915+
q_heads = 32
9916+
key_heads = 16
9917+
L = 7
9918+
S = 17
9919+
HS = 23
9920+
9921+
q = torch.randn([2, q_heads, L, HS], dtype=dtype, device="mps")
9922+
k = torch.randn([2, key_heads, S, HS], dtype=dtype, device="mps")
9923+
v = torch.randn([2, key_heads, S, HS], dtype=dtype, device="mps")
9924+
9925+
y_ref = F.scaled_dot_product_attention(
9926+
q.cpu(), k.cpu(), v.cpu(), dropout_p=0.0, is_causal=is_causal, enable_gqa=True,
9927+
)
9928+
9929+
with torch.nn.attention.sdpa_kernel([torch.nn.attention.SDPBackend.MATH]):
9930+
y = F.scaled_dot_product_attention(
9931+
q, k, v, dropout_p=0.0, is_causal=is_causal, enable_gqa=True,
9932+
)
9933+
self._compare_tensors(y.cpu(), y_ref)
9934+
99129935

99139936
class TestGatherScatter(TestCaseMPS):
99149937
def test_slicing_with_step(self):

0 commit comments

Comments
 (0)
0