8000 FIX report properly n_iter_ when warm_start=True (#25443) · scikit-learn/scikit-learn@6df0f13 · GitHub
[go: up one dir, main page]

Skip to content

Commit 6df0f13

Browse files
Marvvxiglemaitre
andauthored
FIX report properly n_iter_ when warm_start=True (#25443)
Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
1 parent 677a4cf commit 6df0f13

File tree

3 files changed

+33
-1
lines changed

3 files changed

+33
-1
lines changed

doc/whats_new/v1.3.rst

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,15 @@ Changelog
179179
dissimilarity is not a metric and cannot be supported by the BallTree.
180180
:pr:`25417` by :user:`Guillaume Lemaitre <glemaitre>`.
181181

182+
:mod:`sklearn.neural_network`
183+
.............................
184+
185+
- |Fix| :class:`neural_network.MLPRegressor` and :class:`neural_network.MLPClassifier`
186+
reports the right `n_iter_` when `warm_start=True`. It corresponds to the number
187+
of iterations performed on the current call to `fit` instead of the total number
188+
of iterations performed since the initialization of the estimator.
189+
:pr:`25443` by :user:`Marvin Krawutschke <Marvvxi>`.
190+
182191
:mod:`sklearn.pipeline`
183192
.......................
184193

sklearn/neural_network/_multilayer_perceptron.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -607,6 +607,7 @@ def _fit_stochastic(
607607
batch_size = np.clip(self.batch_size, 1, n_samples)
608608

609609
try:
610+
self.n_iter_ = 0
610611
for it in range(self.max_iter):
611612
if self.shuffle:
612613
# Only shuffle the sample indices instead of X and y to

sklearn/neural_network/tests/test_mlp.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -752,7 +752,7 @@ def test_warm_start_full_iteration(MLPEstimator):
752752
clf.fit(X, y)
753753
assert max_iter == clf.n_iter_
754754
clf.fit(X, y)
755-
assert 2 * max_iter == clf.n_iter_
755+
assert max_iter == clf.n_iter_
756756

757757

758758
def test_n_iter_no_change():
@@ -926,3 +926,25 @@ def test_mlp_warm_start_with_early_stopping(MLPEstimator):
926926
mlp.set_params(max_iter=20)
927927
mlp.fit(X_iris, y_iris)
928928
assert len(mlp.validation_scores_) > n_validation_scores
929+
930+
931+
@pytest.mark.parametrize("MLPEstimator", [MLPClassifier, MLPRegressor])
932+
@pytest.mark.parametrize("solver", ["sgd", "adam", "lbfgs"])
933+
def test_mlp_warm_start_no_convergence(MLPEstimator, solver):
934+
"""Check that we stop the number of iteration at `max_iter` when warm starting.
935+
936+
Non-regression test for:
937+
https://github.com/scikit-learn/scikit-learn/issues/24764
938+
"""
939+
model = MLPEstimator(
940+
solver=solver, warm_start=True, early_stopping=False, max_iter=10
941+
)
942+
943+
with pytest.warns(ConvergenceWarning):
944+
model.fit(X_iris, y_iris)
945+
assert model.n_iter_ == 10
946+
947+
model.set_params(max_iter=20)
948+
with pytest.warns(ConvergenceWarning):
949+
model.fit(X_iris, y_iris)
950+
assert model.n_iter_ == 20

0 commit comments

Comments
 (0)
0