-
Notifications
You must be signed in to change notification settings - Fork 24.2k
Use std::fma for CUDA Adam kernel's lerps. #153097
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/153097
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit 2b31134 with merge base 11c64b7 ( UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would you have perf numbers of before/after?
I am not sure if this makes much of a perf benefit here TBH, the benefit is that it should be more accurate: https://en.cppreference.com/w/cpp/numeric/math/fma . My question is does ROCM/CUDA actually map this to proper cuda_fm_rn and cuda_fmaf_rn functions? I think so, but I don't see any documentation to ensure it doesn't fall back to CPU like emulation |
@Skylion007 I'm not sure if this is quite exactly what you're asking, but on Godbolt NVCC 12.5.1, with the example code: double testing_fma(double a, double b, double c) {
return std::fma(a, b, c);
}
double testing_manual(double a, double b, double c) {
return a * b + c;
} the output for testing_fma(double, double, double):
jmp fma
testing_manual(double, double, double):
mulsd %xmm1, %xmm0
addsd %xmm2, %xmm0
ret and the output for testing_fma(double, double, double):
pushq %rbp
movq %rsp, %rbp
subq $32, %rsp
movsd %xmm0, -8(%rbp)
movsd %xmm1, -16(%rbp)
movsd %xmm2, -24(%rbp)
movsd -24(%rbp), %xmm1
movsd -16(%rbp), %xmm0
movq -8(%rbp), %rax
movapd %xmm1, %xmm2
movapd %xmm0, %xmm1
movq %rax, %xmm0
call fma
movq %xmm0, %rax
movq %rax, %xmm0
leave
ret
testing_manual(double, double, double):
pushq %rbp
movq %rsp, %rbp
movsd %xmm0, -8(%rbp)
movsd %xmm1, -16(%rbp)
movsd %xmm2, -24(%rbp)
movsd -8(%rbp), %xmm0
mulsd -16(%rbp), %xmm0
addsd -24(%rbp), %xmm0
movq %xmm0, %rax
movq %rax, %xmm0
popq %rbp
ret |
https://godbolt.org/z/oedrhM1v1 for device code (the above is host code), maps to fma.rn.f32 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
probably wont have any unintended numerics consequences right
Shouldn't. Neither original formula nor the new one recovers the value when grad and moment are equal, but we've been living with it for a long time, so should be fine |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 jobs have failed, first few of them are: trunk / win-vs2022-cuda12.6-py3 / build Details for Dev Infra teamRaised by workflow job |
Windows error is real |
For the three fma's:
do you think it makes more sense static_cast How much would the loss of precision on |
Casting betas to opmath is acceptable |
3cf034b
to
2b31134
Compare
cc @janeyx99 looks like we are now doing computations in double in the fused optimizer, I don't think we should |
Switch the calculation of lerps in Adam's fused CUDA kernel to use std::fma, as proposed by @crcrpar .