-
Notifications
You must be signed in to change notification settings - Fork 24.3k
ROCm SDPA: Ensure attn_mask has the same dtype with q #143242
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/143242
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 1489d35 with merge base a1ae8fa ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@pytorchbot label "topic: not user facing" |
@pytorchbot rebase |
Please seek CI approval before scheduling CIFlow labels |
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here |
Successfully rebased |
60b555e
to
29bf7a8
Compare
@pytorchbot rebase |
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here |
This is required by current AOTriton's backend.
Successfully rebased |
29bf7a8
to
857b9c3
Compare
@xinyazhang The flex_attention failures look legit. |
However I have another concern. According to the code pytorch/aten/src/ATen/native/transformers/cuda/attention.cu Lines 1331 to 1336 in e56768f
CUDA's ME also assumes q.dtype() == attn_bias.dtype() . How does the CUDA backend make it working?
Update: Okay I found the differences. (Although I'm still not sure how fp16 qkv + fp32 attn_mask works on NVIDIA). |
@pytorchbot merge -f "ROCm CI passed. Change only impacts ROCm" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
@pytorchbot cherry-pick --onto release/2.6 |
❌ 🤖 pytorchbot command failed:
Try |
@pytorchbot cherry-pick --onto release/2.6 -c critical |
This is required by current AOTriton's backend. Fixes NaN when calling SDPA ME backend with `q.dtype() != attn_mask.dtype()` when training llama2 using transformers+deepspeed+pytorch Corresponding CUDA check seems to be here: https://github.com/pytorch/pytorch/blob/708ce3c0082d670d9eaff84bc3c43cad4554a75d/aten/src/ATen/native/transformers/cuda/attention.cu#L1331-L1336 Pull Request resolved: #143242 Approved by: https://github.com/jeffdaily (cherry picked from commit 3068ce0)
Cherry picking #143242The cherry pick PR is at #144398 and it is recommended to link a critical cherry pick PR with an issue. The following tracker issues are updated: Details for Dev Infra teamRaised by workflow job |
ROCm SDPA: Ensure attn_mask has the same dtype with q (#143242) This is required by current AOTriton's backend. Fixes NaN when calling SDPA ME backend with `q.dtype() != attn_mask.dtype()` when training llama2 using transformers+deepspeed+pytorch Corresponding CUDA check seems to be here: https://github.com/pytorch/pytorch/blob/708ce3c0082d670d9eaff84bc3c43cad4554a75d/aten/src/ATen/native/transformers/cuda/attention.cu#L1331-L1336 Pull Request resolved: #143242 Approved by: https://github.com/jeffdaily (cherry picked from commit 3068ce0) Co-authored-by: Xinya Zhang <Xinya.Zhang@amd.com>
This is required by current AOTriton's backend.
Fixes NaN when calling SDPA ME backend with
q.dtype() != attn_mask.dtype()
when training llama2 using transformers+deepspeed+pytorchCorresponding CUDA check seems to be here:
pytorch/aten/src/ATen/native/transformers/cuda/attention.cu
Lines 1331 to 1336 in 708ce3c
cc @jeffdaily @sunway513 @jithunnair-amd @pruthvistony @ROCmSupport @dllehr-amd @jataylo @hongxiayang @naromero77amd