8000 Use std::fma for CUDA Adam kernel's lerps. by MeetThePatel · Pull Request #153097 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Use std::fma for CUDA Adam kernel's lerps. #153097

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

MeetThePatel
Copy link
Contributor

Switch the calculation of lerps in Adam's fused CUDA kernel to use std::fma, as proposed by @crcrpar .

Copy link
pytorch-bot bot commented May 7, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/153097

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit 2b31134 with merge base 11c64b7 (image):

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the release notes: cuda release notes category label May 7, 2025
@eqy eqy requested a review from crcrpar May 7, 2025 21:14
Copy link
Collaborator
@crcrpar crcrpar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you have perf numbers of before/after?

@Skylion007
Copy link
Collaborator

Would you have perf numbers of before/after?

I am not sure if this makes much of a perf benefit here TBH, the benefit is that it should be more accurate: https://en.cppreference.com/w/cpp/numeric/math/fma . My question is does ROCM/CUDA actually map this to proper cuda_fm_rn and cuda_fmaf_rn functions? I think so, but I don't see any documentation to ensure it doesn't fall back to CPU like emulation

@Skylion007 Skylion007 requested a review from ngimel May 8, 2025 14:09
@MeetThePatel
Copy link
Contributor Author

@Skylion007 I'm not sure if this is quite exactly what you're asking, but on Godbolt NVCC 12.5.1, with the example code:

double testing_fma(double a, double b, double c) {
    return std::fma(a, b, c);
}

double testing_manual(double a, double b, double c) {
    return a * b + c;
}

the output for -O3 is:

testing_fma(double, double, double):
        jmp     fma
testing_manual(double, double, double):
        mulsd   %xmm1, %xmm0
        addsd   %xmm2, %xmm0
        ret

and the output for -O0 is:

testing_fma(double, double, double):
        pushq   %rbp
        movq    %rsp, %rbp
        subq    $32, %rsp
        movsd   %xmm0, -8(%rbp)
        movsd   %xmm1, -16(%rbp)
        movsd   %xmm2, -24(%rbp)
        movsd   -24(%rbp), %xmm1
        movsd   -16(%rbp), %xmm0
        movq    -8(%rbp), %rax
        movapd  %xmm1, %xmm2
        movapd  %xmm0, %xmm1
        movq    %rax, %xmm0
        call    fma
        movq    %xmm0, %rax
        movq    %rax, %xmm0
        leave
        ret
testing_manual(double, double, double):
        pushq   %rbp
        movq    %rsp, %rbp
        movsd   %xmm0, -8(%rbp)
        movsd   %xmm1, -16(%rbp)
        movsd   %xmm2, -24(%rbp)
        movsd   -8(%rbp), %xmm0
        mulsd   -16(%rbp), %xmm0
        addsd   -24(%rbp), %xmm0
        movq    %xmm0, %rax
        movq    %rax, %xmm0
        popq    %rbp
        ret

@ngimel
Copy link
Collaborator
ngimel commented May 8, 2025

https://godbolt.org/z/oedrhM1v1 for device code (the above is host code), maps to fma.rn.f32

Copy link
Collaborator
@eqy eqy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably wont have any unintended numerics consequences right

@ngimel
Copy link
Collaborator
ngimel commented May 8, 2025

Shouldn't. Neither original formula nor the new one recovers the value when grad and moment are equal, but we've been living with it for a long time, so should be fine

@Skylion007
Copy link
Collaborator

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label May 13, 2025
@Skylion007 Skylion007 added the ciflow/rocm Trigger "default" config CI on ROCm label May 13, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: trunk / win-vs2022-cuda12.6-py3 / build

Details for Dev Infra team Raised by workflow job

@ngimel
Copy link
Collaborator
ngimel commented May 13, 2025

Windows error is real fma<double, float, float, (int)0> and also we shouldn't be calling fma with these mixed args? @MeetThePatel can you please take a look?

@MeetThePatel
Copy link
Contributor Author

For the three fma's:

  • std::fma(-beta1, grad, grad)
  • std::fma(-beta2, grad * grad, grad * grad)
  • std::fma(beta2, exp_avg_sq, std::fma(-beta2, grad * grad, grad * grad))

do you think it makes more sense static_cast beta1 and beta2 down to opmath_t, rather than upcasting grad and exp_avg_sq to double?

How much would the loss of precision on beta1 and beta2 be (for going from double to opmath_t).

@ngimel
Copy link
Collaborator
ngimel commented May 13, 2025

Casting betas to opmath is acceptable

@pytorch-bot pytorch-bot bot removed ciflow/trunk Trigger trunk jobs on your pull request ciflow/rocm Trigger "default" config CI on ROCm labels May 14, 2025
@ngimel
Copy link
Collaborator
ngimel commented May 14, 2025

cc @janeyx99 looks like we are now doing computations in double in the fused optimizer, I don't think we should

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants
0