8000 Fix PolynomialLR power type. (#1440) · dotnet/TorchSharp@9a5ac0b · GitHub
[go: up one dir, main page]

Skip to content

Commit 9a5ac0b

Browse files
hiyuhMasaru Kimura
andauthored
Fix PolynomialLR power type. (#1440)
* Fix PolynomialLR power type. torch.optim.lr_scheduler.PolynomialLR power was typed int, but should be double. Non-integer power is widely used for common training recipe. E.g. torchvision's pre-trained semantic segmentation models uses PolynomialLR as main LR scheduler with power = 0.9; https://github.com/pytorch/vision/blob/main/references/segmentation/train.py#L201 See also https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.PolynomialLR.html * Update RELEASENOTES.md. --------- Co-authored-by: Masaru Kimura <masaru@hacarus.com>
1 parent 14e351f commit 9a5ac0b

File tree

4 files changed

+7
-6
lines changed

4 files changed

+7
-6
lines changed

RELEASENOTES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ __Bug Fixes__:
77

88
#1426 Sequential.eval() does not put model into eval mode<br/>
99
`torch.optim.lr_scheduler.LinearLR` `end_factor` default has been corrected, is now 1.0.<br/>
10+
`torch.optim.lr_scheduler.PolynomialLR` `power` type has been corrected, is now double.<br/>
1011

1112
# NuGet Version 0.105.0
1213

src/TorchSharp/Optimizers/LRScheduler.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,7 @@ public class PolynomialLR : LRScheduler
325325
/// <param name="last_epoch">The index of last epoch. Default: -1.</param>
326326
/// <param name="verbose"> If true, prints a message to stdout for each update. Default: false.</param>
327327
/// <returns>A scheduler</returns>
328-
public PolynomialLR(Optimizer optimizer, int total_iters = 5, int power = 1, int last_epoch = -1, bool verbose = false) : base(optimizer, last_epoch, verbose)
328+
public PolynomialLR(Optimizer optimizer, int total_iters = 5, double power = 1.0, int last_epoch = -1, bool verbose = false) : base(optimizer, last_epoch, verbose)
329329
{
330330
if (optimizer == null) throw new ArgumentNullException("optimizer");
331331
_power = power;
@@ -359,7 +359,7 @@ protected override IEnumerable<double> get_closed_form_lr()
359359
}
360360

361361
private double _total_iters;
362-
private int _power;
362+
private double _power;
363363
}
364364

365365
/// <summary>
@@ -1306,7 +1306,7 @@ public static LRScheduler MultiStepLR(Optimizer optimizer, IList<int> milestones
13061306
/// <param name="last_epoch">The index of last epoch. Default: -1.</param>
13071307
/// <param name="verbose"> If true, prints a message to stdout for each update. Default: false.</param>
13081308
/// <returns>A scheduler</returns>
1309-
public static LRScheduler PolynomialLR(Optimizer optimizer, int total_iters = 5, int power = 1, int last_epoch = -1, bool verbose = false)
1309+
public static LRScheduler PolynomialLR(Optimizer optimizer, int total_iters = 5, double power = 1, int last_epoch = -1, bool verbose = false)
13101310
{
13111311
return new impl.PolynomialLR(optimizer, total_iters, power, last_epoch, verbose);
13121312
}

test/TorchSharpTest/TestTorchTensorBugs.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -873,7 +873,7 @@ public void ValidatePolynomialLR()
873873

874874
double learning_rate = 0.1;
875875
var optimizer = torch.optim.SGD(seq.parameters(), learning_rate);
876-
var scheduler = torch.optim.lr_scheduler.PolynomialLR(optimizer, 10, 1);
876+
var scheduler = torch.optim.lr_scheduler.PolynomialLR(optimizer, 10, 1.0);
877877

878878
optimizer.zero_grad();
879879
optimizer.step();
@@ -907,7 +907,7 @@ public void ValidatePolynomialLR()
907907

908908
double learning_rate = 0.1;
909909
var optimizer = torch.optim.SGD(seq.parameters(), learning_rate);
910-
var scheduler = torch.optim.lr_scheduler.PolynomialLR(optimizer, 10, 2);
910+
var scheduler = torch.optim.lr_scheduler.PolynomialLR(optimizer, 10, 2.0);
911911

912912
optimizer.zero_grad();
913913
optimizer.step();

test/TorchSharpTest/TestTraining.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1654,7 +1654,7 @@ public void TrainingSGDSequentialLRWithAllClosedFormSchedulers()
16541654
var scheduler2 = torch.optim.lr_scheduler.StepLR(optimizer, 2);
16551655
var scheduler3 = torch.optim.lr_scheduler.MultiStepLR(optimizer, new[] { 2, 4 });
16561656
var scheduler4 = t 6183 orch.optim.lr_scheduler.ExponentialLR(optimizer);
1657-
var scheduler5 = torch.optim.lr_scheduler.PolynomialLR(optimizer, power: 2);
1657+
var scheduler5 = torch.optim.lr_scheduler.PolynomialLR(optimizer, power: 2.0);
16581658
var scheduler6 = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 5, 0.1);
16591659
var scheduler7 = torch.optim.lr_scheduler.LinearLR(optimizer, end_factor: 0.75);
16601660
var scheduler = torch.optim.lr_scheduler.SequentialLR(optimizer, new[] { scheduler0, scheduler1, scheduler2, scheduler3, scheduler4, scheduler5, scheduler6, scheduler7}, new[] { 5, 5, 5, 5, 5, 5, 5 });

0 commit comments

Comments
 (0)
0