8000 POC for mixed prec optim frontend · pytorch/pytorch@010cb06 · GitHub
[go: up one dir, main page]

Skip to content

Commit 010cb06

Browse files
committed
POC for mixed prec optim frontend
ghstack-source-id: 251b9c3 Pull Request resolved: #146640
1 parent 99dd846 commit 010cb06

File tree

3 files changed

+113
-11
lines changed

3 files changed

+113
-11
lines changed

torch/_meta_registrations.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7060,10 +7060,10 @@ def _fn(self, *args, **kwargs):
70607060
@register_meta(aten.lerp)
70617061
@out_wrapper()
70627062
def lerp(start, end, weight):
7063-
torch._check(
7064-
start.dtype == end.dtype,
7065-
lambda: f"expected dtype {start.dtype} for `end`, but got dtype {end.dtype}",
7066-
)
7063+
# torch._check(
7064+
# start.dtype == end.dtype,
7065+
# lambda: f"expected dtype {start.dtype} for `end`, but got dtype {end.dtype}",
7066+
# )
70677067
args = [start, end]
70687068
if isinstance(weight, TensorLike):
70697069
if weight.ndim != 0:

torch/optim/adam.py

Lines changed: 41 additions & 6 deletions
206
Original file line numberDiff line numberDiff line change
@@ -166,24 +166,46 @@ def _init_group(
166166
state["step"] = (
167167
torch.zeros(
168168
(),
169-
dtype=_get_scalar_dtype(is_fused=group["fused"]),
169+
dtype=(
170+
_get_scalar_dtype(is_fused=group["fused"])
171+
if "step" not in self._dtype_policy
172+
else self._dtype_policy["step"](p)
173+
),
170174
device=p.device,
171175
)
172176
if group["capturable"] or group["fused"]
173177
else torch.tensor(0.0, dtype=_get_scalar_dtype())
174178
)
175179
# Exponential moving average of gradient values
176180
state["exp_avg"] = torch.zeros_like(
177-
p, memory_format=torch.preserve_format
181+
p,
182+
dtype=(
183+
p.dtype
184+
if "exp_avg" not in self._dtype_policy
185+
else self._dtype_policy["exp_avg"](p)
186+
),
187+
memory_format=torch.preserve_format,
178188
)
179189
# Exponential moving average of squared gradient values
180190
state["exp_avg_sq"] = torch.zeros_like(
181-
p, memory_format=torch.preserve_format
191+
p,
192+
dtype=(
193+
p.dtype
194+
if "exp_avg_sq" not in self._dtype_policy
195+
else self._dtype_policy["exp_avg_sq"](p)
196+
),
197+
memory_format=torch.preserve_format,
182198
)
183199
if group["amsgrad"]:
184200
# Maintains max of all exp. moving avg. of sq. grad. values
185201
state["max_exp_avg_sq"] = torch.zeros_like(
186-
p, memory_format=torch.preserve_format
202+
p,
203+
dtype=(
204+
p.dtype
205+
if "max_exp_avg_sq" not in self._dtype_policy
+
else self._dtype_policy["max_exp_avg_sq"](p)
207+
),
208+
memory_format=torch.preserve_format,
187209
)
188210

189211
exp_avgs.append(state["exp_avg"])
@@ -384,8 +406,16 @@ def _single_tensor_adam(
384406

385407
for i, param in enumerate(params):
386408
grad = grads[i] if not maximize else -grads[i]
387-
exp_avg = exp_avgs[i]
388-
exp_avg_sq = exp_avg_sqs[i]
409+
exp_avg = (
410+
exp_avgs[i]
411+
if exp_avgs[i].dtype == grad.dtype
412+
else exp_avgs[i].to(grad.dtype)
413+
)
414+
exp_avg_sq = (
415+
exp_avg_sqs[i]
416+
if exp_avg_sqs[i].dtype == grad.dtype
417+
else exp_avg_sqs[i].to(grad.dtype)
418+
)
389419
step_t = state_steps[i]
390420

391421
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
@@ -530,6 +560,11 @@ def _single_tensor_adam(
530560
if amsgrad and torch.is_complex(params[i]):
531561
max_exp_avg_sqs[i] = torch.view_as_complex(max_exp_avg_sqs[i])
532562

563+
if exp_avgs[i].dtype != exp_avg.dtype:
564+
exp_avgs[i].copy_(exp_avg)
565+
if exp_avg_sqs[i].dtype != exp_avg_sq.dtype:
566+
exp_avg_sqs[i].copy_(exp_avg_sq)
567+
533568

534569
def _multi_tensor_adam(
535570
params: list[Tensor],

torch/optim/optimizer.py

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,8 @@ class Optimizer:
338338
_optimizer_load_state_dict_pre_hooks: 'OrderedDict[int, Callable[["Optimizer", StateDict], Optional[StateDict]]]'
339339
_optimizer_load_state_dict_post_hooks: 'OrderedDict[int, Callable[["Optimizer"], None]]'
340340

341+
_dtype_policy: dict[str, Callable[[torch.Tensor], torch.dtype]]
342+
341343
def __init__(self, params: ParamsT, defaults: dict[str, Any]) -> None: # noqa: D107
342344
torch._C._log_api_usage_once("python.optimizer")
343345
self.defaults = defaults
@@ -347,6 +349,7 @@ def __init__(self, params: ParamsT, defaults: dict[str, Any]) -> None: # noqa:
347349
self._optimizer_state_dict_post_hooks = OrderedDict()
348350
self._optimizer_load_state_dict_pre_hooks = OrderedDict()
349351
self._optimizer_load_state_dict_post_hooks = OrderedDict()
352+
self._dtype_policy = OrderedDict()
350353

351354
self._patch_step_function()
352355

@@ -864,7 +867,7 @@ def load_state_dict(self, state_dict: StateDict) -> None:
864867

865868
if len(groups) != len(saved_groups):
866869
raise ValueError(
867-
"loaded state dict has a different number of parameter groups"
870+
"loaded state dict has a different number of " "parameter groups"
868871
)
869872
param_lens = (len(g["params"]) for g in groups)
870873
saved_lens = (len(g["params"]) for g in saved_groups)
@@ -1000,6 +1003,70 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]
10001003
"""
10011004
raise NotImplementedError
10021005

1006+
def dtype_policy(self) -> dict[str, Callable[[torch.Tensor], torch.dtype]]:
1007+
r"""Gets the dtype policy for the optimizer.
1008+
1009+
Returns the optimizer's dtype_policy. See the docs for set_dtype_policy for more details.
1010+
1011+
"""
1012+
return self._dtype_policy
1013+
1014+
def set_dtype_policy(
1015+
self, policy: dict[str, Callable[[torch.Tensor], torch.dtype]]
1016+
) -> None:
1017+
r"""Set the dtype policy for the optimizer.
1018+
1019+
By default, the optimizer initializes state to be the same dtype as the parameter. This
1020+
function allows the user to enable mixed precision training for the optimizer by specifying
1021+
lower or higher precision dtypes for state corresponding to a parameter.
1022+
1023+
A dtype policy is a dictionary mapping optimizer state to a desired dtype given a parameter.
1024+
For example, Adam(W) has state ``exp_avg`` and ``exp_avg_sq`` mapping to momentum and
1025+
variance respectively. The default policy would semantically be the following:
1026+
1027+
.. code-block:: python
1028+
1029+
default_dtype_policy = {
1030+
"exp_avg": lambda p: p.dtype,
1031+
"exp_avg_sq": lambda p: p.dtype,
1032+
}
1033+
1034+
1035+
If we wanted momentum (exp_avg) to match the param but variance (exp_avg_sq) to be BF16 when
1036+
the parameter is a float, then the policy would look like:
1037+
1038+
.. code-block:: python
1039+
1040+
mixed_precision_dtype_policy = {
1041+
"exp_avg_sq": lambda p: torch.bfloat16 if p.dtype == torch.float else p.dtype
1042+
# no need to specify "exp_avg" since the default will fall back to p's dtype already
1043+
}
1044+
1045+
model = ...
1046+
optim = torch.optim.AdamW(model.named_parameters())
1047+
optim.set_dtype_policy(mixed_precision_dtype_policy)
1048+
1049+
# at this point, state has not been initialized
1050+
1051+
# run forward and backward
1052+
loss = model(...)
1053+
loss.backward()
1054+
1055+
# at first step, state will be initialized according to the set policy
1056+
optim.step()
1057+
optim.zero_grad()
1058+
1059+
1060+
The new policy will only be applied for any new state initalized after the policy has been
1061+
set. State loaded from an existing state_dict will not be affected. Previously initialized
1062+
state will also not be affected.
1063+
1064+
Args:
1065+
policy (Dict[str, Callable]): A dictionary mapping optimizer state keys (str) to a Callable
1066+
that will intake the parameter.
1067+
"""
1068+
self._dtype_policy = policy
1069+
10031070
@torch._disable_dynamo
10041071
def add_param_group(self, param_group: dict[str, Any]) -> None:
10051072
r"""Add a param group to the :class:`Optimizer` s `param_groups`.

0 commit comments

Comments
 (0)
0