diff --git a/doc/whats_new/v1.3.rst b/doc/whats_new/v1.3.rst index 2b209ee91c027..3e5e5083dd708 100644 --- a/doc/whats_new/v1.3.rst +++ b/doc/whats_new/v1.3.rst @@ -179,6 +179,15 @@ Changelog dissimilarity is not a metric and cannot be supported by the BallTree. :pr:`25417` by :user:`Guillaume Lemaitre `. +:mod:`sklearn.neural_network` +............................. + +- |Fix| :class:`neural_network.MLPRegressor` and :class:`neural_network.MLPClassifier` + reports the right `n_iter_` when `warm_start=True`. It corresponds to the number + of iterations performed on the current call to `fit` instead of the total number + of iterations performed since the initialization of the estimator. + :pr:`25443` by :user:`Marvin Krawutschke `. + :mod:`sklearn.pipeline` ....................... diff --git a/sklearn/neural_network/_multilayer_perceptron.py b/sklearn/neural_network/_multilayer_perceptron.py index 61d97e37b32a3..ec470c07d17ab 100644 --- a/sklearn/neural_network/_multilayer_perceptron.py +++ b/sklearn/neural_network/_multilayer_perceptron.py @@ -607,6 +607,7 @@ def _fit_stochastic( batch_size = np.clip(self.batch_size, 1, n_samples) try: + self.n_iter_ = 0 for it in range(self.max_iter): if self.shuffle: # Only shuffle the sample indices instead of X and y to diff --git a/sklearn/neural_network/tests/test_mlp.py b/sklearn/neural_network/tests/test_mlp.py index a4d4831766170..6db1f965dad7e 100644 --- a/sklearn/neural_network/tests/test_mlp.py +++ b/sklearn/neural_network/tests/test_mlp.py @@ -752,7 +752,7 @@ def test_warm_start_full_iteration(MLPEstimator): clf.fit(X, y) assert max_iter == clf.n_iter_ clf.fit(X, y) - assert 2 * max_iter == clf.n_iter_ + assert max_iter == clf.n_iter_ def test_n_iter_no_change(): @@ -926,3 +926,25 @@ def test_mlp_warm_start_with_early_stopping(MLPEstimator): mlp.set_params(max_iter=20) mlp.fit(X_iris, y_iris) assert len(mlp.validation_scores_) > n_validation_scores + + +@pytest.mark.parametrize("MLPEstimator", [MLPClassifier, MLPRegressor]) +@pytest.mark.parametrize("solver", ["sgd", "adam", "lbfgs"]) +def test_mlp_warm_start_no_convergence(MLPEstimator, solver): + """Check that we stop the number of iteration at `max_iter` when warm starting. + + Non-regression test for: + https://github.com/scikit-learn/scikit-learn/issues/24764 + """ + model = MLPEstimator( + solver=solver, warm_start=True, early_stopping=False, max_iter=10 + ) + + with pytest.warns(ConvergenceWarning): + model.fit(X_iris, y_iris) + assert model.n_iter_ == 10 + + model.set_params(max_iter=20) + with pytest.warns(ConvergenceWarning): + model.fit(X_iris, y_iris) + assert model.n_iter_ == 20