You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
bug fix: ensure 4d input in _scaled_dot_product_attention_math_mps (#146623)
This pr addresses the issue in the MPS backend for `_scaled_dot_product_attention_math_mps` where a 3d input like (num_heads, seq_len, query_dim) cannot be automatically treated as (1, num_heads, seq_len, query_dim), which can be inferred on cpu or cuda, which can be circumvented by adding a util function to ensure a 4d shape.
The issue was found in hiyouga/LLaMA-Factory#6835, in [transformers qwen2_vl](https://github.com/huggingface/transformers/blob/1590c664306766f32ba68c50e67f14d61b16925d/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L373C14-L373C93), 3d q/k/v were passed into sdpa function, which lead to an error.
Considering consistency, since this pattern might pop up elsewhere in the transformers codebase, I think it makes more sense to maintain the same intuition across all platforms.
---
reproduce code:
```
import torch
import torch.nn.functional as F
head_num, seq_len, embed_dim = 16, 16, 80
bsz = 1
q = torch.randn(head_num, seq_len, embed_dim)
k = torch.randn(head_num, seq_len, embed_dim)
v = torch.randn(head_num, seq_len, embed_dim)
attention_mask = torch.ones(1, seq_len, seq_len)
oo_cpu = F.scaled_dot_product_attention(
q.to("cpu"),
k.to("cpu"),
v.to("cpu"),
attention_mask.to("cpu"),
dropout_p=0.0
)
if torch.backends.mps.is_available():
oo_mps = F.scaled_dot_product_attention(
q.to("mps"),
k.to("mps"),
v.to("mps"),
attention_mask.to("mps"),
dropout_p=0.0
)
assert torch.allclose(oo_cpu, oo_mps.to("cpu"), atol=1e-5)
```
error outputs:
```
Traceback (most recent call last):
File "/opt/homebrew/Caskroom/miniconda/base/envs/torch-dev/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3577, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "<ipython-input-2-5169b8d2c5dd>", line 21, in <module>
oo_mps = F.scaled_dot_product_attention(
IndexError: Dimension out of range (expected to be in range of [-3, 2], but got 3)
```
hardware and envs:
```
torch 2.6.0
apple m3 max
```
---
Pull Request resolved: #146623
Approved by: https://github.com/malfet
Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com>
Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
0 commit comments