-
Notifications
You must be signed in to change notification settings - Fork 24.3k
bug fix: ensure 4d input in _scaled_dot_product_attention_math_mps #146623
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/146623
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 71b52ac with merge base 8a4dd76 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com>
auto final_out = (sq ? out.squeeze(0) : out); | ||
auto final_attn = (sq ? attn.squeeze(0) : attn); | ||
|
||
return {final_out, final_attn}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
return {final_out, final_attn}; | |
return {std::move(final_out), stdd::move(final_a 8000 ttn}); |
@hellopahe thank you for the PR, can you please sign the CLA? |
And last but not least, this PR could benefit from a test, so that it would not regress again |
I wonder if for 2.6.1 milestone one can land a smaller fix to just fall back to Math implementation if ndim is 3 (cc: @manuelcandales ) |
Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
…tential dangling reference
@pytorchbot merge -f "Lint + MPS are green" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…146623) This pr addresses the issue in the MPS backend for `_scaled_dot_product_attention_math_mps` where a 3d input like (num_heads, seq_len, query_dim) cannot be automatically treated as (1, num_heads, seq_len, query_dim), which can be inferred on cpu or cuda, which can be circumvented by adding a util function to ensure a 4d shape. The issue was found in hiyouga/LLaMA-Factory#6835, in [transformers qwen2_vl](https://github.com/huggingface/transformers/blob/1590c664306766f32ba68c50e67f14d61b16925d/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L373C14-L373C93), 3d q/k/v were passed into sdpa function, which lead to an error. Considering consistency, since this pattern might pop up elsewhere in the transformers codebase, I think it makes more sense to maintain the same intuition across all platforms. --- reproduce code: ``` import torch import torch.nn.functional as F head_num, seq_len, embed_dim = 16, 16, 80 bsz = 1 q = torch.randn(head_num, seq_len, embed_dim) k = torch.randn(head_num, seq_len, embed_dim) v = torch.randn(head_num, seq_len, embed_dim) attention_mask = torch.ones(1, seq_len, seq_len) oo_cpu = F.scaled_dot_product_attention( q.to("cpu"), k.to("cpu"), v.to("cpu"), attention_mask.to("cpu"), dropout_p=0.0 ) if torch.backends.mps.is_available(): oo_mps = F.scaled_dot_product_attention( q.to("mps"), k.to("mps"), v.to("mps"), attention_mask.to("mps"), dropout_p=0.0 ) assert torch.allclose(oo_cpu, oo_mps.to("cpu"), atol=1e-5) ``` error outputs: ``` Traceback (most recent call last): File "/opt/homebrew/Caskroom/miniconda/base/envs/torch-dev/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3577, in run_code exec(code_obj, self.user_global_ns, self.user_ns) File "<ipython-input-2-5169b8d2c5dd>", line 21, in <module> oo_mps = F.scaled_dot_product_attention( IndexError: Dimension out of range (expected to be in range of [-3, 2], but got 3) ``` hardware and envs: ``` torch 2.6.0 apple m3 max ``` --- Pull Request resolved: #146623 Approved by: https://github.com/malfet Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com> Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
…ytorch#146623) This pr addresses the issue in the MPS backend for `_scaled_dot_product_attention_math_mps` where a 3d input like (num_heads, seq_len, query_dim) cannot be automatically treated as (1, num_heads, seq_len, query_dim), which can be inferred on cpu or cuda, which can be circumvented by adding a util function to ensure a 4d shape. The issue was found in hiyouga/LLaMA-Factory#6835, in [transformers qwen2_vl](https://github.com/huggingface/transformers/blob/1590c664306766f32ba68c50e67f14d61b16925d/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L373C14-L373C93), 3d q/k/v were passed into sdpa function, which lead to an error. Considering consistency, since this pattern might pop up elsewhere in the transformers codebase, I think it makes more sense to maintain the same intuition across all platforms. --- reproduce code: ``` import torch import torch.nn.functional as F head_num, seq_len, embed_dim = 16, 16, 80 bsz = 1 q = torch.randn(head_num, seq_len, embed_dim) k = torch.randn(head_num, seq_len, embed_dim) v = torch.randn(head_num, seq_len, embed_dim) attention_mask = torch.ones(1, seq_len, seq_len) oo_cpu = F.scaled_dot_product_attention( q.to("cpu"), k.to("cpu"), v.to("cpu"), attention_mask.to("cpu"), dropout_p=0.0 ) if torch.backends.mps.is_available(): oo_mps = F.scaled_dot_product_attention( q.to("mps"), k.to("mps"), v.to("mps"), attention_mask.to("mps"), dropout_p=0.0 ) assert torch.allclose(oo_cpu, oo_mps.to("cpu"), atol=1e-5) ``` error outputs: ``` Traceback (most recent call last): File "/opt/homebrew/Caskroom/miniconda/base/envs/torch-dev/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3577, in run_code exec(code_obj, self.user_global_ns, self.user_ns) File "<ipython-input-2-5169b8d2c5dd>", line 21, in <module> oo_mps = F.scaled_dot_product_attention( IndexError: Dimension out of range (expected to be in range of [-3, 2], but got 3) ``` hardware and envs: ``` torch 2.6.0 apple m3 max ``` --- Pull Request resolved: pytorch#146623 Approved by: https://github.com/malfet Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com> Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
…ytorch#146623) This pr addresses the issue in the MPS backend for `_ 8000 scaled_dot_product_attention_math_mps` where a 3d input like (num_heads, seq_len, query_dim) cannot be automatically treated as (1, num_heads, seq_len, query_dim), which can be inferred on cpu or cuda, which can be circumvented by adding a util function to ensure a 4d shape. The issue was found in hiyouga/LLaMA-Factory#6835, in [transformers qwen2_vl](https://github.com/huggingface/transformers/blob/1590c664306766f32ba68c50e67f14d61b16925d/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L373C14-L373C93), 3d q/k/v were passed into sdpa function, which lead to an error. Considering consistency, since this pattern might pop up elsewhere in the transformers codebase, I think it makes more sense to maintain the same intuition across all platforms. --- reproduce code: ``` import torch import torch.nn.functional as F head_num, seq_len, embed_dim = 16, 16, 80 bsz = 1 q = torch.randn(head_num, seq_len, embed_dim) k = torch.randn(head_num, seq_len, embed_dim) v = torch.randn(head_num, seq_len, embed_dim) attention_mask = torch.ones(1, seq_len, seq_len) oo_cpu = F.scaled_dot_product_attention( q.to("cpu"), k.to("cpu"), v.to("cpu"), attention_mask.to("cpu"), dropout_p=0.0 ) if torch.backends.mps.is_available(): oo_mps = F.scaled_dot_product_attention( q.to("mps"), k.to("mps"), v.to("mps"), attention_mask.to("mps"), dropout_p=0.0 ) assert torch.allclose(oo_cpu, oo_mps.to("cpu"), atol=1e-5) ``` error outputs: ``` Traceback (most recent call last): File "/opt/homebrew/Caskroom/miniconda/base/envs/torch-dev/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3577, in run_code exec(code_obj, self.user_global_ns, self.user_ns) File "<ipython-input-2-5169b8d2c5dd>", line 21, in <module> oo_mps = F.scaled_dot_product_attention( IndexError: Dimension out of range (expected to be in range of [-3, 2], but got 3) ``` hardware and envs: ``` torch 2.6.0 apple m3 max ``` --- Pull Request resolved: pytorch#146623 Approved by: https://github.com/malfet Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com> Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
This pr addresses the issue in the MPS backend for
_scaled_dot_product_attention_math_mps
where a 3d input like (num_heads, seq_len, query_dim) cannot be automatically treated as (1, num_heads, seq_len, query_dim), which can be inferred on cpu or cuda, which can be circumvented by adding a util function to ensure a 4d shape.The issue was found in hiyouga/LLaMA-Factory#6835, in transformers qwen2_vl, 3d q/k/v were passed into sdpa function, which lead to an error.
Considering consistency, since this pattern might pop up elsewhere in the transformers codebase, I think it makes more sense to maintain the same intuition across all platforms.
reproduce code:
error outputs:
hardware and envs: