8000 ROCm SDPA: Ensure attn_mask has the same dtype with q by xinyazhang · Pull Request #143242 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

ROCm SDPA: Ensure attn_mask has the same dtype with q #143242

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed

Conversation

xinyazhang
Copy link
Collaborator
@xinyazhang xinyazhang commented Dec 14, 2024

This is required by current AOTriton's backend.

Fixes NaN when calling SDPA ME backend with q.dtype() != attn_mask.dtype() when training llama2 using transformers+deepspeed+pytorch

Corresponding CUDA check seems to be here:

if (bias.has_value()) {
CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA((*bias));
TORCH_CHECK(
bias->scalar_type() == CutlassToAtenDtype<scalar_t>::atScalarType(),
"invalid dtype for bias - should match query's dtype");
p.attn_bias_ptr = (const scalar_t*)bias->const_data_ptr();

cc @jeffdaily @sunway513 @jithunnair-amd @pruthvistony @ROCmSupport @dllehr-amd @jataylo @hongxiayang @naromero77amd

Copy link
pytorch-bot bot commented Dec 14, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/143242

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 1489d35 with merge base a1ae8fa (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the module: rocm AMD GPU support for Pytorch label Dec 14, 2024
@xinyazhang
Copy link
Collaborator Author

@pytorchbot label "topic: not user facing"

@jeffdaily
Copy link
Collaborator

@pytorchbot rebase

@jeffdaily jeffdaily added the ciflow/rocm Trigger "default" config CI on ROCm label Dec 16, 2024
Copy link
pytorch-bot bot commented Dec 16, 2024

Please seek CI approval before scheduling CIFlow labels

@pytorch-bot pytorch-bot bot removed the ciflow/rocm Trigger "default" config CI on ROCm label Dec 16, 2024
@jeffdaily jeffdaily added the ciflow/rocm Trigger "default" config CI on ROCm label Dec 16, 2024
@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased xinyazhang/check_attn_mask_dtype onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout xinyazhang/check_attn_mask_dtype && git pull --rebase)

@pytorchmergebot pytorchmergebot force-pushed the xinyazhang/check_attn_mask_dtype branch from 60b555e to 29bf7a8 Compare December 16, 2024 17:44
@xinyazhang
Copy link
Collaborator Author

@pytorchbot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

This is required by current AOTriton's backend.
@pytorchmergebot
Copy link
Collaborator

Successfully rebased xinyazhang/check_attn_mask_dtype onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout xinyazhang/check_attn_mask_dtype && git pull --rebase)

@pytorchmergebot pytorchmergebot force-pushed the xinyazhang/check_attn_mask_dtype branch from 29bf7a8 to 857b9c3 Compare January 3, 2025 18:55
@jithunnair-amd
Copy link
Collaborator

@xinyazhang The flex_attention failures look legit.

@xinyazhang
Copy link
Collaborator Author
xinyazhang commented Jan 6, 2025

@xinyazhang The flex_attention failures look legit.

This is due to we are switching to math backend on certain inputs, and math backend has known numerical accuracy problems. Update: this problem is the current patch excludes binary masks which SDPA will convert it to fp data types.

However I have another concern. According to the code

if (bias.has_value()) {
CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA((*bias));
TORCH_CHECK(
bias->scalar_type() == CutlassToAtenDtype<scalar_t>::atScalarType(),
"invalid dtype for bias - should match query's dtype");
p.attn_bias_ptr = (const scalar_t*)bias->const_data_ptr();

CUDA's ME also assumes q.dtype() == attn_bias.dtype(). How does the CUDA backend make it working?

Update: Okay I found the differences. (Although I'm still not sure how fp16 qkv + fp32 attn_mask works on NVIDIA).

@xinyazhang xinyazhang marked this pull request as ready for review January 7, 2025 15:24
@jithunnair-amd
Copy link
Collaborator

@pytorchbot merge -f "ROCm CI passed. Change only impacts ROCm"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@jithunnair-amd jithunnair-amd deleted the xinyazhang/check_attn_mask_dtype branch January 8, 2025 15:22
@jithunnair-amd
Copy link
Collaborator

@pytorchbot cherry-pick --onto release/2.6

Copy link
pytorch-bot bot commented Jan 8, 2025

❌ 🤖 pytorchbot command failed:

@pytorchbot cherry-pick: error: the following arguments are required: -c/--classification

usage: @pytorchbot cherry-pick --onto ONTO [--fixes FIXES] -c
                               {regression,critical,fixnewfeature,docs,release}

Try @pytorchbot --help for more info.

@jithunnair-amd
Copy link
Collaborator

@pytorchbot cherry-pick --onto release/2.6 -c critical

@jithunnair-amd jithunnair-amd added this to the 2.6.0 milestone Jan 8, 2025
pytorchbot pushed a commit that referenced this pull request Jan 8, 2025
This is required by current AOTriton's backend.

Fixes NaN when calling SDPA ME backend with `q.dtype() != attn_mask.dtype()` when training llama2 using transformers+deepspeed+pytorch

Corresponding CUDA check seems to be here:
https://github.com/pytorch/pytorch/blob/708ce3c0082d670d9eaff84bc3c43cad4554a75d/aten/src/ATen/native/transformers/cuda/attention.cu#L1331-L1336

Pull Request resolved: #143242
Approved by: https://github.com/jeffdaily

(cherry picked from commit 3068ce0)
@pytorchbot
Copy link
Collaborator

Cherry picking #143242

The cherry pick PR is at #144398 and it is recommended to link a critical cherry pick PR with an issue. The following tracker issues are updated:

Details for Dev Infra team Raised by workflow job

kit1980 pushed a commit that referenced this pull request Jan 10, 2025
ROCm SDPA: Ensure attn_mask has the same dtype with q (#143242)

This is required by current AOTriton's backend.

Fixes NaN when calling SDPA ME backend with `q.dtype() != attn_mask.dtype()` when training llama2 using transformers+deepspeed+pytorch

Corresponding CUDA check seems to be here:
https://github.com/pytorch/pytorch/blob/708ce3c0082d670d9eaff84bc3c43cad4554a75d/aten/src/ATen/native/transformers/cuda/attention.cu#L1331-L1336

Pull Request resolved: #143242
Approved by: https://github.com/jeffdaily

(cherry picked from commit 3068ce0)

Co-authored-by: Xinya Zhang <Xinya.Zhang@amd.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/rocm Trigger "default" config CI on ROCm Merged module: rocm AMD GPU support for Pytorch open source topic: not user facing topic category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants
0