8000 Warn in the main process when a fit fails during a cross-validation (… · scikit-learn/scikit-learn@7317416 · GitHub
[go: up one dir, main page]

Skip to content

Commit 7317416

Browse files
lestevethomasjpfanglemaitreogrisel
authored
Warn in the main process when a fit fails during a cross-validation (#20619)
Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com> Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com> Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
1 parent 4b8cd88 commit 7317416

File tree

5 files changed

+99
-52
lines changed

5 files changed

+99
-52
lines changed

doc/whats_new/v1.0.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -610,6 +610,9 @@ Changelog
610610
:pr:`18649` by :user:`Leandro Hermida <hermidalc>` and
611611
:user:`Rodion Martynov <marrodion>`.
612612

613+
- |Enhancement| warn only once in the main process for per-split fit failures
614+
in cross-validation. :pr:`20619` by :user:`Loïc Estève <lesteve>`
615+
613616
:mod:`sklearn.naive_bayes`
614617
..........................
615618

sklearn/model_selection/_search.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from ._validation import _aggregate_score_dicts
3232
from ._validation import _insert_error_scores
3333
from ._validation import _normalize_score_results
34+
from ._validation import _warn_about_fit_failures
3435
from ..exceptions import NotFittedError
3536
from joblib import Parallel
3637
from ..utils import check_random_state
@@ -793,14 +794,18 @@ def evaluate_candidates(candidate_params, cv=None, more_results=None):
793794
"splits, got {}".format(n_splits, len(out) // n_candidates)
794795
)
795796

797+
_warn_about_fit_failures(out, self.error_score)
798+
796799
# For callable self.scoring, the return type is only know after
797800
# calling. If the return type is a dictionary, the error scores
798801
# can now be inserted with the correct key. The type checking
799802
# of out will be done in `_insert_error_scores`.
800803
if callable(self.scoring):
801804
_insert_error_scores(out, self.error_score)
805+
802806
all_candidate_params.extend(candidate_params)
803807
all_out.extend(out)
808+
804809
if more_results is not None:
805810
for key, value in more_results.items():
806811
all_more_results[key].extend(value)

sklearn/model_selection/_validation.py

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import time
1717
from traceback import format_exc
1818
from contextlib import suppress
19+
from collections import Counter
1920

2021
import numpy as np
2122
import scipy.sparse as sp
@@ -282,6 +283,8 @@ def cross_validate(
282283
6D4E for train, test in cv.split(X, y, groups)
283284
)
284285

286+
_warn_about_fit_failures(results, error_score)
287+
285288
# For callabe scoring, the return type is only know after calling. If the
286289
# return type is a dictionary, the error scores can now be inserted with
287290
# the correct key.
@@ -319,7 +322,7 @@ def _insert_error_scores(results, error_score):
319322
successful_score = None
320323
failed_indices = []
321324
for i, result in enumerate(results):
322-
if result["fit_failed"]:
325+
if result["fit_error"] is not None:
323326
failed_indices.append(i)
324327
elif successful_score is None:
325328
successful_score = result["test_scores"]
@@ -344,6 +347,31 @@ def _normalize_score_results(scores, scaler_score_key="score"):
344347
return {scaler_score_key: scores}
345348

346349

350+
def _warn_about_fit_failures(results, error_score):
351+
fit_errors = [
352+
result["fit_error"] for result in results if result["fit_error"] is not None
353+
]
354+
if fit_errors:
355+
num_failed_fits = len(fit_errors)
356+
num_fits = len(results)
357+
fit_errors_counter = Counter(fit_errors)
358+
delimiter = "-" * 80 + "\n"
359+
fit_errors_summary = "\n".join(
360+
f"{delimiter}{n} fits failed with the following error:\n{error}"
361+
for error, n in fit_errors_counter.items()
362+
)
363+
364+
some_fits_failed_message = (
365+
f"\n{num_failed_fits} fits failed out of a total of {num_fits}.\n"
366+
"The score on these train-test partitions for these parameters"
367+
f" will be set to {error_score}.\n"
368+
"If these failures are not expected, you can try to debug them "
369+
"by setting error_score='raise'.\n\n"
370+
f"Below are more details about the failures:\n{fit_errors_summary}"
371+
)
372+
warnings.warn(some_fits_failed_message, FitFailedWarning)
373+
374+
347375
def cross_val_score(
348376
estimator,
349377
X,
@@ -599,8 +627,8 @@ def _fit_and_score(
599627
The parameters that have been evaluated.
600628
estimator : estimator object
601629
The fitted estimator.
602-
fit_failed : bool
603-
The estimator failed to fit.
630+
fit_error : str or None
631+
Traceback str if the fit failed, None if the fit succeeded.
604632
"""
605633
if not isinstance(error_score, numbers.Number) and error_score != "raise":
606634
raise ValueError(
@@ -667,15 +695,9 @@ def _fit_and_score(
667695
test_scores = error_score
668696
if return_train_score:
669697
train_scores = error_score
670-
warnings.warn(
671-
"Estimator fit failed. The score on this train-test"
672-
" partition for these parameters will be set to %f. "
673-
"Details: \n%s" % (error_score, format_exc()),
674-
FitFailedWarning,
675-
)
676-
result["fit_failed"] = True
698+
result["fit_error"] = format_exc()
677699
else:
678-
result["fit_failed"] = False
700+
result["fit_error"] = None
679701

680702
fit_time = time.time() - start_time
681703
test_scores = _score(estimator, X_test, y_test, scorer, error_score)

sklearn/model_selection/tests/test_search.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1565,9 +1565,13 @@ def test_grid_search_failing_classifier():
15651565
refit=False,
15661566
error_score=0.0,
15671567
)
1568-
warning_message = (
1569-
"Estimator fit failed. The score on this train-test partition "
1570-
"for these parameters will be set to 0.0.*."
1568+
1569+
warning_message = re.compile(
1570+
"5 fits failed.+total of 15.+The score on these"
1571+
r" train-test partitions for these parameters will be set to 0\.0.+"
1572+
"5 fits failed with the following error.+ValueError.+Failing classifier failed"
1573+
" as required",
1574+
flags=re.DOTALL,
15711575
)
15721576
with pytest.warns(FitFailedWarning, match=warning_message):
15731577
gs.fit(X, y)
@@ -1598,9 +1602,12 @@ def get_cand_scores(i):
15981602
refit=False,
15991603
error_score=float("nan"),
16001604
)
1601-
warning_message = (
1602-
"Estimator fit failed. The score on this train-test partition "
1603-
"for these parameters will be set to nan."
1605+
warning_message = re.compile(
1606+
"5 fits failed.+total of 15.+The score on these"
1607+
r" train-test partitions for these parameters will be set to nan.+"
1608+
"5 fits failed with the following error.+ValueError.+Failing classifier failed"
1609+
" as required",
1610+
flags=re.DOTALL,
16041611
)
16051612
with pytest.warns(FitFailedWarning, match=warning_message):
16061613
gs.fit(X, y)
@@ -2112,7 +2119,12 @@ def custom_scorer(est, X, y):
21122119
error_score=0.1,
21132120
)
21142121

2115-
with pytest.warns(FitFailedWarning, match="Estimator fit failed"):
2122+
warning_message = re.compile(
2123+
"5 fits failed.+total of 15.+The score on these"
2124+
r" train-test partitions for these parameters will be set to 0\.1",
2125+
flags=re.DOTALL,
2126+
)
2127+
with pytest.warns(FitFailedWarning, match=warning_message):
21162128
gs.fit(X, y)
21172129

21182130
assert_allclose(gs.cv_results_["mean_test_acc"], [1, 1, 0.1])
@@ -2135,9 +2147,10 @@ def custom_scorer(est, X, y):
21352147
error_score=0.1,
21362148
)
21372149

2138-
with pytest.warns(FitFailedWarning, match="Estimator fit failed"), pytest.raises(
2139-
NotFittedError, match="All estimators failed to fit"
2140-
):
2150+
with pytest.warns(
2151+
FitFailedWarning,
2152+
match="15 fits failed.+total of 15",
2153+
), pytest.raises(NotFittedError, match="All estimators failed to fit"):
21412154
gs.fit(X, y)
21422155

21432156

sklearn/model_selection/tests/test_validation.py

Lines changed: 35 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2082,37 +2082,6 @@ def test_fit_and_score_failing():
20822082
y = np.ones(9)
20832083
fit_and_score_args = [failing_clf, X, None, dict(), None, None, 0, None, None]
20842084
# passing error score to trigger the warning message
2085-
fit_and_score_kwargs = {"error_score": 0}
2086-
# check if the warning message type is as expected
2087-
warning_message = (
2088-
"Estimator fit failed. The score on this train-test partition for "
2089-
"these parameters will be set to %f." % (fit_and_score_kwargs["error_score"])
2090-
)
2091-
with pytest.warns(FitFailedWarning, match=warning_message):
2092-
_fit_and_score(*fit_and_score_args, **fit_and_score_kwargs)
2093-
# since we're using FailingClassfier, our error will be the following
2094-
error_message = "ValueError: Failing classifier failed as required"
2095-
# the warning message we're expecting to see
2096-
warning_message = (
2097-
"Estimator fit failed. The score on this train-test "
2098-
"partition for these parameters will be set to %f. "
2099-
"Details: \n%s" % (fit_and_score_kwargs["error_score"], error_message)
2100-
)
2101-
2102-
def test_warn_trace(msg):
2103-
assert "Traceback (most recent call last):\n" in msg
2104-
split = msg.splitlines() # note: handles more than '\n'
2105-
mtb = split[0] + "\n" + split[-1]
2106-
return warning_message in mtb
2107-
2108-
# check traceback is included
2109-
warning_message = (
2110-
"Estimator fit failed. The score on this train-test partition for "
2111-
"these parameters will be set to %f." % (fit_and_score_kwargs["error_score"])
2112-
)
2113-
with pytest.warns(FitFailedWarning, match=warning_message):
2114-
_fit_and_score(*fit_and_score_args, **fit_and_score_kwargs)
2115-
21162085
fit_and_score_kwargs = {"error_score": "raise"}
21172086
# check if exception was raised, with default error_score='raise'
21182087
with pytest.raises(ValueError, match="Failing classifier failed as required"):
@@ -2161,6 +2130,41 @@ def test_fit_and_score_working():
21612130
assert result["parameters"] == fit_and_score_kwargs["parameters"]
21622131

21632132

2133+
@pytest.mark.parametrize("error_score", [np.nan, 0])
2134+
def test_cross_validate_failing_fits_warnings(error_score):
2135+
# Create a failing classifier to deliberately fail
2136+
failing_clf = FailingClassifier(FailingClassifier.FAILING_PARAMETER)
2137+
# dummy X data
2138+
X = np.arange(1, 10)
2139+
y = np.ones(9)
2140+
# fit_and_score_args = [failing_clf, X, None, dict(), None, None, 0, None, None]
2141+
# passing error score to trigger the warning message
2142+
cross_validate_args = [failing_clf, X, y]
2143+
cross_validate_kwargs = {"cv": 7, "error_score": error_score}
2144+
# check if the warning message type is as expected
2145+
warning_message = re.compile(
2146+
"7 fits failed.+total of 7.+The score on these"
2147+
" train-test partitions for these parameters will be set to"
2148+
f" {cross_validate_kwargs['error_score']}.",
2149+
flags=re.DOTALL,
2150+
)
2151+
2152+
with pytest.warns(FitFailedWarning, match=warning_message):
2153+
cross_validate(*cross_validate_args, **cross_validate_kwargs)
2154+
2155+
# since we're using FailingClassfier, our error will be the following
2156+
error_message = "ValueError: Failing classifier failed as required"
2157+
2158+
# check traceback is included
2159+
warning_message = re.compile(
2160+
"The score on these train-test partitions for these parameters will be set"
2161+
f" to {cross_validate_kwargs['error_score']}.+{error_message}",
2162+
re.DOTALL,
2163+
)
2164+
with pytest.warns(FitFailedWarning, match=warning_message):
2165+
cross_validate(*cross_validate_args, **cross_validate_kwargs)
2166+
2167+
21642168
def _failing_scorer(estimator, X, y, error_msg):
21652169
raise ValueError(error_msg)
21662170

0 commit comments

Comments
 (0)
0