-
Notifications
You must be signed in to change notification settings - Fork 24.2k
[ROCm] sdpa group query attention bf16 numeric error #139352
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
same issue,[bf16, lora train],backward nan |
cc @xinyazhang |
@OrenLeung GQA is not supported by current PyTorch's FA/ME backend on ROCM. The support is still under review ROCm/aotriton#49 and targeting PyTorch 2.6 release. (Actually I'm surprised this didn't trigger GPU segfault) |
@xinyazhang maybe yall can add a simple "if rocm throw not supported error" to prevent users from using this code path on rocm till it is supported? |
Certainly we can, but I have concerns about the effectiveness:
I need to discuss with other to see if it can make into 2.5.2, if the answer is yes then you'll see it (otherwise just wait for 2.6 which is not too far from 2.5.2 release actually...) |
I need to dig deeper into this case. The FA should have been disabled if GQA is used pytorch/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp Lines 637 to 644 in 6a1c451
|
Hmmm interesting. I wonder why I am running into numerics issues if that is disabled |
Okay there is one thing we missed
This line uses MATH backend, because FA is disabled
This one uses FA backend instead since enable_gqa is False by default and FA has higher priority. I tested an updated version of the script on MI300X (no fundamental changes but makes everything explicity)
and get the following result
The precise number differs due to random seed. However @OrenLeung can you figure out an seed for |
@xinyazhang for the one line missed. if it is a bug, could you please send a PR and fix it? Thank you. |
If it is a bug. |
@OrenLeung I just compiled latest pytorch, and I did a few sweeps. My data show the mismatch is about 7.6%. I used your code. By the way, what is the mismatch rate of H100? Thanks. seed: 103 Mismatched elements: 1278733 / 16777216 (7.6%) Mismatched elements: 1278264 / 16777216 (7.6%) Mismatched elements: 1278453 / 16777216 (7.6%) Mismatched elements: 1273980 / 16777216 (7.6%) Mismatched elements: 1269015 / 16777216 (7.6%) Mismatched elements: 1266750 / 16777216 (7.6%) Mismatched elements: 1278990 / 16777216 (7.6%) |
I ran this reproducer using upstream PyTorch and seems to be working. @OrenLeung @hliuca Can we close this issue? |
@OrenLeung Just want to follow-up from my last message. Is this issue resolved for you? |
Is this still an open bug? |
🐛 Describe the bug
Hi AMD Team,
On MI300X pytorch nightly grouped query attention is running into numeric errors. I have confirmed on H100 that this script does not have numeric errors.
Can you look into this & potentially add an numeric unit test for this?
cc: @hliuca
ROCm Error
Reprod Script
Versions
ROCm Versions
H100 Versions
cc @jeffdaily @sunway513 @jithunnair-amd @pruthvistony @ROCmSupport @dllehr-amd @jataylo @hongxiayang @naromero77amd
The text was updated successfully, but these errors were encountered: