8000 FIX Fixes bug in mlp when loading a pickle and partial_fit (#19631) · rth/scikit-learn@1d1f65a · GitHub
[go: up one dir, main page]

Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 1d1f65a

Browse files
FIX Fixes bug in mlp when loading a pickle and partial_fit (scikit-learn#19631)
Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
1 parent 44c40b3 commit 1d1f65a

File tree

5 files changed

+56
-26
lines changed

5 files changed

+56
-26
lines changed

doc/whats_new/v1.0.rst

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -555,7 +555,6 @@ Changelog
555555
Use ``var_`` instead.
556556
:pr:`18842` by :user:`Hong Shao Yang <hongshaoyang>`.
557557

558-
559558
:mod:`sklearn.neighbors`
560559
........................
561560

@@ -574,6 +573,13 @@ Changelog
574573
`__init__` and validates `weights` in `fit` instead. :pr:`20072` by
575574
:user:`Juan Carlos Alfaro Jiménez <alfaro96>`.
576575

576+
:mod:`sklearn.neural_network`
577+
.............................
578+
579+
- |Fix| :class:`neural_network.MLPClassifier` and
580+
:class:`neural_network.MLPRegressor` now correct supports continued training
581+
when loading from a pickled file. :pr:`19631` by `Thomas Fan`_.
582+
577583
:mod:`sklearn.pipeline`
578584
.......................
579585

sklearn/neural_network/_multilayer_perceptron.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -557,9 +557,8 @@ def _fit_stochastic(
557557
incremental,
558558
):
559559

560+
params = self.coefs_ + self.intercepts_
560561
if not incremental or not hasattr(self, "_optimizer"):
561-
params = self.coefs_ + self.intercepts_
562-
563562
if self.solver == "sgd":
564563
self._optimizer = SGDOptimizer(
565564
params,
@@ -642,7 +641,7 @@ def _fit_stochastic(
642641

643642
# update weights
644643
grads = coef_grads + intercept_grads
645-
self._optimizer.update_params(grads)
644+
self._optimizer.update_params(params, grads)
646645

647646
self.n_iter_ += 1
648647
self.loss_ = accumulated_loss / X.shape[0]

sklearn/neural_network/_stochastic_optimizers.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,6 @@ class BaseOptimizer:
1212
1313
Parameters
1414
----------
15-
params : list, length = len(coefs_) + len(intercepts_)
16-
The concatenated list containing coefs_ and intercepts_ in MLP model.
17-
Used for initializing velocities and updating params
18-
1915
learning_rate_init : float, default=0.1
2016
The initial learning rate used. It controls the step-size in updating
2117
the weights
@@ -26,22 +22,25 @@ class BaseOptimizer:
2622
the current learning rate
2723
"""
2824

29-
def __init__(self, params, learning_rate_init=0.1):
30-
self.params = [param for param in params]
25+
def __init__(self, learning_rate_init=0.1):
3126
self.learning_rate_init = learning_rate_init
3227
self.learning_rate = float(learning_rate_init)
3328

34-
def update_params(self, grads):
29+
def update_params(self, params, grads):
3530
"""Update parameters with given gradients
3631
3732
Parameters
3833
----------
39-
grads : list, length = len(params)
34+
params : list of length = len(coefs_) + len(intercepts_)
35+
The concatenated list containing coefs_ and intercepts_ in MLP
36+
model. Used for initializing velocities and updating params
37+
38+
grads : list of length = len(params)
4039
Containing gradients with respect to coefs_ and intercepts_ in MLP
4140
model. So length should be aligned with params
4241
"""
4342
updates = self._get_updates(grads)
44-
for param, update in zip(self.params, updates):
43+
for param, update in zip((p for p in params), updates):
4544
param += update
4645

4746
def iteration_ends(self, time_step):
@@ -128,7 +127,7 @@ def __init__(
128127
nesterov=True,
129128
power_t=0.5,
130129
):
131-
super().__init__(params, learning_rate_init)
130+
super().__init__(learning_rate_init)
132131

133132
self.lr_schedule = lr_schedule
134133
self.momentum = momentum
@@ -246,7 +245,7 @@ class AdamOptimizer(BaseOptimizer):
246245
def __init__(
247246
self, params, learning_rate_init=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-8
248247
):
249-
super().__init__(params, learning_rate_init)
248+
super().__init__(learning_rate_init)
250249

251250
self.beta_1 = beta_1
252251
self.beta_2 = beta_2

sklearn/neural_network/tests/test_mlp.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import re
1212

1313
import numpy as np
14+
import joblib
1415

1516
from numpy.testing import (
1617
assert_almost_equal,
@@ -869,3 +870,30 @@ def test_mlp_param_dtypes(dtype, Estimator):
869870

870871
if Estimator == MLPRegressor:
871872
assert pred.dtype == dtype
873+
874+
875+
def test_mlp_loading_from_joblib_partial_fit(tmp_path):
876+
"""Loading from MLP and partial fitting updates weights. Non-regression
877+
test for #19626."""
878+
pre_trained_estimator = MLPRegressor(
879+
hidden_layer_sizes=(42,), random_state=42, learning_rate_init=0.01, max_iter=200
880+
)
881+
features, target = [[2]], [4]
882+
883+
# Fit on x=2, y=4
884+
pre_trained_estimator.fit(features, target)
885+
886+
# dump and load model
887+
pickled_file = tmp_path / "mlp.pkl"
888+
joblib.dump(pre_trained_estimator, pickled_file)
889+
load_estimator = joblib.load(pickled_file)
890+
891+
# Train for a more epochs on point x=2, y=1
892+
fine_tune_features, fine_tune_target = [[2]], [1]
893+
894+
for _ in range(200):
895+
load_estimator.partial_fit(fine_tune_features, fine_tune_target)
896+
897+
# finetuned model learned the new target
898+
predicted_value = load_estimator.predict(fine_tune_features)
899+
assert_allclose(predicted_value, fine_tune_target, rtol=1e-4)

sklearn/neural_network/tests/test_stochastic_optimizers.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,8 @@
1212

1313

1414
def test_base_optimizer():
15-
params = [np.zeros(shape) for shape in shapes]
16-
1715
for lr in [10 ** i for i in range(-3, 4)]:
18-
optimizer = BaseOptimizer(params, lr)
16+
optimizer = BaseOptimizer(lr)
1917
assert optimizer.trigger_stopping("", False)
2018

2119

@@ -27,9 +25,9 @@ def test_sgd_optimizer_no_momentum():
2725
optimizer = SGDOptimizer(params, lr, momentum=0, nesterov=False)
2826
grads = [rng.random_sample(shape) for shape in shapes]
2927
expected = [param - lr * grad for param, grad in zip(params, grads)]
30-
optimizer.update_params(grads)
28+
optimizer.update_params(params, grads)
3129

32-
for exp, param in zip(expected, optimizer.params):
30+
for exp, param in zip(expected, params):
3331
assert_array_equal(exp, param)
3432

3533

@@ -47,9 +45,9 @@ def test_sgd_optimizer_momentum():
4745
momentum * velocity - lr * grad for velocity, grad in zip(velocities, grads)
4846
]
4947
expected = [param + update for param, update in zip(params, updates)]
50-
optimizer.update_params(grads)
48+
optimizer.update_params(params, grads)
5149

52-
for exp, param in zip(expected, optimizer.params):
50+
for exp, param in zip(expected, params):
5351
assert_array_equal(exp, param)
5452

5553

@@ -79,9 +77,9 @@ def test_sgd_optimizer_nesterovs_momentum():
7977
momentum * update - lr * grad for update, grad in zip(updates, grads)
8078
]
8179
expected = [param + update for param, update in zip(params, updates)]
82-
optimizer.update_params(grads)
80+
optimizer.update_params(params, grads)
8381

84-
for exp, param in zip(expected, optimizer.params):
82+
for exp, param in zip(expected, params):
8583
assert_array_equal(exp, param)
8684

8785

@@ -110,6 +108,6 @@ def test_adam_optimizer():
110108
]
111109
expected = [param + update for param, update in zip(params, updates)]
112110

113-
optimizer.update_params(grads)
114-
for exp, param in zip(expected, optimizer.params):
111+
optimizer.update_params(params, grads)
112+
for exp, param in zip(expected, params):
115113
assert_array_equal(exp, param)

0 commit comments

Comments
 (0)
0