8000 ExponentialLR unexpectedly calls `step()` when init argument `last_epoch` is larger than -1 · Issue #102261 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

ExponentialLR unexpectedly calls step() when init argument last_epoch is larger than -1 #102261

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
yuxinyuan opened this issue May 25, 2023 · 2 comments · May be fixed by #149312
Open

ExponentialLR unexpectedly calls step() when init argument last_epoch is larger than -1 #102261

yuxinyuan opened this issue May 25, 2023 · 2 comments · May be fixed by #149312
Labels
actionable module: LrScheduler module: optimizer Related to torch.optim triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@yuxinyuan
Copy link
yuxinyuan commented May 25, 2023

🐛 Describe the bug

Currently, the init function of torch.optim.lr_scheduler._LRSchedule will call self.step() once. This causes a mismatch between the learning rate used by the optimizer and the closed_form_lr of ExponentialLR, when init argument last_epoch is larger than -1.

import torch

model = torch.nn.Linear(3, 3)
optim = torch.optim.AdamW(model.parameters())
sched = torch.optim.lr_scheduler.ExponentialLR(optim, gamma=0.999)

optim.step()
sched.step()
print("Optim & sched:")
print(optim.state_dict())
print(optim.param_groups[0]["lr"])
print(sched.state_dict())
print(sched._get_closed_form_lr())
print("")

# As if we are restoring from a checkpoint
optim2 = torch.optim.AdamW(model.parameters())
optim2.load_state_dict(optim.state_dict())
# Init scheduler with last_epoch=0
sched2 = torch.optim.lr_scheduler.ExponentialLR(optim2, gamma=0.999, last_epoch=0)

print("Optim2 & sched2:")
print(optim2.state_dict())
print(optim2.param_groups[0]["lr"])
print(sched2.state_dict())
print(sched2._get_closed_form_lr())
print("")

As the result shows, optim2 has lr 0.000998001, but the closed form lr of sched2 is 0.000999. This behavior causes confusion and inconsistency when one resumes training from a checkpoint.

Versions

Collecting environment information...
PyTorch version: 1.13.1
Is debug build: False
CUDA used to build PyTorch: 11.6
ROCM used to build PyTorch: N/A

cc @vincentqb @jbschlosser @albanD @janeyx99 @crcrpar

@soulitzer soulitzer added module: optimizer Related to torch.optim triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: LrScheduler labels May 25, 2023
@janeyx99
Copy link
Contributor

We will accept a fix for this discrepancy to unblock checkpointing cases

@Vetti420
Copy link

Yes this causes also to step when we call .build in RLlib.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
actionable module: LrScheduler module: optimizer Related to torch.optim triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
4 participants
0