8000 Attempt a mixed precision fused adam · pytorch/pytorch@da5b584 · GitHub
[go: up one dir, main page]

Skip to content

Commit da5b584

Browse files
committed
Attempt a mixed precision fused adam
ghstack-source-id: 63ec4a7 Pull Request resolved: #147653
1 parent 358d92b commit da5b584

File tree

6 files changed

+424
-27
lines changed

6 files changed

+424
-27
lines changed

aten/src/ATen/native/cuda/ForeachFunctors.cuh

+80
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,34 @@ __device__ bool init_args(
7676
return all_aligned;
7777
}
7878

79+
template <
80+
int depth,
81+
typename param_type,
82+
typename grad_type,
83+
typename exp_avg_type,
84+
typename exp_avg_sq_type>
85+
__device__ bool init_args_mixed_prec(
86+
param_type** param_args,
87+
grad_type** grad_args,
88+
exp_avg_type** exp_avg_args,
89+
exp_avg_sq_type** exp_avg_sq_args,
90+
FusedOptimizerTensorListMetadata<depth>& tl,
91+
const int64_t chunk_idx,
92+
const int64_t chunk_size,
93+
const int64_t tensor_loc) {
94+
*param_args =
95+
(param_type*)tl.addresses[0][tensor_loc] + chunk_idx * chunk_size;
96+
*grad_args = (grad_type*)tl.addresses[1][tensor_loc] + chunk_idx * chunk_size;
97+
*exp_avg_args =
98+
(exp_avg_type*)tl.addresses[2][tensor_loc] + chunk_idx * chunk_size;
99+
*exp_avg_sq_args =
100+
(exp_avg_sq_type*)tl.addresses[3][tensor_loc] + chunk_idx * chunk_size;
101+
102+
bool all_aligned = is_aligned(*param_args) && is_aligned(*grad_args) &&
103+
is_aligned(*exp_avg_args) && is_aligned(*exp_avg_sq_args);
104+
return all_aligned;
105+
}
106+
79107
template <int depth, typename T>
80108
__device__ void load_args(
81109
T r_args[][kILP],
@@ -95,6 +123,43 @@ __device__ void load_args(
95123
}
96124
}
97125

126+
template <
127+
typename T,
128+
typename param_type,
129+
typename grad_type,
130+
typename exp_avg_type,
131+
typename exp_avg_sq_type>
132+
__device__ void load_args(
133+
T r_args[][kILP],
134+
const param_type* param_args,
135+
const grad_type* grad_args,
136+
const exp_avg_type* exp_avg_args,
137+
const exp_avg_sq_type* exp_avg_sq_args,
138+
const int64_t i_start,
139+
const int64_t chunk_size,
140+
const int64_t n) {
141+
#pragma unroll
142+
for (int ii = 0; ii < kILP; ii++) {
143+
const auto i = i_start + threadIdx.x + ii * blockDim.x;
144+
r_args[0][ii] = 0;
145+
if (i < n && i < chunk_size) {
146+
r_args[0][ii] = static_cast<T>(param_args[i]);
147+
}
148+
r_args[1][ii] = 0;
149+
if (i < n && i < chunk_size) {
150+
r_args[1][ii] = static_cast<T>(grad_args[i]);
151+
}
152+
r_args[2][ii] = 0;
153+
if (i < n && i < chunk_size) {
154+
r_args[2][ii] = static_cast<T>(exp_avg_args[i]);
155+
}
156+
r_args[3][ii] = 0;
157+
if (i < n && i < chunk_size) {
158+
r_args[3][ii] = static_cast<T>(exp_avg_sq_args[i]);
159+
}
160+
}
161+
}
162+
98163
template <typename T>
99164
__device__ void store_args(
100165
T* dst,
@@ -110,6 +175,21 @@ __device__ void store_args(
110175
}
111176
}
112177

178+
template <typename dT, typename sT>
179+
__device__ void store_args(
180+
dT* dst,
181+
sT* src,
182+
const int64_t i_start,
183+
const int64_t chunk_size,
184+
const int64_t n) {
185+
#pragma unroll
186+
for (int ii = 0; ii < kILP; ii++) {
187+
const int64_t i = i_start + threadIdx.x + ii * blockDim.x;
188+
if (i < n && i < chunk_size)
189+
dst[i] = static_cast<dT>(src[ii]);
190+
}
191+
}
192+
113193
template <int res_arg_index, typename Op, typename T, typename opmath_t>
114194
__device__ __forceinline__ void binary_op_scalar(
115195
T r_args[][kILP],

aten/src/ATen/native/cuda/FusedAdamKernel.cu

+5-4
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,11 @@ void _fused_adam_kernel_cuda_(
5050
grad_scale,
5151
found_inf);
5252
} else {
53-
TORCH_CHECK(
54-
at::native::check_fast_path_restrictions(
55-
{params, grads, exp_avgs, exp_avg_sqs}),
56-
"params, grads, exp_avgs, and exp_avg_sqs must have same dtype, device, and layout");
53+
// TORCH_CHECK(
54+
// at::native::check_fast_path_restrictions(
55+
// {params, grads, exp_avgs, exp_avg_sqs}),
56+
// "params, grads, exp_avgs, and exp_avg_sqs must have same dtype,
57+
// device, and layout");
5758
_fused_adam_cuda_impl_(
5859
params,
5960
grads,

aten/src/ATen/native/cuda/fused_adam_impl.cu

+51-20
Original file line numberDiff line numberDiff line change
@@ -31,26 +31,57 @@ void _fused_adam_cuda_impl_(
3131
found_inf.has_value() ? found_inf->data_ptr<float>() : nullptr;
3232
const float* lr_ptr = nullptr;
3333

34-
AT_DISPATCH_FLOATING_TYPES_AND2(
35-
kHalf,
36-
kBFloat16,
37-
params[0].scalar_type(),
38-
"fused_adam_kernel_cuda",
39-
[&]() {
40-
multi_tensor_apply_for_fused_optimizer<4>(
41-
tensor_lists,
42-
state_steps,
43-
FusedAdamMathFunctor<scalar_t, 4, ADAM_MODE::ORIGINAL, false>(),
44-
lr_ptr, // unused
45-
lr,
46-
beta1,
47-
beta2,
48-
weight_decay,
49-
eps,
50-
maximize,
51-
grad_scale_ptr,
52-
found_inf_ptr);
53-
});
34+
if (params[0].scalar_type() != exp_avgs[0].scalar_type()) {
35+
AT_DISPATCH_FLOATING_TYPES_AND2(
36+
kHalf,
37+
kBFloat16,
38+
params[0].scalar_type(),
39+
"fused_adam_kernel_cuda",
40+
[&]() {
41+
multi_tensor_apply_for_fused_optimizer<4>(
42+
tensor_lists,
43+
state_steps,
44+
FusedAdamMathFunctorMP<
45+
scalar_t,
46+
float,
47+
float,
48+
BFloat16,
49+
BFloat16,
50+
4,
51+
ADAM_MODE::ORIGINAL,
52+
false>(),
53+
lr_ptr, // unused
54+
lr,
55+
beta1,
56+
beta2,
57+
weight_decay,
58+
eps,
59+
maximize,
60+
grad_scale_ptr,
61+
found_inf_ptr);
62+
});
63+
} else {
64+
AT_DISPATCH_FLOATING_TYPES_AND2(
65+
kHalf,
66+
kBFloat16,
67+
params[0].scalar_type(),
68+
"fused_adam_kernel_cuda",
69+
[&]() {
70+
multi_tensor_apply_for_fused_optimizer<4>(
71+
tensor_lists,
72+
state_steps,
73+
FusedAdamMathFunctor<scalar_t, 4, ADAM_MODE::ORIGINAL, false>(),
74+
lr_ptr, // unused
75+
lr,
76+
beta1,
77+
beta2,
78< 76E0 /code>+
weight_decay,
79+
eps,
80+
maximize,
81+
grad_scale_ptr,
82+
found_inf_ptr);
83+
});
84+
}
5485
}
5586

5687
// The following overload simply has a Tensor lr

0 commit comments

Comments
 (0)
0