10000 Use FMA for CUDA Adam kernel's lerps. by MeetThePatel · Pull Request #153097 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Use FMA for CUDA Adam kernel's lerps. #153097

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions aten/src/ATen/native/cuda/fused_adam_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ C10_DEVICE inline void adam_math(
// Load values.
opmath_t param = static_cast<opmath_t>(r_args[kParamIdx][ii]);
opmath_t grad = static_cast<opmath_t>(r_args[kGradIdx][ii]);
opmath_t casted_beta1 = static_cast<opmath_t>(beta1);
opmath_t casted_beta2 = static_cast<opmath_t>(beta2);
if (grad_scale_ptr) {
grad /= (static_cast<double>(*grad_scale_ptr));
}
Expand All @@ -62,10 +64,12 @@ C10_DEVICE inline void adam_math(
param -= lr * weight_decay * param;
}
}
// todo(crcrpar): use lerp
// ref: https://developer.nvidia.com/blog/lerp-faster-cuda/
exp_avg = beta1 * exp_avg + (1 - beta1) * grad;
exp_avg_sq = beta2 * exp_avg_sq + (1 - beta2) * grad * grad;
exp_avg =
std::fma(casted_beta1, exp_avg, std::fma(-casted_beta1, grad, grad));
exp_avg_sq = std::fma(
casted_beta2,
exp_avg_sq,
std::fma(-casted_beta2, grad * grad, grad * grad));
const opmath_t step_size = lr / bias_correction1;
opmath_t denom;
if (amsgrad) {
Expand Down
Loading
0