8000 scaled_dot_product_attention crashes on apple silicon · Issue #149132 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

scaled_dot_product_attention crashes on apple silicon #149132

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
jjh42 opened this issue Mar 13, 2025 · 2 comments
Closed

scaled_dot_product_attention crashes on apple silicon #149132

jjh42 opened this issue Mar 13, 2025 · 2 comments
Labels
module: crash Problem manifests as a hard crash, as opposed to a RuntimeError module: mps Related to Apple Metal Performance Shaders framework module: sdpa All things related to torch.nn.functional.scaled_dot_product_attentiion triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Milestone

Comments

@jjh42
Copy link
Contributor
jjh42 commented Mar 13, 2025

🐛 Describe the bug

This following python code fails and ends the process on macos 15.3.1 (M1 Pro).

import torch
import torch.nn.functional as F

print(torch.__version__)

device = torch.device('mps')

B=2
T=3
n_kv_head = 2
n_q_head = 4
dim = 8

attn_mask = torch.ones((T, T)).to(device)


q = torch.rand(B, n_q_head, T, dim).to(device)
k = torch.rand(B, n_kv_head, T, dim).to(device)
v = torch.rand(B, n_kv_head, T, dim).to(device)

F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, enable_gqa=True)

with the following logs:

2.7.0.dev20250311
loc("mps_matmul"("(mpsFileLoc): /AppleInternal/Library/BuildRoots/d187755d-b9a3-11ef-83e5-aabfac210453/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm":43:0)): error: incompatible dimensions
loc("mps_matmul"("(mpsFileLoc): /AppleInternal/Library/BuildRoots/d187755d-b9a3-11ef-83e5-aabfac210453/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm":43:0)): error: invalid shape
LLVM ERROR: Failed to infer result type(s).

Changing device to CPU and it works fine. Setting n_kv_head to 4 also resolves the issue.

Versions

I'm using uv the version script fails.

I've tested with python 2.6.0 and the 2025-03-11 nightly.

cc @kulinseth @albanD @malfet @DenisVieriu97 @jhavukainen

@Isalia20 Isalia20 added module: mps Related to Apple Metal Performance Shaders framework module: crash Problem manifests as a hard crash, as opposed to a RuntimeError labels Mar 13, 2025
@Isalia20
Copy link
Collaborator

Confirmed that it crashes with the latest main branch build. Crash is due to the mismatch in num_head dimensions(dim=1). I can take a look at this

@malfet malfet added this to the 2.7.0 milestone Mar 13, 2025
@malfet malfet added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: sdpa All things related to torch.nn.functional.scaled_dot_product_attentiion labels Mar 13, 2025
pytorchbot pushed a commit that referenced this issue Mar 26, 2025
Fixes #149132

Pull Request resolved: #149147
Approved by: https://github.com/malfet

(cherry picked from commit dd6e9df)
malfet pushed a commit that referenced this issue Mar 27, 2025
[MPS] fix attention enable_gqa crash on mps (#149147)

Fixes #149132

Pull Request resolved: #149147
Approved by: https://github.com/malfet

(cherry picked from commit dd6e9df)

Co-authored-by: Isalia20 <irakli.salia854@gmail.com>
@ZainRizvi
Copy link
Contributor

Verified that the repro works on the current rc build and that it fails on torch==2.7.0.dev20250312

(release2.7) ~/test/release2.7/.venv/lib/python3.12/site-packages/torch/lib python
Python 3.12.5 (main, Aug 14 2024, 04:32:18) [Clang 18.1.8 ] on darwin
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
/Users/zainr/test/release2.7/.venv/lib/python3.12/site-packages/torch/_subclasses/functional_tensor.py:276: UserWarning: Failed to initialize NumPy: No module named 'numpy' (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/utils/tensor_numpy.cpp:81.)
  cpu = _conversion_method_template(device=torch.device("cpu"))
>>> import torch.nn.functional as F
>>> print(torch.__version__)
2.7.0
>>>
>>> device = torch.device('mps')
>>>
>>> B=2
>>> T=3
>>> n_kv_head = 2
>>> n_q_head = 4
>>> dim = 8
>>>
>>> attn_mask = torch.ones((T, T)).to(device)
>>>
>>>
>>> q = torch.rand(B, n_q_head, T, dim).to(device)
>>> k = torch.rand(B, n_kv_head, T, dim).to(device)
>>> v = torch.rand(B, n_kv_head, T, dim).to(device)
>>>
>>> F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, enable_gqa=True)
tensor([[[[0.3207, 0.5862, 0.7752, 0.6000, 0.3268, 0.4634, 0.4741, 0.4143],
          [0.3193, 0.5866, 0.7741, 0.6013, 0.3311, 0.4643, 0.4692, 0.4183],
          [0.3203, 0.5806, 0.7754, 0.6006, 0.3322, 0.4632, 0.4670, 0.4235]],

         [[0.3168, 0.5912, 0.7715, 0.6035, 0.3365, 0.4662, 0.4636, 0.4207],
          [0.3198, 0.5843, 0.7747, 0.6009, 0.3311, 0.4638, 0.4688, 0.4199],
          [0.3160, 0.5963, 0.7703, 0.6041, 0.3357, 0.4672, 0.4653, 0.4164]],

         [[0.0807, 0.4647, 0.3779, 0.1464, 0.4681, 0.7168, 0.4045, 0.4266],
          [0.0738, 0.4313, 0.3509, 0.1397, 0.4528, 0.6858, 0.4098, 0.4429],
          [0.0766, 0.4433, 0.3602, 0.1427, 0.4568, 0.6951, 0.4075, 0.4393]],

         [[0.0771, 0.4514, 0.3679, 0.1426, 0.4648, 0.7079, 0.4075, 0.4288],
          [0.0771, 0.4577, 0.3746, 0.1418, 0.4727, 0.7199, 0.4080, 0.4183],
          [0.0891, 0.4930, 0.3977, 0.1561, 0.4711, 0.7311, 0.3970, 0.4274]]],


        [[[0.5090, 0.6858, 0.4873, 0.5611, 0.3307, 0.4444, 0.4684, 0.3867],
          [0.5283, 0.6929, 0.4951, 0.5742, 0.3484, 0.4612, 0.4860, 0.3723],
          [0.4933, 0.6818, 0.4845, 0.5608, 0.3178, 0.4360, 0.4607, 0.3909]],

         [[0.5551, 0.7061, 0.5121, 0.6100, 0.3760, 0.4935, 0.5220, 0.3397],
          [0.5354, 0.6975, 0.5017, 0.5898, 0.3567, 0.4729, 0.4995, 0.3593],
          [0.5098, 0.6843, 0.4843, 0.5519, 0.3297, 0.4401, 0.4628, 0.3931]],

         [[0.4532, 0.6665, 0.6869, 0.5664, 0.5450, 0.3771, 0.4980, 0.5236],
          [0.4709, 0.6238, 0.6603, 0.5809, 0.5787, 0.4170, 0.5109, 0.5358],
          [0.4480, 0.6593, 0.6786, 0.5667, 0.5330, 0.3960, 0.4869, 0.5466]],

         [[0.4455, 0.6766, 0.6915, 0.5620, 0.5293, 0.3729, 0.4892, 0.5299],
          [0.4483, 0.6576, 0.6773, 0.5672, 0.5333, 0.3984, 0.4867, 0.5484],
          [0.4601, 0.6591, 0.6840, 0.5699, 0.5590, 0.3784, 0.5064, 0.5161]]]],
       device='mps:0')

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: crash Problem manifests as a hard crash, as opposed to a RuntimeError module: mps Related to Apple Metal Performance Shaders framework module: sdpa All things related to torch.nn.functional.scaled_dot_product_attentiion triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants
0