@@ -338,6 +338,8 @@ class Optimizer:
338
338
_optimizer_load_state_dict_pre_hooks : 'OrderedDict[int, Callable[["Optimizer", StateDict], Optional[StateDict]]]'
339
339
_optimizer_load_state_dict_post_hooks : 'OrderedDict[int, Callable[["Optimizer"], None]]'
340
340
341
+ _dtype_policy : dict [str , Callable [[torch .Tensor ], torch .dtype ]]
342
+
341
343
def __init__ (self , params : ParamsT , defaults : dict [str , Any ]) -> None : # noqa: D107
342
344
torch ._C ._log_api_usage_once ("python.optimizer" )
343
345
self .defaults = defaults
@@ -347,6 +349,7 @@ def __init__(self, params: ParamsT, defaults: dict[str, Any]) -> None: # noqa:
347
349
self ._optimizer_state_dict_post_hooks = OrderedDict ()
348
350
self ._optimizer_load_state_dict_pre_hooks = OrderedDict ()
349
351
self ._optimizer_load_state_dict_post_hooks = OrderedDict ()
352
+ self ._dtype_policy = OrderedDict ()
350
353
351
354
self ._patch_step_function ()
352
355
@@ -864,7 +867,7 @@ def load_state_dict(self, state_dict: StateDict) -> None:
864
867
865
868
if len (groups ) != len (saved_groups ):
866
869
raise ValueError (
867
- "loaded state dict has a different number of parameter groups"
870
+ "loaded state dict has a different number of " " parameter groups"
868
871
)
869
872
param_lens = (len (g ["params" ]) for g in groups )
870
873
saved_lens = (len (g ["params" ]) for g in saved_groups )
@@ -1000,6 +1003,70 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]
1000
1003
"""
1001
1004
raise NotImplementedError
1002
1005
1006
+ def dtype_policy (self ) -> dict [str , Callable [[torch .Tensor ], torch .dtype ]]:
1007
+ r"""Gets the dtype policy for the optimizer.
1008
+
1009
+ Returns the optimizer's dtype_policy. See the docs for set_dtype_policy for more details.
1010
+
1011
+ """
1012
+ return self ._dtype_policy
1013
+
1014
+ def set_dtype_policy (
1015
+ self , policy : dict [str , Callable [[torch .Tensor ], torch .dtype ]]
1016
+ ) -> None :
1017
+ r"""Set the dtype policy for the optimizer.
1018
+
1019
+ By default, the optimizer initializes state to be the same dtype as the parameter. This
1020
+ function allows the user to enable mixed precision training for the optimizer by specifying
1021
+ lower or higher precision dtypes for state corresponding to a parameter.
1022
+
1023
+ A dtype policy is a dictionary mapping optimizer state to a desired dtype given a parameter.
1024
+ For example, Adam(W) has state ``exp_avg`` and ``exp_avg_sq`` mapping to momentum and
1025
+ variance respectively. The default policy would semantically be the following:
1026
+
1027
+ .. code-block:: python
1028
+
1029
+ default_dtype_policy = {
1030
+ "exp_avg": lambda p: p.dtype,
1031
+ "exp_avg_sq": lambda p: p.dtype,
1032
+ }
1033
+
1034
+
1035
+ If we wanted momentum (exp_avg) to match the param but variance (exp_avg_sq) to be BF16 when
1036
+ the parameter is a float, then the policy would look like:
1037
+
1038
+ .. code-block:: python
1039
+
1040
+ mixed_precision_dtype_policy = {
1041
+ "exp_avg_sq": lambda p: torch.bfloat16 if p.dtype == torch.float else p.dtype
1042
+ # no need to specify "exp_avg" since the default will fall back to p's dtype already
1043
+ }
1044
+
1045
+ model = ...
1046
+ optim = torch.optim.AdamW(model.named_parameters())
1047
+ optim.set_dtype_policy(mixed_precision_dtype_policy)
1048
+
1049
+ # at this point, state has not been initialized
1050
+
1051
+ # run forward and backward
1052
+ loss = model(...)
1053
+ loss.backward()
1054
+
1055
+ # at first step, state will be initialized according to the set policy
1056
+ optim.step()
1057
+ optim.zero_grad()
1058
+
1059
+
1060
+ The new policy will only be applied for any new state initalized after the policy has been
1061
+ set. State loaded from an existing state_dict will not be affected. Previously initialized
1062
+ state will also not be affected.
1063
+
1064
+ Args:
1065
+ policy (Dict[str, Callable]): A dictionary mapping optimizer state keys (str) to a Callable
1066
+ that will intake the parameter.
1067
+ """
1068
+ self ._dtype_policy = policy
1069
+
1003
1070
@torch ._disable_dynamo
1004
1071
def add_param_group (self , param_group : dict [str , Any ]) -> None :
1005
1072
r"""Add a param group to the :class:`Optimizer` s `param_groups`.
0 commit comments