-
Notifications
You must be signed in to change notification settings - Fork 24.3k
Use opmath_t and not double compute in fused optimizers #153649
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
Where in the call stack does it make the most sense to do the casting? For example, in Adam, which would be the best to cast:
void _fused_adam_cuda_impl_(
at::TensorList params,
at::TensorList grads,
at::TensorList exp_avgs,
at::TensorList exp_avg_sqs,
at::TensorList state_steps,
const double lr,
const double beta1,
const double beta2,
const double weight_decay,
const double eps,
const bool maximize,
const std::optional<at::Tensor>& grad_scale,
const std::optional<at::Tensor>& found_inf) {
std::vector<std::vector<at::Tensor>> tensor_lists{
params.vec(), grads.vec(), exp_avgs.vec(), exp_avg_sqs.vec()};
const float* grad_scale_ptr =
grad_scale.has_value() ? grad_scale->data_ptr<float>() : nullptr;
const float* found_inf_ptr =
found_inf.has_value() ? found_inf->data_ptr<float>() : nullptr;
const float* lr_ptr = nullptr;
AT_DISPATCH_FLOATING_TYPES_AND2(
kHalf,
kBFloat16,
params[0].scalar_type(),
"fused_adam_kernel_cuda",
[&]() {
multi_tensor_apply_for_fused_optimizer<4>(
tensor_lists,
state_steps,
FusedAdamMathFunctor<scalar_t, 4, ADAM_MODE::ORIGINAL, false>(),
lr_ptr, // unused
/// CAST HERE <-----------------------------
/// using opmath_t = at::opmath_type<scalar_type>;
/// static_cast<opmath_t>(lr),
lr,
beta1,
beta2,
weight_decay,
eps,
maximize,
grad_scale_ptr,
found_inf_ptr);
});
}
template <typename scalar_type, int depth, ADAM_MODE adam_mode, bool amsgrad>
struct FusedAdamMathFunctor {
static_assert(
depth == 4 || depth == 5,
"depth of 4 for Adam, depth of 5 for Adam with AMSGrad.");
using opmath_t = at::opmath_type<scalar_type>;
C10_DEVICE __forceinline__ void operator()(
int chunk_size,
FusedOptimizerTensorListMetadata<depth>& tl,
const float* lr_ptr,
const double& lr,
const double& beta1,
const double& beta2,
const double& weight_decay,
const double& eps,
const bool& maximize,
const float* grad_scale_ptr,
const float* found_inf_ptr) {
const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
const double lr_double = lr_ptr ? *lr_ptr : lr;
if (found_inf_ptr && *found_inf_ptr == 1) {
return;
}
const auto [bias_correction1, bias_correction2_sqrt] =
[&]() -> std::pair<double, double> {
auto* step_count =
reinterpret_cast<const float*>(tl.state_steps_addresses[tensor_loc]);
const auto bias_correction1 = 1 - at::native::pow_(beta1, *step_count);
const auto bias_correction2 = 1 - at::native::pow_(beta2, *step_count);
const auto bias_correction2_sqrt = std::sqrt(bias_correction2);
return {bias_correction1, bias_correction2_sqrt};
}();
scalar_type* args[depth];
scalar_type r_args[depth][kILP];
const auto n = tl.numel_for_tensor[tensor_loc] - chunk_idx * chunk_size;
const bool all_aligned{
init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc)};
if ((n % kILP == 0) && (chunk_size % kILP == 0) && all_aligned) {
for (int64_t i_start = threadIdx.x;
i_start * kILP < n && i_start * kILP < chunk_size;
i_start += blockDim.x) {
#pragma unroll
for (int i = 0; i < depth; i++) {
load_store(r_args[i], args[i], 0, i_start);
}
adam_math<scalar_type, opmath_t, depth, adam_mode, amsgrad>(
r_args,
/// CAST HERE <-----------------------------
lr_double,
beta1,
beta2,
weight_decay,
eps,
maximize,
grad_scale_ptr,
found_inf_ptr,
bias_correction1,
bias_correction2_sqrt);
#pragma unroll
for (int i = 0; i < depth; i++) {
if (i != kGradIdx || grad_scale_ptr) {
load_store(args[i], r_args[i], i_start, 0);
}
}
}
} else {
for (int64_t i_start = 0; i_start < n && i_start < chunk_size;
i_start += blockDim.x * kILP) {
load_args<depth>(r_args, args, i_start, chunk_size, n);
adam_math<scalar_type, opmath_t, depth, adam_mode, amsgrad>(
r_args,
lr_double,
beta1,
beta2,
weight_decay,
eps,
maximize,
grad_scale_ptr,
found_inf_ptr,
bias_correction1,
bias_correction2_sqrt);
#pragma unroll
for (int i = 0; i < depth; i++) {
if (i != kGradIdx || grad_scale_ptr) {
store_args(args[i], r_args[i], i_start, chunk_size, n);
}
}
}
}
}
};
template <
typename scalar_type,
typename opmath_t,
int depth,
ADAM_MODE adam_mode,
bool amsgrad>
C10_DEVICE inline void adam_math(
scalar_type r_args[depth][kILP],
const double& lr,
const double& beta1,
const double& beta2,
const double& weight_decay,
const double& eps,
const bool& maximize,
const float* grad_scale_ptr,
const float* found_inf_ptr,
const opmath_t& bias_correction1,
const opmath_t& bias_correction2_sqrt) {
static_assert(depth == 4 || depth == 5);
#pragma unroll
for (int ii = 0; ii < kILP; ii++) {
// Load values.
opmath_t param = static_cast<opmath_t>(r_args[kParamIdx][ii]);
opmath_t grad = static_cast<opmath_t>(r_args[kGradIdx][ii]);
opmath_t casted_beta1 = static_cast<opmath_t>(beta1);
opmath_t casted_beta2 = static_cast<opmath_t>(beta2);
if (grad_scale_ptr) {
grad /= (static_cast<double>(*grad_scale_ptr));
}
const opmath_t grad_to_store = grad;
if (maximize) {
grad = -grad;
}
opmath_t exp_avg = static_cast<opmath_t>(r_args[kExpAvgIdx][ii]);
opmath_t exp_avg_sq = static_cast<opmath_t>(r_args[kExpAvgSqIdx][ii]);
opmath_t max_exp_avg_sq;
if (amsgrad) {
max_exp_avg_sq = static_cast<opmath_t>(r_args[kMaxExpAvgSqIdx][ii]);
}
// Update param, grad, 1st and 2nd order momentum.
if (weight_decay != 0) { /// CAST HERE <-----------------------------
if constexpr (adam_mode == ADAM_MODE::ORIGINAL) {
grad += param * weight_decay;
} else if constexpr (adam_mode == ADAM_MODE::ADAMW) {
param -= lr * weight_decay * param;
}
}
exp_avg =
std::fma(casted_beta1, exp_avg, std::fma(-casted_beta1, grad, grad));
exp_avg_sq = std::fma(
casted_beta2,
exp_avg_sq,
std::fma(-casted_beta2, grad * grad, grad * grad));
const opmath_t step_size = lr / bias_correction1;
opmath_t denom;
if (amsgrad) {
max_exp_avg_sq = std::max(max_exp_avg_sq, exp_avg_sq);
denom = (std::sqrt(max_exp_avg_sq) / bias_correction2_sqrt) + eps;
} else {
denom = (std::sqrt(exp_avg_sq) / bias_correction2_sqrt) + eps;
}
param -= step_size * exp_avg / denom;
// Store results.
r_args[kParamIdx][ii] = param;
if (grad_scale_ptr) {
r_args[kGradIdx][ii] = grad_to_store;
}
r_args[kExpAvgIdx][ii] = exp_avg;
r_args[kExpAvgSqIdx][ii] = exp_avg_sq;
if (amsgrad) {
r_args[kMaxExpAvgSqIdx][ii] = max_exp_avg_sq;
}
}
} |
@MeetThePatel After talking to @ngimel we think computing in double was a mistake from the start so we should cast everything as early as possible (so option 1)! @ngimel would it also save us the slightest perf to not send over doubles to the kernel? |
To be clear, all three options here result in the same computed values, that's purely a readability issue. |
Uh oh!
There was an error while loading. Please reload this page.
Look into and fix #153097 (comment)
cc @msaroufim @jerryzh168 @vincentqb @jbschlosser @albanD @crcrpar @ptrblck @eqy
The text was updated successfully, but these errors were encountered: