8000 [ROCm] sdpa group query attention bf16 numeric error · Issue #139352 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[ROCm] sdpa group query attention bf16 numeric error #139352

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
functionstackx opened this issue Oct 31, 2024 · 15 comments
Closed

[ROCm] sdpa group query attention bf16 numeric error #139352

functionstackx opened this issue Oct 31, 2024 · 15 comments
Assignees
Labels
module: rocm AMD GPU support for Pytorch triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@functionstackx
Copy link
Contributor
functionstackx commented Oct 31, 2024

🐛 Describe the bug

Hi AMD Team,

On MI300X pytorch nightly grouped query attention is running into numeric errors. I have confirmed on H100 that this script does not have numeric errors.

Can you look into this & potentially add an numeric unit test for this?

cc: @hliuca

ROCm Error

    raise error_metas[0].to_error(msg)
AssertionError: Tensor-likes are not close!

Mismatched elements: 8388584 / 8388608 (100.0%)
Greatest absolute difference: 0.99609375 at index (0, 0, 0, 15) (up to 1e-05 allowed)
Greatest relative difference: inf at index (0, 0, 1, 0) (up to 0.016 allowed)

Reprod Script

import torch
from torch.nn.functional import scaled_dot_product_attention

batch = 4
seq_len_q = 1024
seq_len_kv = 1024
D = 128

query = torch.randn(batch, 32, seq_len_q, D, device='cuda', dtype=torch.bfloat16)
key = torch.randn(batch, 8, seq_len_kv, D, device='cuda', dtype=torch.bfloat16)
value = torch.randn(batch, 8, seq_len_kv, D, device='cuda', dtype=torch.bfloat16)


output_gqa = scaled_dot_product_attention(query, key, value, is_causal=True, enable_gqa=True)

key = key.repeat_interleave(4,1)
value = value.repeat_interleave(4,1)
output_repeat = scaled_dot_product_attention(query, key, value, is_causal=True)

torch.testing.assert_close(output_gqa, output_repeat)

Versions

ROCm Versions

~$ pip list | grep torch
pytorch-triton-rocm 3.1.0+cf34004b8a
torch               2.6.0.dev20241030+rocm6.2

H100 Versions

~$ pip list | grep torch
pytorch-triton               3.1.0+cf34004b8a
torch                        2.6.0.dev20241030+cu124

cc @jeffdaily @sunway513 @jithunnair-amd @pruthvistony @ROCmSupport @dllehr-amd @jataylo @hongxiayang @naromero77amd

@pytorch-bot pytorch-bot bot added the module: rocm AMD GPU support for Pytorch label Oct 31, 2024
@malfet malfet added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Oct 31, 2024
@zhipuch
Copy link
zhipuch commented Oct 31, 2024
8000

same issue,[bf16, lora train],backward nan

@hongxiayang
Copy link
Collaborator

cc @xinyazhang

@xinyazhang
Copy link
Collaborator
xinyazhang commented Oct 31, 2024

@OrenLeung GQA is not supported by current PyTorch's FA/ME backend on ROCM. The support is still under review ROCm/aotriton#49 and targeting PyTorch 2.6 release.

(Actually I'm surprised this didn't trigger GPU segfault)

@functionstackx
Copy link
Contributor Author

(Actually I'm surprised this didn't trigger GPU segfault)
same LOL.

@xinyazhang maybe yall can add a simple "if rocm throw not supported error" to prevent users from using this code path on rocm till it is supported?

@xinyazhang
Copy link
Collaborator
xinyazhang 8000 commented Oct 31, 2024

@xinyazhang maybe yall can add a simple "if rocm throw not supported error" to prevent users from using this code path on rocm till it is supported?

Certainly we can, but I have concerns about the effectiveness:

  1. nightly/main is considered experimental, and ideally this should be shipped to a point release
  2. There are two upcoming releases 2.5.2 and 2.6. In 2.6 GQA will be supported and this patch will be removed in Jan
  3. Hence this patch is only needed for 2.5.2

I need to discuss with other to see if it can make into 2.5.2, if the answer is yes then you'll see it (otherwise just wait for 2.6 which is not too far from 2.5.2 release actually...)

@functionstackx
Copy link
Contributor Author

nightly/main is considered experimental, and ideally this should be shipped to a point release

yep that makes sense in theory but 2.3/2.4/2.5 is so unusable due to #137414 , #138532, #135431 and a lot of other issues, ironically nightly/main is the most stable rocm version from my testing.

@xinyazhang
Copy link
Collaborator

I need to dig deeper into this case. The FA should have been disabled if GQA is used

#if USE_ROCM
constexpr bool backend_supports_grouped_query_attention = false;
#else
constexpr bool backend_supports_grouped_query_attention = true;
#endif
if (has_only_dense_inputs(params)) {
constexpr auto dense_constraints = array_of<bool (*)(sdp_params const&, bool)>(
check_batch_size_and_num_heads_dense<backend_supports_grouped_query_attention>,

@functionstackx
Copy link
Contributor Author

Hmmm interesting. I wonder why I am running into numerics issues if that is disabled

@xinyazhang
Copy link
Collaborator
xinyazhang commented Nov 1, 2024

Hmmm interesting. I wonder why I am running into numerics issues if that is disabled

Okay there is one thing we missed

output_gqa = scaled_dot_product_attention(query, key, value, is_causal=True, enable_gqa=True)

This line uses MATH backend, because FA is disabled

output_repeat = scaled_dot_product_attention(query, key, value, is_causal=True)

This one uses FA backend instead since enable_gqa is False by default and FA has higher priority.

I tested an updated version of the script on MI300X (no fundamental changes but makes everything explicity)

import contextlib
import torch
from torch.nn.functional import scaled_dot_product_attention
from torch.nn.attention import sdpa_kernel, SDPBackend

batch = 4
seq_len_q = 1024
seq_len_kv = 1024
D = 128

query = torch.randn(batch, 32, seq_len_q, D, device='cuda', dtype=torch.bfloat16)
key = torch.randn(batch, 8, seq_len_kv, D, device='cuda', dtype=torch.bfloat16)
value = torch.randn(batch, 8, seq_len_kv, D, device='cuda', dtype=torch.bfloat16)

ctxmgr = contextlib.nullcontext()
# ctxmgr = sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION])
# ctxmgr = sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION])
# ctxmgr = sdpa_kernel(backends=[SDPBackend.MATH])
# ctxmgr = sdpa_kernel(backends=[SDPBackend.CUDNN_ATTENTION])

with ctxmgr:
    output_gqa = scaled_dot_product_attention(query, key, value, is_causal=True, enable_gqa=True)
    print(key.data_ptr())
    key = key.repeat_interleave(4,1)
    print(key.data_ptr())
    value = value.repeat_interleave(4,1)
    output_repeat = scaled_dot_product_attention(query, key, value, is_causal=True)

torch.testing.assert_close(output_gqa, output_repeat)

and get the following result

AssertionError: Tensor-likes are not close!

Mismatched elements: 843398 / 16777216 (5.0%)
Greatest absolute difference: 0.0029296875 at index (1, 14, 8, 110) (up to 1e-05 allowed)
Greatest relative difference: inf at index (1, 18, 1, 43) (up to 0.016 allowed)

The precise number differs due to random seed. However 0.0029296875 is a reasonable level of numerical errors for bf16. In other words I cannot reproduce your problem.

@OrenLeung can you figure out an seed for torch.manual_seed to reproduce the 100% mismatching and Greatest absolute difference close to 1.0

@hliuca
Copy link
hliuca commented Nov 1, 2024

@xinyazhang for the one line missed. if it is a bug, could you please send a PR and fix it? Thank you.

@xinyazhang
Copy link
Collaborator

@xinyazhang for the one line missed. if it is a bug, could you please send a PR and fix it? Thank you.

If it is a bug.
I haven't been able to reproduce the problem yet.

@hliuca
Copy link
hliuca commented Nov 7, 2024

@OrenLeung I just compiled latest pytorch, and I did a few sweeps. My data show the mismatch is about 7.6%.

I used your code. By the way, what is the mismatch rate of H100? Thanks.

seed: 103
Traceback (most recent call last):
File "/root/te.py", line 29, in
torch.testing.assert_close(output_gqa, output_repeat)
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/testing/_comparison.py", line 1530, in assert_close
raise error_metas[0].to_error(msg)
AssertionError: Tensor-likes are not close!

Mismatched elements: 1278733 / 16777216 (7.6%)
Greatest absolute difference: 0.005859375 at index (2, 0, 2, 42) (up to 1e-05 allowed)
Greatest relative difference: 83968.0 at index (0, 27, 241, 40) (up to 0.016 allowed)
seed: 104
Traceback (most recent call last):
File "/root/te.py", line 29, in
torch.testing.assert_close(output_gqa, output_repeat)
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/testing/_comparison.py", line 1530, in assert_close
raise error_metas[0].to_error(msg)
AssertionError: Tensor-likes are not close!

Mismatched elements: 1278264 / 16777216 (7.6%)
Greatest absolute difference: 0.006500244140625 at index (3, 11, 1, 99) (up to 1e-05 allowed)
Greatest relative difference: 37632.0 at index (3, 25, 43, 20) (up to 0.016 allowed)
seed: 105
Traceback (most recent call last):
File "/root/te.py", line 29, in
torch.testing.assert_close(output_gqa, output_repeat)
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/testing/_comparison.py", line 1530, in assert_close
raise error_metas[0].to_error(msg)
AssertionError: Tensor-likes are not close!

Mismatched elements: 1278453 / 16777216 (7.6%)
Greatest absolute difference: 0.005859375 at index (0, 25, 1, 102) (up to 1e-05 allowed)
Greatest relative difference: inf at index (2, 17, 1, 89) (up to 0.016 allowed)
seed: 106
Traceback (most recent call last):
File "/root/te.py", line 29, in
torch.testing.assert_close(output_gqa, output_repeat)
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/testing/_comparison.py", line 1530, in assert_close
raise error_metas[0].to_error(msg)
AssertionError: Tensor-likes are not close!

Mismatched elements: 1273980 / 16777216 (7.6%)
Greatest absolute difference: 0.00634765625 at index (2, 0, 5, 71) (up to 1e-05 allowed)
Greatest relative difference: 13120.0 at index (1, 31, 822, 11) (up to 0.016 allowed)
seed: 107
Traceback (most recent call last):
File "/root/te.py", line 29, in
torch.testing.assert_close(output_gqa, output_repeat)
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/testing/_comparison.py", line 1530, in assert_close
raise error_metas[0].to_error(msg)
AssertionError: Tensor-likes are not close!

Mismatched elements: 1269015 / 16777216 (7.6%)
Greatest absolute difference: 0.0078125 at index (1, 24, 1, 81) (up to 1e-05 allowed)
Greatest relative difference: 13888.0 at index (0, 6, 234, 83) (up to 0.016 allowed)
seed: 108
Traceback (most recent call last):
File "/root/te.py", line 29, in
torch.testing.assert_close(output_gqa, output_repeat)
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/testing/_comparison.py", line 1530, in assert_close
raise error_metas[0].to_error(msg)
AssertionError: Tensor-likes are not close!

Mismatched elements: 1266750 / 16777216 (7.6%)
Greatest absolute difference: 0.0078125 at index (0, 5, 1, 18) (up to 1e-05 allowed)
Greatest relative difference: 27776.0 at index (3, 20, 299, 13) (up to 0.016 allowed)
seed: 109
Traceback (most recent call last):
File "/root/te.py", line 29, in
torch.testing.assert_close(output_gqa, output_repeat)
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/testing/_comparison.py", line 1530, in assert_close
raise error_metas[0].to_error(msg)
AssertionError: Tensor-likes are not close!

Mismatched elements: 1278990 / 16777216 (7.6%)
Greatest absolute difference: 0.009765625 at index (1, 8, 1, 127) (up to 1e-05 allowed)
Greatest relative difference: 69120.0 at index (2, 25, 766, 35) (up to 0.016 allowed)

@naromero77amd
Copy link
Collaborator

I ran this reproducer using upstream PyTorch and seems to be working.

@OrenLeung @hliuca Can we close this issue?

@naromero77amd
Copy link
Collaborator

@OrenLeung Just want to follow-up from my last message. Is this issue resolved for you?

@naromero77amd
Copy link
Collaborator

Is this still an open bug?

@github-project-automation github-project-automation bot moved this from Todo to Done in PyTorch on ROCm May 16, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: rocm AMD GPU support for Pytorch triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
Status: Done
Development

No branches or pull requests

7 participants
0