8000 Fix doc cosineannealinglr 152081 (#152936) · pytorch/pytorch@6a80064 · GitHub
[go: up one dir, main page]

Skip to content

Commit 6a80064

Browse files
Jacobgoss30pytorchmergebot
authored andcommitted
Fix doc cosineannealinglr 152081 (#152936)
## Summary This PR updates the docstring for `CosineAnnealingLR` to accurately reflect its recursive learning rate schedule. The previous docstring displayed only the SGDR closed-form expression, which doesn't match the actual recursive implementation in code. Changes: - Added the recursive update formula used in `get_lr()` - Retained the original closed-form SGDR expression for reference - Clarified that warm restarts are not implemented in this scheduler This addresses confusion raised in issue #152081. ## Related issue [#152081](#152081) ## Testing Doc-only change. Ran pre-commit to verify formatting. Pull Request resolved: #152936 Approved by: https://github.com/janeyx99
1 parent 3cd6935 commit 6a80064

File tree

1 file changed

+22
-22
lines changed

1 file changed

+22
-22
lines changed

torch/optim/lr_scheduler.py

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1000,39 +1000,39 @@ def _get_closed_form_lr(self):
10001000

10011001

10021002
class CosineAnnealingLR(LRScheduler):
1003-
r"""Set the learning rate of each parameter group using a cosine annealing schedule.
1003+
r"""
1004+
Set the learning rate of each parameter group using a cosine annealing schedule.
10041005
1005-
The :math:`\eta_{max}` is set to the initial lr and
1006-
:math:`T_{cur}` is the number of epochs since the last restart in SGDR:
1006+
The learning rate is updated recursively using:
10071007
10081008
.. math::
1009-
\begin{aligned}
1010-
\eta_t & = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1
1011-
+ \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right),
1012-
& T_{cur} \neq (2k+1)T_{max}; \\
1013-
\eta_{t+1} & = \eta_{t} + \frac{1}{2}(\eta_{max} - \eta_{min})
1014-
\left(1 - \cos\left(\frac{1}{T_{max}}\pi\right)\right),
1015-
& T_{cur} = (2k+1)T_{max}.
1016-
\end{aligned}
1017-
1018-
When last_epoch=-1, sets initial lr as lr. Notice that because the schedule
1019-
is defined recursively, the learning rate can be simultaneously modified
1020-
outside this scheduler by other operators. If the learning rate is set
1021-
solely by this scheduler, the learning rate at each step becomes:
1009+
\eta_{t+1} = \eta_{\min} + (\eta_t - \eta_{\min}) \cdot
1010+
\frac{1 + \cos\left(\frac{(T_{cur}+1) \pi}{T_{max}}\right)}
1011+
{1 + \cos\left(\frac{T_{cur} \pi}{T_{max}}\right)}
1012+
1013+
This implements a recursive approximation of the closed-form schedule proposed in
1014+
`SGDR: Stochastic Gradient Descent with Warm Restarts`_:
10221015
10231016
.. math::
1024-
\eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 +
1025-
\cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right)
1017+
\eta_t = \eta_{\min} + \frac{1}{2}(\eta_{\max} - \eta_{\min}) \left(
1018+
1 + \cos\left(\frac{T_{cur} \pi}{T_{max}}\right) \right)
10261019
1027-
It has been proposed in
1028-
`SGDR: Stochastic Gradient Descent with Warm Restarts`_. Note that this only
1029-
implements the cosine annealing part of SGDR, and not the restarts.
1020+
where:
1021+
1022+
- :math:`\eta_t` is the learning rate at step :math:`t`
1023+
- :math:`T_{cur}` is the number of epochs since the last restart
1024+
- :math:`T_{max}` is the maximum number of epochs in a cycle
1025+
1026+
Note:
1027+
Although SGDR includes periodic restarts, this implementation performs cosine annealing
1028+
**without restarts**, so :math:`T_{cur} = t` and increases monotonically with each call
1029+
to :meth:`step`.
10301030
10311031
Args:
10321032
optimizer (Optimizer): Wrapped optimizer.
10331033
T_max (int): Maximum number of iterations.
10341034
eta_min (float): Minimum learning rate. Default: 0.
1035-
last_epoch (int): The index of last epoch. Default: -1.
1035+
last_epoch (int): The index of the last epoch. Default: -1.
10361036
10371037
.. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
10381038
https://arxiv.org/abs/1608.03983

0 commit comments

Comments
 (0)
0