-
You must be signed in to change notification settings - Fork 24.2k
Softmax Decomp Causes Incorrect Gradients when Using torch.compile
with F.multi_head_attention_forward
#152309
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
I couldn't repro this on a100 or h100.
Can you try again on main, and reopen if this issue still exists ? |
On the CPU, the correct results can be obtained: VersionThe above results were obtained using torch-nightly, with the version information as follows:
CPU Code
|
It seems that I don't have permission to reopen the issue. |
cc @drisspg can you repro ? |
torch.compile
with F.multi_head_attention_forward
torch.compile
with F.multi_head_attention_forward
Okay, I can repro this by forcing the math backend. And the error is due to the softmax decomp.
@drisspg, any thoughts ? |
I have a pretty small repro of the softmax decomp causing noticeable differences compared to the eager kernel for
A few things to call out: (1) the difference in my tiny repro is
(2) all of the inputs/outputs are float32, so there don't seem to be any funky intermediate dtype casting differences going on. (3) The decomp uses FMA. i tried both the FMA and non-FMA flavor, and I get similar divergences (4) how large the difference is depends on the input data distribution. I got a similar data distribution to the MHA example in my repro above ( I spent some time staring at the cuda kernel but so far I haven't figured out what causes the cc @ngimel, in case you know |
It might be fma, cuda kernel might or might not produce fma instructions? Also torch.sum() and result produced in the kernel might slightly differ, as you said the error for the tiny repro is not huge so it's those small things. Softmax, as we know, is particularly sensitive to fma/no fma. |
It didn't appear to be FMA because i tried with FMA on and off and saw similar differences. Are we ok chalking these differences up to reduction numerics being "different" between eager and compile and living with the (small) difference? |
One thing I will say is that there are two independent issues here: (1) the softmax decomp causing slightly different numerics compared to eager (this can be fully repro'd with
(2) there is a second issue I found that only repros with inductor, that I could repro with the efficient_attention backend (not math, so no softmax decomp), that causes larger differences:
I can file a separate issue for that one |
@bdhirsh for your micro repro it's definitely just fp error, when I modify the output to
the output I'm getting is |
@ngimel doing the same calculation for the weight grad (instead of the forward out) gives me in the ballpark of Do you have a standard set of utils that you recommend using for deciding if "compile vs eager delta is reasonable or not? I've been going off of I did this after reading your commend:
|
If Natalia / @eellison agree that the relative difference across compile vs eager here is within tolerance, i'll go ahead and close this issue |
@bdhirsh in the torch.compile torchbench test suite we are computing reference in fp64, and then check that difference (in the cosine similarity sense) from this gold standard is approximately the same for eager and compile, but at times it can hide real issues. Having a test with perfect accuracy and no false positives is hard |
tentatively closing this issue |
Yea, we should have checked fp64 ref here. We do have this note in the torch.compile issue template..
|
🐛 Describe the bug
When using
torch.compile
to compile a model that internally callstorch.nn.functional.multi_head_attention_forward
, the computed gradients differ significantly from the ones obtained via eager mode.To Reproduce
Output
Versions
PyTorch version: 2.7.0+cu126
Is debug build: False
CUDA used to build PyTorch: 12.6
ROCM used to build PyTorch: N/A
cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @chauhang @penguinwu @SherlockNoMad @bdhirsh
The text was updated successfully, but these errors were encountered: