8000 bug fix: ensure 4d input in _scaled_dot_product_attention_math_mps by hellopahe · Pull Request #146623 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

bug fix: ensure 4d input in _scaled_dot_product_attention_math_mps #146623

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 7 commits into from

Conversation

hellopahe
Copy link
Contributor

This pr addresses the issue in the MPS backend for _scaled_dot_product_attention_math_mps where a 3d input like (num_heads, seq_len, query_dim) cannot be automatically treated as (1, num_heads, seq_len, query_dim), which can be inferred on cpu or cuda, which can be circumvented by adding a util function to ensure a 4d shape.

The issue was found in hiyouga/LLaMA-Factory#6835, in transformers qwen2_vl, 3d q/k/v were passed into sdpa function, which lead to an error.

Considering consistency, since this pattern might pop up elsewhere in the transformers codebase, I think it makes more sense to maintain the same intuition across all platforms.


reproduce code:

import torch
import torch.nn.functional as F

head_num, seq_len, embed_dim = 16, 16, 80
bsz = 1

q = torch.randn(head_num, seq_len, embed_dim)
k = torch.randn(head_num, seq_len, embed_dim)
v = torch.randn(head_num, seq_len, embed_dim)
attention_mask = torch.ones(1, seq_len, seq_len)

oo_cpu = F.scaled_dot_product_attention(
    q.to("cpu"),
    k.to("cpu"),
    v.to("cpu"),
    attention_mask.to("cpu"),
    dropout_p=0.0
)

if torch.backends.mps.is_available():
    oo_mps = F.scaled_dot_product_attention(
        q.to("mps"),
        k.to("mps"),
        v.to("mps"),
        attention_mask.to("mps"),
        dropout_p=0.0
    )
    assert torch.allclose(oo_cpu, oo_mps.to("cpu"), atol=1e-5)

error outputs:

Traceback (most recent call last):
  File "/opt/homebrew/Caskroom/miniconda/base/envs/torch-dev/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3577, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-2-5169b8d2c5dd>", line 21, in <module>
    oo_mps = F.scaled_dot_product_attention(
IndexError: Dimension out of range (expected to be in range of [-3, 2], but got 3)

hardware and envs:

torch               2.6.0
apple m3 max

Copy link
pytorch-bot bot commented Feb 6, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/146623

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 71b52ac with merge base 8a4dd76 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Copy link
linux-foundation-easycla bot commented Feb 6, 2025

CLA Signed

The committers listed above are authorized under a signed CLA.

@pytorch-bot pytorch-bot bot added the release notes: mps Release notes category label Feb 6, 2025
Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com>
auto final_out = (sq ? out.squeeze(0) : out);
auto final_attn = (sq ? attn.squeeze(0) : attn);

return {final_out, final_attn};
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return {final_out, final_attn};
return {std::move(final_out), stdd::move(final_a 8000 ttn});

@malfet
Copy link
Contributor
malfet commented Feb 6, 2025

@hellopahe thank you for the PR, can you please sign the CLA?

@malfet
Copy link
Contributor
malfet commented Feb 6, 2025

And last but not least, this PR could benefit from a test, so that it would not regress again

@malfet malfet added topic: bug fixes topic category ciflow/mps Run MPS tests (subset of trunk) labels Feb 6, 2025
@malfet malfet added this to the 2.6.1 milestone Feb 6, 2025
@malfet
Copy link
Contributor
malfet commented Feb 6, 2025

I wonder if for 2.6.1 milestone one can land a smaller fix to just fall back to Math implementation if ndim is 3 (cc: @manuelcandales )

hellopahe and others added 3 commits February 7, 2025 01:42
Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
@mikaylagawarecki mikaylagawarecki added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Feb 7, 2025
@malfet
Copy link
Contributor
malfet commented Feb 13, 2025

@pytorchbot merge -f "Lint + MPS are green"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Raymo111 pushed a commit that referenced this pull request Feb 20, 2025
…146623)

This pr addresses the issue in the MPS backend for `_scaled_dot_product_attention_math_mps` where a 3d input like (num_heads, seq_len, query_dim) cannot be automatically treated as (1, num_heads, seq_len, query_dim), which can be inferred on cpu or cuda, which can be circumvented by adding a util function to ensure a 4d shape.

The issue was found in hiyouga/LLaMA-Factory#6835, in [transformers qwen2_vl](https://github.com/huggingface/transformers/blob/1590c664306766f32ba68c50e67f14d61b16925d/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L373C14-L373C93), 3d q/k/v were passed into sdpa function, which lead to an error.

Considering consistency, since this pattern might pop up elsewhere in the transformers codebase, I think it makes more sense to maintain the same intuition across all platforms.

---
reproduce code:
```
import torch
import torch.nn.functional as F

head_num, seq_len, embed_dim = 16, 16, 80
bsz = 1

q = torch.randn(head_num, seq_len, embed_dim)
k = torch.randn(head_num, seq_len, embed_dim)
v = torch.randn(head_num, seq_len, embed_dim)
attention_mask = torch.ones(1, seq_len, seq_len)

oo_cpu = F.scaled_dot_product_attention(
    q.to("cpu"),
    k.to("cpu"),
    v.to("cpu"),
    attention_mask.to("cpu"),
    dropout_p=0.0
)

if torch.backends.mps.is_available():
    oo_mps = F.scaled_dot_product_attention(
        q.to("mps"),
        k.to("mps"),
        v.to("mps"),
        attention_mask.to("mps"),
        dropout_p=0.0
    )
    assert torch.allclose(oo_cpu, oo_mps.to("cpu"), atol=1e-5)
```

error outputs:
```
Traceback (most recent call last):
  File "/opt/homebrew/Caskroom/miniconda/base/envs/torch-dev/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3577, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-2-5169b8d2c5dd>", line 21, in <module>
    oo_mps = F.scaled_dot_product_attention(
IndexError: Dimension out of range (expected to be in range of [-3, 2], but got 3)
```

hardware and envs:
```
torch               2.6.0
apple m3 max
```

---

Pull Request resolved: #146623
Approved by: https://github.com/malfet

Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com>
Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
Ryo-not-rio pushed a commit to Ryo-not-rio/pytorch that referenced this pull request Feb 24, 2025
…ytorch#146623)

This pr addresses the issue in the MPS backend for `_scaled_dot_product_attention_math_mps` where a 3d input like (num_heads, seq_len, query_dim) cannot be automatically treated as (1, num_heads, seq_len, query_dim), which can be inferred on cpu or cuda, which can be circumvented by adding a util function to ensure a 4d shape.

The issue was found in hiyouga/LLaMA-Factory#6835, in [transformers qwen2_vl](https://github.com/huggingface/transformers/blob/1590c664306766f32ba68c50e67f14d61b16925d/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L373C14-L373C93), 3d q/k/v were passed into sdpa function, which lead to an error.

Considering consistency, since this pattern might pop up elsewhere in the transformers codebase, I think it makes more sense to maintain the same intuition across all platforms.

---
reproduce code:
```
import torch
import torch.nn.functional as F

head_num, seq_len, embed_dim = 16, 16, 80
bsz = 1

q = torch.randn(head_num, seq_len, embed_dim)
k = torch.randn(head_num, seq_len, embed_dim)
v = torch.randn(head_num, seq_len, embed_dim)
attention_mask = torch.ones(1, seq_len, seq_len)

oo_cpu = F.scaled_dot_product_attention(
    q.to("cpu"),
    k.to("cpu"),
    v.to("cpu"),
    attention_mask.to("cpu"),
    dropout_p=0.0
)

if torch.backends.mps.is_available():
    oo_mps = F.scaled_dot_product_attention(
        q.to("mps"),
        k.to("mps"),
        v.to("mps"),
        attention_mask.to("mps"),
        dropout_p=0.0
    )
    assert torch.allclose(oo_cpu, oo_mps.to("cpu"), atol=1e-5)
```

error outputs:
```
Traceback (most recent call last):
  File "/opt/homebrew/Caskroom/miniconda/base/envs/torch-dev/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3577, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-2-5169b8d2c5dd>", line 21, in <module>
    oo_mps = F.scaled_dot_product_attention(
IndexError: Dimension out of range (expected to be in range of [-3, 2], but got 3)
```

hardware and envs:
```
torch               2.6.0
apple m3 max
```

---

Pull Request resolved: pytorch#146623
Approved by: https://github.com/malfet

Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com>
Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
majing921201 pushed a commit to majing921201/pytorch that referenced this pull request Mar 4, 2025
…ytorch#146623)

This pr addresses the issue in the MPS backend for `_
8000
scaled_dot_product_attention_math_mps` where a 3d input like (num_heads, seq_len, query_dim) cannot be automatically treated as (1, num_heads, seq_len, query_dim), which can be inferred on cpu or cuda, which can be circumvented by adding a util function to ensure a 4d shape.

The issue was found in hiyouga/LLaMA-Factory#6835, in [transformers qwen2_vl](https://github.com/huggingface/transformers/blob/1590c664306766f32ba68c50e67f14d61b16925d/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L373C14-L373C93), 3d q/k/v were passed into sdpa function, which lead to an error.

Considering consistency, since this pattern might pop up elsewhere in the transformers codebase, I think it makes more sense to maintain the same intuition across all platforms.

---
reproduce code:
```
import torch
import torch.nn.functional as F

head_num, seq_len, embed_dim = 16, 16, 80
bsz = 1

q = torch.randn(head_num, seq_len, embed_dim)
k = torch.randn(head_num, seq_len, embed_dim)
v = torch.randn(head_num, seq_len, embed_dim)
attention_mask = torch.ones(1, seq_len, seq_len)

oo_cpu = F.scaled_dot_product_attention(
    q.to("cpu"),
    k.to("cpu"),
    v.to("cpu"),
    attention_mask.to("cpu"),
    dropout_p=0.0
)

if torch.backends.mps.is_available():
    oo_mps = F.scaled_dot_product_attention(
        q.to("mps"),
        k.to("mps"),
        v.to("mps"),
        attention_mask.to("mps"),
        dropout_p=0.0
    )
    assert torch.allclose(oo_cpu, oo_mps.to("cpu"), atol=1e-5)
```

error outputs:
```
Traceback (most recent call last):
  File "/opt/homebrew/Caskroom/miniconda/base/envs/torch-dev/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3577, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-2-5169b8d2c5dd>", line 21, in <module>
    oo_mps = F.scaled_dot_product_attention(
IndexError: Dimension out of range (expected to be in range of [-3, 2], but got 3)
```

hardware and envs:
```
torch               2.6.0
apple m3 max
```

---

Pull Request resolved: pytorch#146623
Approved by: https://github.com/malfet

Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com>
Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/mps Run MPS tests (subset of trunk) Merged open source release notes: mps Release notes category topic: bug fixes topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants
0