8000 Add Trainer support for ReduceLROnPlateau (#23010) · githubhjs/transformers@9b43520 · GitHub
[go: up one dir, main page]

Skip to content

Commit 9b43520

Browse files
pie3636mmeloux
andauthored
Add Trainer support for ReduceLROnPlateau (huggingface#23010)
* Add Trainer support for ReduceLROnPlateau Fixes huggingface#16503 8000 * Remove training argument and add default instance --------- Co-authored-by: mmeloux <maxime.meloux@loria.fr>
1 parent cf7baf4 commit 9b43520

File tree

5 files changed

+103
-4
lines changed

5 files changed

+103
-4
lines changed

src/transformers/optimization.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import torch
2323
from torch import nn
2424
from torch.optim import Optimizer
25-
from torch.optim.lr_scheduler import LambdaLR
25+
from torch.optim.lr_scheduler import LambdaLR, ReduceLROnPlateau
2626

2727
from .trainer_utils import SchedulerType
2828
from .utils import logging
@@ -49,6 +49,21 @@ def get_constant_schedule(optimizer: Optimizer, last_epoch: int = -1):
4949
return LambdaLR(optimizer, lambda _: 1, last_epoch=last_epoch)
5050

5151

52+
def get_reduce_on_plateau_schedule(optimizer: Optimizer):
53+
"""
54+
Create a schedule with a constant learning rate that decreases when a metric has stopped improving.
55+
56+
Args:
57+
optimizer ([`~torch.optim.Optimizer`]):
58+
The optimizer for which to schedule the learning rate.
59+
60+
Return:
61+
`torch.optim.lr_scheduler.ReduceLROnPlateau` with the appropriate schedule.
62+
"""
63+
64+
return ReduceLROnPlateau(optimizer)
65+
66+
5267
def _get_constant_schedule_with_warmup_lr_lambda(current_step: int, *, num_warmup_steps: int):
5368
if current_step < num_warmup_steps:
5469
return float(current_step) / float(max(1.0, num_warmup_steps))
@@ -309,6 +324,7 @@ def get_inverse_sqrt_schedule(
309324
SchedulerType.CONSTANT: get_constant_schedule,
310325
SchedulerType.CONSTANT_WITH_WARMUP: get_constant_schedule_with_warmup,
311326
SchedulerType.INVERSE_SQRT: get_inverse_sqrt_schedule,
327+
SchedulerType.REDUCE_ON_PLATEAU: get_reduce_on_plateau_schedule,
312328
}
313329

314330

@@ -335,7 +351,7 @@ def get_scheduler(
335351
"""
336352
name = SchedulerType(name)
337353
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
338-
if name == SchedulerType.CONSTANT:
354+
if name == SchedulerType.CONSTANT or name == SchedulerType.REDUCE_ON_PLATEAU:
339355
return schedule_func(optimizer)
340356

341357
# All other schedulers require `num_warmup_steps`

src/transformers/trainer.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1997,7 +1997,9 @@ def _inner_training_loop(
19971997
self.optimizer.step()
19981998

19991999
if optimizer_was_run and not self.deepspeed:
2000-
self.lr_scheduler.step()
2000+
# Delay optimizer scheduling until metrics are generated
2001+
if not isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
2002+
self.lr_scheduler.step()
20012003

20022004
model.zero_grad()
20032005
self.state.global_step += 1
@@ -2288,6 +2290,10 @@ def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for
22882290
metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
22892291
self._report_to_hp_search(trial, self.state.global_step, metrics)
22902292

2293+
# Run delayed LR scheduler now that metrics are populated
2294+
if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
2295+
self.lr_scheduler.step(metrics[self.args.metric_for_best_model])
2296+
22912297
if self.control.should_save:
22922298
self._save_checkpoint(model, trial, metrics=metrics)
22932299
self.control = self.callback_handler.on_save(self.args, self.state, self.control)

src/transformers/trainer_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,7 @@ class SchedulerType(ExplicitEnum):
367367
CONSTANT = "constant"
368368
CONSTANT_WITH_WARMUP = "constant_with_warmup"
369369
INVERSE_SQRT = "inverse_sqrt"
370+
REDUCE_ON_PLATEAU = "reduce_lr_on_plateau"
370371

371372

372373
class TrainerMemoryTracker:

src/transformers/training_args.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1194,7 +1194,9 @@ def __post_init__(self):
11941194
f"https://github.com/huggingface/safetensors!"
11951195
)
11961196

1197-
if self.load_best_model_at_end and self.metric_for_best_model is None:
1197+
if (
1198+
self.load_best_model_at_end or self.lr_scheduler_type == SchedulerType.REDUCE_ON_PLATEAU
1199+
) and self.metric_for_best_model is None:
11981200
self.metric_for_best_model = "loss"
11991201
if self.greater_is_better is None and self.metric_for_best_model is not None:
12001202
self.greater_is_better = self.metric_for_best_model not in ["loss", "eval_loss"]
@@ -1234,6 +1236,12 @@ def __post_init__(self):
12341236
if not (self.sharded_ddp == "" or not self.sharded_ddp):
12351237
raise ValueError("sharded_ddp is not supported with bf16")
12361238

1239+
if self.lr_scheduler_type == SchedulerType.REDUCE_ON_PLATEAU:
1240+
if self.evaluation_strategy == IntervalStrategy.NO:
1241+
raise ValueError("lr_scheduler_type reduce_lr_on_plateau requires an eval strategy")
1242+
if not is_torch_available():
1243+
raise ValueError("lr_scheduler_type reduce_lr_on_plateau requires torch>=0.2.0")
1244+
12371245
self.optim = OptimizerNames(self.optim)
12381246
if self.adafactor:
12391247
warnings.warn(

tests/trainer/test_trainer.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -575,6 +575,74 @@ def test_custom_optimizer(self):
575575
self.assertFalse(torch.allclose(trainer.model.b, b))
576576
self.assertEqual(trainer.optimizer.state_dict()["param_groups"][0]["lr"], 1.0)
577577

578+
def test_reduce_lr_on_plateau_args(self):
579+
# test passed arguments for a custom ReduceLROnPlateau scheduler
580+
train_dataset = RegressionDataset(length=64)
581+
eval_dataset = RegressionDataset(length=64)
582+
args = TrainingArguments(
583+
"./regression",
584+
evaluation_strategy="epoch",
585+
metric_for_best_model="eval_loss",
586+
)
587+
model = RegressionModel()
588+
optimizer = torch.optim.SGD(model.parameters(), lr=1.0)
589+
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.2, patience=5, cooldown=2)
590+
trainer = Trainer(
591+
model, args, train_dataset=train_dataset, eval_dataset=eval_dataset, optimizers=(optimizer, lr_scheduler)
592+
)
593+
trainer.train()
594+
595+
self.assertIsInstance(trainer.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau)
596+
self.assertEqual(trainer.lr_scheduler.factor, 0.2)
597+
self.assertEqual(trainer.lr_scheduler.patience, 5)
598+
self.assertEqual(trainer.lr_scheduler.cooldown, 2)
599+
600+
def test_reduce_lr_on_plateau(self):
601+
# test the ReduceLROnPlateau scheduler
602+
603+
class TrainerWithLRLogs(Trainer):
604+
def log(self, logs):
605+
# the LR is computed after metrics and does not exist for the first epoch
606+
if hasattr(self.lr_scheduler, "_last_lr"):
607+
logs["learning_rate"] = self.lr_scheduler._last_lr
608+
super().log(logs)
609+
610+
train_dataset = RegressionDataset(length=64)
611+
eval_dataset = RegressionDataset(length=64)
612+
613+
args = TrainingArguments(
614+
"./regression",
615+
lr_scheduler_type="reduce_lr_on_plateau",
616+
evaluation_strategy="epoch",
617+
metric_for_best_model="eval_loss",
618+
num_train_epochs=10,
619+
learning_rate=0.2,
620+
)
621+
model = RegressionModel()
622+
trainer = TrainerWithLRLogs(model, args, train_dataset=train_dataset, eval_dataset=eval_dataset)
623+
trainer.train()
624+
625+
self.assertIsInstance(trainer.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau)
626+
patience = trainer.lr_scheduler.patience
627+
628+
logs = trainer.state.log_history[1:]
629+
best_loss = logs[0]["eval_loss"]
630+
bad_epochs = 0
631+
for i, log in enumerate(logs[:-1]): # Compare learning rate to next epoch's
632+
loss = log["eval_loss"]
633+
just_decreased = False
634+
if loss > best_loss:
635+
bad_epochs += 1
636+
if bad_epochs > patience:
637+
self.assertLess(logs[i + 1]["learning_rate"][0], log["learning_rate"][0])
638+
just_decreased = True
639+
bad_epochs = 0
640+
else:
641+
best_loss = loss
642+
bad_epochs = 0
643+
if not just_decreased:
644+
self.assertEqual(logs[i + 1]["learning_rate"][0], log["learning_rate"][0])
645+
578646
def test_adafactor_lr_none(self):
579647
# test the special case where lr=None, since Trainer can't not have lr_scheduler
580648

0 commit comments

Comments
 (0)
0