8000 Update on "Attempt a mixed precision fused adam" · pytorch/pytorch@6a74269 · GitHub
[go: up one dir, main page]

Skip to content

Commit 6a74269

Browse files
committed
Update on "Attempt a mixed precision fused adam"
[ghstack-poisoned]
1 parent 6815585 commit 6a74269

File tree

6 files changed

+100
-75
lines changed

6 files changed

+100
-75
lines changed

aten/src/ATen/native/ForeachUtils.h

Lines changed: 35 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -326,41 +326,41 @@ inline FlatMap _group_tensors_by_first_tensors_device_and_dtype(
326326
"-th Tensor is not.");
327327
return {t->device(), t->scalar_type()};
328328
}();
329-
TORCH_CHECK(
330-
std::all_of(
331-
nested_tensorlist.cbegin(),
332-
nested_tensorlist.cend(),
333-
[&](const auto& tensorlist) -> bool {
334-
if (tensorlist.size() == 0) {
335-
return true;
336-
}
337-
const auto& tensor = tensorlist[tensor_index];
338-
// note(crcrpar): Currently the scope of this function is
339-
// optimizers so there could be `state_steps` and other scalars
340-
// whose elements are float tensors no matter what the parameter's
341-
// dtype is.
342-
if (!tensor.has_value()) {
343-
return true;
344-
} else {
345-
const auto s = tensor->scalar_type();
346-
const auto d = tensor->device();
347-
// Note: `step` or `state_step` is float32 by default.
348-
if (key.first == d) {
349-
return key.second == s || s == at::ScalarType::Float ||
350-
s == at::ScalarType::Double;
351-
} else if (d.is_cpu()) {
352-
// note(crcrpar): There are some test cases (e.g.
353-
// TestOptim::test_adam) where state_steps are on CPU and the
354-
// others are on CUDA. Currently a state_step Tensor has the
355-
// dtype of float.
356-
return s == at::ScalarType::Float ||
357-
s == at::ScalarType::Double;
358-
} else {
359-
return false;
360-
}
361-
}
362-
}),
363-
"Tensors of the same index must be on the same device and the same dtype except `step` tensors that can be CPU and float32/64 notwithstanding");
329+
// TORCH_CHECK(
330+
// std::all_of(
331+
// nested_tensorlist.cbegin(),
332+
// nested_tensorlist.cend(),
333+
// [&](const auto& tensorlist) -> bool {
334+
// if (tensorlist.size() == 0) {
335+
// return true;
336+
// }
337+
// const auto& tensor = tensorlist[tensor_index];
338+
// // note(crcrpar): Currently the scope of this function is
339+
// // optimizers so there could be `state_steps` and other scalars
340+
// // whose elements are float tensors no matter what the parameter's
341+
// // dtype is.
342+
// if (!tensor.has_value()) {
343+
// return true;
344+
// } else {
345+
// const auto s = tensor->scalar_type();
346+
// const auto d = tensor->device();
347+
// // Note: `step` or `state_step` is float32 by default.
348+
// if (key.first == d) {
349+
// return key.second == s || s == at::ScalarType::Float ||
350+
// s == at::ScalarType::Double;
351+
// } else if (d.is_cpu()) {
352+
// // note(crcrpar): There are some test cases (e.g.
353+
// // TestOptim::test_adam) where state_steps are on CPU and the
354+
// // others are on CUDA. Currently a state_step Tensor has the
355+
// // dtype of float.
356+
// return s == at::ScalarType::Float ||
357+
// s == at::ScalarType::Double;
358+
// } else {
359+
// return false;
360+
// }
361+
// }
362+
// }),
363+
// "Tensors of the same index must be on the same device and the same dtype except `step` tensors that can be CPU and float32/64 notwithstanding");
364364
if (!grouped_tensors_with_indices.count(key)) {
365365
grouped_tensors_with_indices.insert(
366366
{key,

aten/src/ATen/native/cuda/FusedAdamWKernel.cu

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,11 @@ void _fused_adamw_kernel_cuda_(
5050
grad_scale,
5151
found_inf);
5252
} else {
53-
TORCH_CHECK(
54-
at::native::check_fast_path_restrictions(
55-
{params, grads, exp_avgs, exp_avg_sqs}),
56-
"params, grads, exp_avgs, and exp_avg_sqs must have same dtype, device, and layout");
53+
// TORCH_CHECK(
54+
// at::native::check_fast_path_restrictions(
55+
// {params, grads, exp_avgs, exp_avg_sqs}),
56+
// "params, grads, exp_avgs, and exp_avg_sqs must have same dtype,
57+
// device, and layout");
5758
_fused_adamw_cuda_impl_(
5859
params,
5960
grads,

aten/src/ATen/native/cuda/fused_adam_impl.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ void _fused_adam_cuda_impl_(
3636
kHalf,
3737
kBFloat16,
3838
params[0].scalar_type(),
39-
"fused_adam_kernel_cuda",
39+
"fused_adam_mp_kernel_cuda",
4040
[&]() {
4141
multi_tensor_apply_for_fused_optimizer<4>(
4242
tensor_lists,

aten/src/ATen/native/cuda/fused_adamw_impl.cu

Lines changed: 51 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -31,26 +31,57 @@ void _fused_adamw_cuda_impl_(
3131
found_inf.has_value() ? found_inf->data_ptr<float>() : nullptr;
3232
const float* lr_ptr = nullptr;
3333

34-
AT_DISPATCH_FLOATING_TYPES_AND2(
35-
kHalf,
36-
kBFloat16,
37-
params[0].scalar_type(),
38-
"fused_adamw_kernel_cuda",
39-
[&]() {
40-
multi_tensor_apply_for_fused_optimizer<4>(
41-
tensor_lists,
42-
state_steps,
43-
FusedAdamMathFunctor<scalar_t, 4, ADAM_MODE::ADAMW, false>(),
44-
lr_ptr, // unused
45-
lr,
46-
beta1,
47-
beta2,
48-
weight_decay,
49-
eps,
50-
maximize,
51-
grad_scale_ptr,
52-
found_inf_ptr);
53-
});
34+
if (params[0].scalar_type() != exp_avgs[0].scalar_type()) {
35+
AT_DISPATCH_FLOATING_TYPES_AND2(
36+
kHalf,
37+
kBFloat16,
38+
params[0].scalar_type(),
39+
"fused_adamw_kernel_cuda",
40+
[&]() {
41+
multi_tensor_apply_for_fused_optimizer<4>(
42+
tensor_lists,
43+
state_steps,
44+
FusedAdamMathFunctorMP<
45+
scalar_t,
46+
float,
47+
float,
48+
BFloat16,
49+
BFloat16,
50+
4,
51+
ADAM_MODE::ADAMW,
52+
false>(),
53+
lr_ptr, // unused
54+
lr,
55+
beta1,
56+
beta2,
57+
weight_decay,
58+
eps,
59+
maximize,
60+
grad_scale_ptr,
61+
found_inf_ptr);
62+
});
63+
} else {
64+
AT_DISPATCH_FLOATING_TYPES_AND2(
65+
kHalf,
66+
kBFloat16,
67+
params[0].scalar_type(),
68+
"fused_adamw_kernel_cuda",
69+
[&]() {
70+
multi_tensor_apply_for_fused_optimizer<4>(
71+
tensor_lists,
72+
state_steps,
73+
FusedAdamMathFunctor<scalar_t, 4, ADAM_MODE::ADAMW, false>(),
74+
lr_ptr, // unused
75+
lr,
76+
beta1,
77+
beta2,
78+
weight_decay,
79+
eps,
80+
maximize,
81+
grad_scale_ptr,
82+
found_inf_ptr);
83+
});
84+
}
5485
}
5586

5687
// The following overload simply has a Tensor lr

test/test_optim.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2237,17 +2237,16 @@ def test_non_empty_state(self, device, dtype, optim_info):
22372237

22382238
@onlyCUDA
22392239
@optims(
2240-
[o for o in optim_db if o.optim_cls.__name__ == "Adam"], dtypes=[torch.float32]
2240+
[o for o in optim_db if o.optim_cls.__name__ in ["Adam", "AdamW"]],
2241+
dtypes=[torch.float32],
22412242
)
2242-
def test_bf16_fused_adam(self, device, dtype, optim_info):
2243+
def test_bf16_fused(self, device, dtype, optim_info):
22432244
optim_inputs = optim_info.optim_inputs_func(device=device, dtype=dtype)
22442245
optim_cls = optim_info.optim_cls
22452246
for optim_input in optim_inputs:
22462247
kwargs = optim_input.kwargs
22472248
# currently not supported
2248-
if kwargs.get("amsgrad", False) or kwargs.get(
2249-
"decoupled_weight_decay", False
2250-
):
2249+
if kwargs.get("amsgrad", False):
22512250
continue
22522251
kwargs["fused"] = True
22532252

torch/optim/adam.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -837,16 +837,10 @@ def _fused_adam(
837837
lr_dict: Optional[DeviceDict] = (
838838
{lr.device: lr} if isinstance(lr, Tensor) and str(lr.device) != "cpu" else None
839839
)
840-
# grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
841-
# [params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps] # type: ignore[list-item]
842-
# )
843-
# replace this with better implementation
844-
grouped_tensors = {
845-
(params[0].device, params[0].dtype): (
846-
(params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps),
847-
None,
848-
)
849-
}
840+
# TODO: currently the check that the state are properly correspondent to their param dtype + device is removed!!!!
841+
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
842+
[params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps] # type: ignore[list-item]
843+
)
850844
for (device, _), (
851845
(
852846
device_params_,

0 commit comments

Comments
 (0)
0