diff --git a/doc/glossary.rst b/doc/glossary.rst index 86cb3c06f5634..42e746c38b9ec 100644 --- a/doc/glossary.rst +++ b/doc/glossary.rst @@ -1583,10 +1583,10 @@ functions or non-estimator constructors. in the User Guide. Where multiple metrics can be evaluated, ``scoring`` may be given - either as a list of unique strings or a dictionary with names as keys - and callables as values. Note that this does *not* specify which score - function is to be maximized, and another parameter such as ``refit`` - maybe used for this purpose. + either as a list of unique strings, a dictionary with names as keys and + callables as values or a callable that returns a dictionary. Note that + this does *not* specify which score function is to be maximized, and + another parameter such as ``refit`` maybe used for this purpose. The ``scoring`` parameter is validated and interpreted using diff --git a/doc/modules/model_evaluation.rst b/doc/modules/model_evaluation.rst index bb8b59889a3f5..f8874869a0274 100644 --- a/doc/modules/model_evaluation.rst +++ b/doc/modules/model_evaluation.rst @@ -249,7 +249,7 @@ Using multiple metric evaluation Scikit-learn also permits evaluation of multiple metrics in ``GridSearchCV``, ``RandomizedSearchCV`` and ``cross_validate``. -There are two ways to specify multiple scoring metrics for the ``scoring`` +There are three ways to specify multiple scoring metrics for the ``scoring`` parameter: - As an iterable of string metrics:: @@ -261,25 +261,23 @@ parameter: >>> scoring = {'accuracy': make_scorer(accuracy_score), ... 'prec': 'precision'} -Note that the dict values can either be scorer functions or one of the -predefined metric strings. + Note that the dict values can either be scorer functions or one of the + predefined metric strings. -Currently only those scorer functions that return a single score can be passed -inside the dict. Scorer functions that return multiple values are not -permitted and will require a wrapper to return a single metric:: +- As a callable that returns a dictionary of scores:: >>> from sklearn.model_selection import cross_validate >>> from sklearn.metrics import confusion_matrix >>> # A sample toy binary classification dataset >>> X, y = datasets.make_classification(n_classes=2, random_state=0) >>> svm = LinearSVC(random_state=0) - >>> def tn(y_true, y_pred): return confusion_matrix(y_true, y_pred)[0, 0] - >>> def fp(y_true, y_pred): return confusion_matrix(y_true, y_pred)[0, 1] - >>> def fn(y_true, y_pred): return confusion_matrix(y_true, y_pred)[1, 0] - >>> def tp(y_true, y_pred): return confusion_matrix(y_true, y_pred)[1, 1] - >>> scoring = {'tp': make_scorer(tp), 'tn': make_scorer(tn), - ... 'fp': make_scorer(fp), 'fn': make_scorer(fn)} - >>> cv_results = cross_validate(svm, X, y, cv=5, scoring=scoring) + >>> def confusion_matrix_scorer(clf, X, y): + ... y_pred = clf.predict(X) + ... cm = confusion_matrix(y, y_pred) + ... return {'tn': cm[0, 0], 'fp': cm[0, 1], + ... 'fn': cm[1, 0], 'tp': cm[1, 1]} + >>> cv_results = cross_validate(svm, X, y, cv=5, + ... scoring=confusion_matrix_scorer) >>> # Getting the test set true positive scores >>> print(cv_results['test_tp']) [10 9 8 7 8] diff --git a/sklearn/metrics/__init__.py b/sklearn/metrics/__init__.py index be28005631963..36a54b88a50cc 100644 --- a/sklearn/metrics/__init__.py +++ b/sklearn/metrics/__init__.py @@ -74,6 +74,7 @@ from ._scorer import check_scoring from ._scorer import make_scorer from ._scorer import SCORERS +from ._scorer import get_applicable_scorers from ._scorer import get_scorer from ._plot.roc_curve import plot_roc_curve @@ -109,6 +110,7 @@ 'f1_score', 'fbeta_score', 'fowlkes_mallows_score', + 'get_applicable_scorers', 'get_scorer', 'hamming_loss', 'hinge_loss', diff --git a/sklearn/metrics/_scorer.py b/sklearn/metrics/_scorer.py index 6098eae2d68a0..4f56f88ad3d23 100644 --- a/sklearn/metrics/_scorer.py +++ b/sklearn/metrics/_scorer.py @@ -18,9 +18,12 @@ # Arnaud Joly # License: Simplified BSD +from collections import Counter +from collections import namedtuple from collections.abc import Iterable +from copy import deepcopy +from inspect import signature from functools import partial -from collections import Counter import numpy as np @@ -423,7 +426,7 @@ def check_scoring(estimator, scoring=None, *, allow_none=False): " None. %r was passed" % scoring) -def _check_multimetric_scoring(estimator, scoring=None): +def _check_multimetric_scoring(estimator, scoring): """Check the scoring parameter in cases when multiple metrics are allowed Parameters @@ -431,91 +434,66 @@ def _check_multimetric_scoring(estimator, scoring=None): estimator : sklearn estimator instance The estimator for which the scoring will be applied. - scoring : str, callable, list, tuple or dict, default=None + scoring : list, tuple or dict A single string (see :ref:`scoring_parameter`) or a callable (see :ref:`scoring`) to evaluate the predictions on the test set. For evaluating multiple metrics, either give a list of (unique) strings or a dict with names as keys and callables as values. - NOTE that when using custom scorers, each scorer should return a single - value. Metric functions returning a list/array of values can be wrapped - into multiple scorers that return one value each. - See :ref:`multimetric_grid_search` for an example. - If None the estimator's score method is used. - The return value in that case will be ``{'score': }``. - If the estimator's score method is not available, a ``TypeError`` - is raised. - Returns ------- scorers_dict : dict A dict mapping each scorer name to its validated scorer. - - is_multimetric : bool - True if scorer is a list/tuple or dict of callables - False if scorer is None/str/callable """ - if callable(scoring) or scoring is None or isinstance(scoring, - str): - scorers = {"score": check_scoring(estimator, scoring=scoring)} - return scorers, False - else: - err_msg_generic = ("scoring should either be a single string or " - "callable for single metric evaluation or a " - "list/tuple of strings or a dict of scorer name " - "mapped to the callable for multiple metric " - "evaluation. Got %s of type %s" - % (repr(scoring), type(scoring))) - - if isinstance(scoring, (list, tuple, set)): - err_msg = ("The list/tuple elements must be unique " - "strings of predefined scorers. ") - invalid = False - try: - keys = set(scoring) - except TypeError: - invalid = True - if invalid: - raise ValueError(err_msg) - - if len(keys) != len(scoring): - raise ValueError(err_msg + "Duplicate elements were found in" - " the given list. %r" % repr(scoring)) - elif len(keys) > 0: - if not all(isinstance(k, str) for k in keys): - if any(callable(k) for k in keys): - raise ValueError(err_msg + - "One or more of the elements were " - "callables. Use a dict of score name " - "mapped to the scorer callable. " - "Got %r" % repr(scoring)) - else: - raise ValueError(err_msg + - "Non-string types were found in " - "the given list. Got %r" - % repr(scoring)) - scorers = {scorer: check_scoring(estimator, scoring=scorer) - for scorer in scoring} - else: - raise ValueError(err_msg + - "Empty list was given. %r" % repr(scoring)) - - elif isinstance(scoring, dict): + err_msg_generic = ( + f"scoring is invalid (got {scoring!r}). Refer to the " + "scoring glossary for details: " + "https://scikit-learn.org/stable/glossary.html#term-scoring") + + if isinstance(scoring, (list, tuple, set)): + err_msg = ("The list/tuple elements must be unique " + "strings of predefined scorers. ") + invalid = False + try: keys = set(scoring) + except TypeError: + invalid = True + if invalid: + raise ValueError(err_msg) + + if len(keys) != len(scoring): + raise ValueError(f"{err_msg} Duplicate elements were found in" + f" the given list. {scoring!r}") + elif len(keys) > 0: if not all(isinstance(k, str) for k in keys): - raise ValueError("Non-string types were found in the keys of " - "the given dict. scoring=%r" % repr(scoring)) - if len(keys) == 0: - raise ValueError("An empty dict was passed. %r" - % repr(scoring)) - scorers = {key: check_scoring(estimator, scoring=scorer) - for key, scorer in scoring.items()} + if any(callable(k) for k in keys): + raise ValueError(f"{err_msg} One or more of the elements " + "were callables. Use a dict of score " + "name mapped to the scorer callable. " + f"Got {scoring!r}") + else: + raise ValueError(f"{err_msg} Non-string types were found " + f"in the given list. Got {scoring!r}") + scorers = {scorer: check_scoring(estimator, scoring=scorer) + for scorer in scoring} else: - raise ValueError(err_msg_generic) - return scorers, True + raise ValueError(f"{err_msg} Empty list was given. {scoring!r}") + + elif isinstance(scoring, dict): + keys = set(scoring) + if not all(isinstance(k, str) for k in keys): + raise ValueError("Non-string types were found in the keys of " + f"the given dict. scoring={scoring!r}") + if len(keys) == 0: + raise ValueError(f"An empty dict was passed. {scoring!r}") + scorers = {key: check_scoring(estimator, scoring=scorer) + for key, scorer in scoring.items()} + else: + raise ValueError(err_msg_generic) + return scorers @_deprecate_positional_args @@ -711,3 +689,153 @@ def make_scorer(score_func, *, greater_is_better=True, needs_proba=False, qualified_name = '{0}_{1}'.format(name, average) SCORERS[qualified_name] = make_scorer(metric, pos_label=None, average=average) + +ScorerProperty = namedtuple( + "ScorerProperty", ["scorer", "target_type_supported"], +) + +SCORERS_PROPERTY = dict( + explained_variance=ScorerProperty( + scorer=explained_variance_scorer, + target_type_supported=("continuous", "continuous-multioutput"), + ), + r2=ScorerProperty( + scorer=r2_scorer, + target_type_supported=("continuous", "continuous-multioutput"), + ), + max_error=ScorerProperty( + scorer=max_error_scorer, + target_type_supported=("continuous",), + ), + neg_median_absolute_error=ScorerProperty( + scorer=neg_median_absolute_error_scorer, + target_type_supported=("continuous", "continuous-multioutput"), + ), + neg_mean_absolute_error=ScorerProperty( + scorer=neg_mean_absolute_error_scorer, + target_type_supported=("continuous", "continuous-multioutput"), + ), + neg_mean_absolute_percentage_error=ScorerProperty( + scorer=neg_mean_absolute_percentage_error_scorer, + target_type_supported=("continuous", "continuous-multioutput"), + ), + neg_mean_squared_error=ScorerProperty( + scorer=neg_mean_squared_error_scorer, + target_type_supported=("continuous", "continuous-multioutput"), + ), + neg_mean_squared_log_error=ScorerProperty( + scorer=neg_mean_squared_log_error_scorer, + target_type_supported=("continuous", "continuous-multioutput"), + ), + neg_root_mean_squared_error=ScorerProperty( + scorer=neg_root_mean_squared_error_scorer, + target_type_supported=("continuous", "continuous-multioutput"), + ), + neg_mean_poisson_deviance=ScorerProperty( + scorer=neg_mean_poisson_deviance_scorer, + target_type_supported=("continuous",), + ), + neg_mean_gamma_deviance=ScorerProperty( + scorer=neg_mean_gamma_deviance_scorer, + target_type_supported=("continuous",), + ), + accuracy=ScorerProperty( + scorer=accuracy_scorer, + target_type_supported=("binary", "multiclass", "multilabel-indicator"), + ), + roc_auc=ScorerProperty( + scorer=roc_auc_scorer, + target_type_supported=("binary", "multilabel-indicator"), + ), + roc_auc_ovr=ScorerProperty( + scorer=roc_auc_ovr_scorer, + target_type_supported=("multiclass"), + ), + roc_auc_ovo=ScorerProperty( + scorer=roc_auc_ovo_scorer, + target_type_supported=("multiclass"), + ), + roc_auc_ovr_weighted=ScorerProperty( + scorer=roc_auc_ovr_weighted_scorer, + target_type_supported=("multiclass"), + ), + roc_auc_ovo_weighted=ScorerProperty( + scorer=roc_auc_ovo_weighted_scorer, + target_type_supported=("multiclass"), + ), + balanced_accuracy=ScorerProperty( + scorer=balanced_accuracy_scorer, + target_type_supported=("binary", "multiclass"), + ), + jaccard=ScorerProperty( + scorer=make_scorer(jaccard_score), + target_type_supported=("binary", "multilabel-indicator"), + ), + average_precision=ScorerProperty( + scorer=average_precision_scorer, + target_type_supported=("binary", "multilabel-indicator"), + ), + neg_log_loss=ScorerProperty( + scorer=neg_log_loss_scorer, + target_type_supported=("binary", "multiclass"), + ), + neg_brier_score=ScorerProperty( + scorer=neg_brier_score_scorer, + target_type_supported=("binary"), + ), +) + +for name, metric in [('precision', precision_score), + ('recall', recall_score), ('f1', f1_score), + ('jaccard', jaccard_score)]: + SCORERS_PROPERTY[name] = ScorerProperty( + scorer=make_scorer(metric, average='binary'), + target_type_supported=("binary",), + ) + for average in ['macro', 'micro', 'samples', 'weighted']: + qualified_name = f'{name}_{average}' + SCORERS_PROPERTY[qualified_name] = ScorerProperty( + scorer=make_scorer(metric, pos_label=None, average=average), + target_type_supported=("multilabel-indicator"), + ) + + +def get_applicable_scorers(y, **scorers_params): + """Utility providing scorers to be used on `y`. + + This utility creates a dictionary containing the scorers which can be used + on `y`. The dictionary returned can be used directly in a + :class:`~sklearn.model_selection.GridSearchCV`. + + Additional parameters taken by the different metrics can be passed as + keyword argument. + + Parameters + ---------- + y : array-like + The target used to infer the metrics which can be used. + + **scorers_params + Additional parameters to be passed to the scorers when present in their + signature. + + Returns + ------- + scorers : dict + A dictionary containing the scorer name as key and a scorer callable as + value. + """ + target_type = type_of_target(y) + + scorers = {} + for scorer_name, scorer_property in SCORERS_PROPERTY.items(): + if target_type in scorer_property.target_type_supported: + scorers[scorer_name] = deepcopy(scorer_property.scorer) + scorer_sig = signature(scorers[scorer_name]._score_func) + for param_name, param_value in scorers_params.items(): + if param_name in scorer_sig.parameters: + scorers[scorer_name]._kwargs[param_name] = param_value + + if not scorers: + raise ValueError("No compatible scorer with the target 'y' was found.") + return scorers diff --git a/sklearn/metrics/tests/test_score_objects.py b/sklearn/metrics/tests/test_score_objects.py index 227fd8bbadee9..22d1f2d971f93 100644 --- a/sklearn/metrics/tests/test_score_objects.py +++ b/sklearn/metrics/tests/test_score_objects.py @@ -25,7 +25,13 @@ _MultimetricScorer, _check_multimetric_scoring) from sklearn.metrics import accuracy_score -from sklearn.metrics import make_scorer, get_scorer, SCORERS +from sklearn.metrics import average_precision_score +from sklearn.metrics import ( + get_applicable_scorers, + get_scorer, + make_scorer, + SCORERS +) from sklearn.neighbors import KNeighborsClassifier from sklearn.svm import LinearSVC from sklearn.pipeline import make_pipeline @@ -35,6 +41,7 @@ from sklearn.datasets import make_blobs from sklearn.datasets import make_classification, make_regression from sklearn.datasets import make_multilabel_classification +from sklearn.datasets import load_breast_cancer from sklearn.datasets import load_diabetes from sklearn.model_selection import train_test_split, cross_val_score from sklearn.model_selection import GridSearchCV @@ -206,30 +213,10 @@ def check_scoring_validator_for_single_metric_usecases(scoring_validator): assert scorer is None -def check_multimetric_scoring_single_metric_wrapper(*args, **kwargs): - # This wraps the _check_multimetric_scoring to take in - # single metric scoring parameter so we can run the tests - # that we will run for check_scoring, for check_multimetric_scoring - # too for single-metric usecases - - scorers, is_multi = _check_multimetric_scoring(*args, **kwargs) - # For all single metric use cases, it should register as not multimetric - assert not is_multi - if args[0] is not None: - assert scorers is not None - names, scorers = zip(*scorers.items()) - assert len(scorers) == 1 - assert names[0] == 'score' - scorers = scorers[0] - return scorers - - def test_check_scoring_and_check_multimetric_scoring(): check_scoring_validator_for_single_metric_usecases(check_scoring) # To make sure the check_scoring is correctly applied to the constituent # scorers - check_scoring_validator_for_single_metric_usecases( - check_multimetric_scoring_single_metric_wrapper) # For multiple metric use cases # Make sure it works for the valid cases @@ -241,8 +228,7 @@ def test_check_scoring_and_check_multimetric_scoring(): estimator = LinearSVC(random_state=0) estimator.fit([[1], [2], [3]], [1, 1, 0]) - scorers, is_multi = _check_multimetric_scoring(estimator, scoring) - assert is_multi + scorers = _check_multimetric_scoring(estimator, scoring) assert isinstance(scorers, dict) assert sorted(scorers.keys()) == sorted(list(scoring)) assert all([isinstance(scorer, _PredictScorer) @@ -622,7 +608,7 @@ def test_multimetric_scorer_calls_method_once(scorers, expected_predict_count, mock_est.predict_proba = predict_proba_func mock_est.decision_function = decision_function_func - scorer_dict, _ = _check_multimetric_scoring(LogisticRegression(), scorers) + scorer_dict = _check_multimetric_scoring(LogisticRegression(), scorers) multi_scorer = _MultimetricScorer(**scorer_dict) results = multi_scorer(mock_est, X, y) @@ -649,7 +635,7 @@ def predict_proba(self, X): clf.fit(X, y) scorers = ['roc_auc', 'neg_log_loss'] - scorer_dict, _ = _check_multimetric_scoring(clf, scorers) + scorer_dict = _check_multimetric_scoring(clf, scorers) scorer = _MultimetricScorer(**scorer_dict) scorer(clf, X, y) @@ -672,7 +658,7 @@ def predict(self, X): clf.fit(X, y) scorers = {'neg_mse': 'neg_mean_squared_error', 'r2': 'roc_auc'} - scorer_dict, _ = _check_multimetric_scoring(clf, scorers) + scorer_dict = _check_multimetric_scoring(clf, scorers) scorer = _MultimetricScorer(**scorer_dict) scorer(clf, X, y) @@ -690,7 +676,7 @@ def test_multimetric_scorer_sanity_check(): clf = DecisionTreeClassifier() clf.fit(X, y) - scorer_dict, _ = _check_multimetric_scoring(clf, scorers) + scorer_dict = _check_multimetric_scoring(clf, scorers) multi_scorer = _MultimetricScorer(**scorer_dict) result = multi_scorer(clf, X, y) @@ -750,3 +736,88 @@ def test_multiclass_roc_no_proba_scorer_errors(scorer_name): msg = "'Perceptron' object has no attribute 'predict_proba'" with pytest.raises(AttributeError, match=msg): scorer(lr, X, y) + + +def _parametrize_scorers_from_target(estimator_data_ids): + check_scorers, check_scorers_ids = zip(*[ + ((Estimator, X, np.abs(y) - np.min(y), scorer), + f"{scorer_name}-{problem_id}") + for problem_id, Estimator, X, y in estimator_data_ids + for scorer_name, scorer in get_applicable_scorers(y).items() + ]) + + return pytest.mark.parametrize( + "Estimator, X, y, scorer", check_scorers, ids=check_scorers_ids, + ) + + +@pytest.mark.filterwarnings( + "ignore::sklearn.exceptions.UndefinedMetricWarning" +) +@_parametrize_scorers_from_target( + [("binary", LogisticRegression, *make_classification(n_classes=2)), + ("multiclass", LogisticRegression, + *make_classification(n_classes=3, n_clusters_per_class=1)), + ("multilabel", DecisionTreeClassifier, *make_multilabel_classification()), + ("continuous", Ridge, *make_regression(n_targets=1)), + ("continuous-multioutput", Ridge, *make_regression(n_targets=2))] +) +def test_get_applicable_scorers_smoke_test(Estimator, X, y, scorer): + # smoke test to check that we can use the score on the registered problem + estimator = Estimator().fit(X, y) + scorer(estimator, X, y) + + +@pytest.mark.filterwarnings( + "ignore::sklearn.exceptions.UndefinedMetricWarning" +) +@pytest.mark.parametrize( + "Estimator, X, y", + [(LogisticRegression, *make_classification(n_classes=2)), + (LogisticRegression, + *make_classification(n_classes=3, n_clusters_per_class=1)), + (DecisionTreeClassifier, *make_multilabel_classification()), + (Ridge, *make_regression(n_targets=1)), + (Ridge, *make_regression(n_targets=2))] +) +def test_get_applicable_scorers_with_grid_search_smoke_test(Estimator, X, y): + # smoke test to check that scorers can be used directly inside a + # grid-search + if issubclass(Estimator, LogisticRegression): + param_grid = {"C": [0.1, 1]} + elif issubclass(Estimator, DecisionTreeClassifier): + param_grid = {"max_depth": [3, 5]} + elif issubclass(Estimator, Ridge): + y = np.abs(y) - np.min(y) + param_grid = {"alpha": [1, 10]} + + scorers = get_applicable_scorers(y) + estimator = GridSearchCV( + Estimator(), param_grid=param_grid, scoring=scorers, n_jobs=-1, + refit=list(scorers.keys())[0], + ) + estimator.fit(X, y) + + +def test_get_applicable_scorers_passing_scoring_params(): + # check that we can pass scoring parameters when getting the score + breast_cancer = load_breast_cancer() + X = breast_cancer.data + y = breast_cancer.target_names[breast_cancer.target].astype("object") + + scorers = get_applicable_scorers(y, pos_label="malignant") + average_precision_scorer = scorers["average_precision"] + assert "pos_label" in average_precision_scorer._kwargs + assert average_precision_scorer._kwargs["pos_label"] == "malignant" + + estimator = GridSearchCV( + DecisionTreeClassifier(), param_grid={"max_depth": [3, 5]}, + scoring=average_precision_scorer, + ) + estimator.fit(X, y) + + # check that if we don't provide any pos_label, the grid-search will raise + # an error + with pytest.raises(ValueError, match="pos_label=1 is invalid"): + estimator.set_params(scoring=make_scorer(average_precision_score)) + estimator.fit(X, y) diff --git a/sklearn/model_selection/_search.py b/sklearn/model_selection/_search.py index bcdbcdbc498fb..0eeabbd25a1ef 100644 --- a/sklearn/model_selection/_search.py +++ b/sklearn/model_selection/_search.py @@ -29,6 +29,8 @@ from ._split import check_cv from ._validation import _fit_and_score from ._validation import _aggregate_score_dicts +from ._validation import _insert_error_scores +from ._validation import _normalize_score_results from ..exceptions import NotFittedError from joblib import Parallel, delayed from ..utils import check_random_state @@ -453,8 +455,18 @@ def score(self, X, y=None): raise ValueError("No score function explicitly defined, " "and the estimator doesn't provide one %s" % self.best_estimator_) - score = self.scorer_[self.refit] if self.multimetric_ else self.scorer_ - return score(self.best_estimator_, X, y) + if isinstance(self.scorer_, dict): + if self.multimetric_: + scorer = self.scorer_[self.refit] + else: + scorer = self.scorer_ + return scorer(self.best_estimator_, X, y) + + # callable + score = self.scorer_(self.best_estimator_, X, y) + if self.multimetric_: + score = score[self.refit] + return score @if_delegate_has_method(delegate=('best_estimator_', 'estimator')) def score_samples(self, X): @@ -643,6 +655,23 @@ def _run_search(self, evaluate_candidates): """ raise NotImplementedError("_run_search not implemented.") + def _check_refit_for_multimetric(self, scores_dict): + """Check score contains the string in refit""" + multimetric_refit_msg = ("For multi-metric scoring, the parameter " + "refit must be set to a scorer key or a " + "callable to refit an estimator with the " + "best parameter setting on the whole " + "data and make the best_* attributes " + "available for that metric. If this is " + "not needed, refit should be set to " + "False explicitly. %r was passed." + % self.refit) + if self.refit is not False and ( + not isinstance(self.refit, str) or + # This will work for both dict / list (tuple) + self.refit not in scores_dict) and not callable(self.refit): + raise ValueError(multimetric_refit_msg) + @_deprecate_positional_args def fit(self, X, y=None, *, groups=None, **fit_params): """Run fit with all sets of parameters. @@ -670,27 +699,16 @@ def fit(self, X, y=None, *, groups=None, **fit_params): estimator = self.estimator cv = check_cv(self.cv, y, classifier=is_classifier(estimator)) - scorers, self.multimetric_ = _check_multimetric_scoring( - self.estimator, scoring=self.scoring) + refit_metric = "score" - if self.multimetric_: - if self.refit is not False and ( - not isinstance(self.refit, str) or - # This will work for both dict / list (tuple) - self.refit not in scorers) and not callable(self.refit): - raise ValueError("For multi-metric scoring, the parameter " - "refit must be set to a scorer key or a " - "callable to refit an estimator with the " - "best parameter setting on the whole " - "data and make the best_* attributes " - "available for that metric. If this is " - "not needed, refit should be set to " - "False explicitly. %r was passed." - % self.refit) - else: - refit_metric = self.refit + if callable(self.scoring): + scorers = self.scoring + elif self.scoring is None or isinstance(self.scoring, str): + scorers = check_scoring(self.estimator, self.scoring) else: - refit_metric = 'score' + scorers = _check_multimetric_scoring(self.estimator, self.scoring) + self._check_refit_for_multimetric(scorers) + refit_metric = self.refit X, y, groups = indexable(X, y, groups) fit_params = _check_fit_params(X, fit_params) @@ -751,16 +769,31 @@ def evaluate_candidates(candidate_params): .format(n_splits, len(out) // n_candidates)) + # For callabe self.scoring, the return type is only know after + # calling. If the return type is a dictionary, the error scores + # can now be inserted with the correct key. + if callable(self.scoring): + _insert_error_scores(out, self.error_score) all_candidate_params.extend(candidate_params) all_out.extend(out) nonlocal results results = self._format_results( - all_candidate_params, scorers, n_splits, all_out) + all_candidate_params, n_splits, all_out) return results self._run_search(evaluate_candidates) + # multimetric is determined here because in the case of a callable + # self.scoring the return type is only known after calling + first_test_score = all_out[0]['test_scores'] + self.multimetric_ = isinstance(first_test_score, dict) + + # check refit_metric now for a callabe scorer that is multimetric + if callable(self.scoring) and self.multimetric_: + self._check_refit_for_multimetric(first_test_score) + refit_metric = self.refit + # For multi-metric evaluation, store the best_index_, best_params_ and # best_score_ iff refit is one of the scorer names # In single metric evaluation, refit_metric is "score" @@ -795,14 +828,14 @@ def evaluate_candidates(candidate_params): self.refit_time_ = refit_end_time - refit_start_time # Store the only scorer not as a dict for single metric evaluation - self.scorer_ = scorers if self.multimetric_ else scorers['score'] + self.scorer_ = scorers self.cv_results_ = results self.n_splits_ = n_splits return self - def _format_results(self, candidate_params, scorers, n_splits, out): + def _format_results(self, candidate_params, n_splits, out): n_candidates = len(candidate_params) out = _aggregate_score_dicts(out) @@ -852,17 +885,18 @@ def _store(key_name, array, weights=None, splits=False, rank=False): # Store a list of param dicts at the key 'params' results['params'] = candidate_params - test_scores = _aggregate_score_dicts(out["test_scores"]) + test_scores_dict = _normalize_score_results(out["test_scores"]) if self.return_train_score: - train_scores = _aggregate_score_dicts(out["train_scores"]) + train_scores_dict = _normalize_score_results(out["train_scores"]) - for scorer_name in test_scores: + for scorer_name in test_scores_dict: # Computed the (weighted) mean and std for test scores alone - _store('test_%s' % scorer_name, test_scores[scorer_name], + _store('test_%s' % scorer_name, test_scores_dict[scorer_name], splits=True, rank=True, weights=None) if self.return_train_score: - _store('train_%s' % scorer_name, train_scores[scorer_name], + _store('train_%s' % scorer_name, + train_scores_dict[scorer_name], splits=True) return results diff --git a/sklearn/model_selection/_validation.py b/sklearn/model_selection/_validation.py index b1cdd8748eb8d..eac1082a97e4f 100644 --- a/sklearn/model_selection/_validation.py +++ b/sklearn/model_selection/_validation.py @@ -28,7 +28,7 @@ from ..utils.metaestimators import _safe_split from ..metrics import check_scoring from ..metrics._scorer import _check_multimetric_scoring, _MultimetricScorer -from ..exceptions import FitFailedWarning +from ..exceptions import FitFailedWarning, NotFittedError from ._split import check_cv from ..preprocessing import LabelEncoder @@ -233,7 +233,13 @@ def cross_validate(estimator, X, y=None, *, groups=None, scoring=None, cv=None, X, y, groups = indexable(X, y, groups) cv = check_cv(cv, y, classifier=is_classifier(estimator)) - scorers, _ = _check_multimetric_scoring(estimator, scoring=scoring) + + if callable(scoring): + scorers = scoring + elif scoring is None or isinstance(scoring, str): + scorers = check_scoring(estimator, scoring) + else: + scorers = _check_multimetric_scoring(estimator, scoring) # We clone the estimator to make sure that all the folds are # independent, and that it is pickle-able. @@ -247,6 +253,12 @@ def cross_validate(estimator, X, y=None, *, groups=None, scoring=None, cv=None, error_score=error_score) for train, test in cv.split(X, y, groups)) + # For callabe scoring, the return type is only know after calling. If the + # return type is a dictionary, the error scores can now be inserted with + # the correct key. + if callable(scoring): + _insert_error_scores(results, error_score) + results = _aggregate_score_dicts(results) ret = {} @@ -256,19 +268,52 @@ def cross_validate(estimator, X, y=None, *, groups=None, scoring=None, cv=None, if return_estimator: ret['estimator'] = results["estimator"] - test_scores = _aggregate_score_dicts(results["test_scores"]) + test_scores_dict = _normalize_score_results(results["test_scores"]) if return_train_score: - train_scores = _aggregate_score_dicts(results["train_scores"]) + train_scores_dict = _normalize_score_results(results["train_scores"]) - for name in test_scores: - ret['test_%s' % name] = test_scores[name] + for name in test_scores_dict: + ret['test_%s' % name] = test_scores_dict[name] if return_train_score: key = 'train_%s' % name - ret[key] = train_scores[name] + ret[key] = train_scores_dict[name] return ret +def _insert_error_scores(results, error_score): + """Insert error in results by replacing them with `error_score`. + + This only applies to multimetric scores because `_fit_and_score` will + handle the single metric case.""" + successful_score = None + failed_indices = [] + for i, result in enumerate(results): + if result["fit_failed"]: + failed_indices.append(i) + elif successful_score is None: + successful_score = result["test_scores"] + + if successful_score is None: + raise NotFittedError("All estimators failed to fit") + + if isinstance(successful_score, dict): + formatted_error = {name: error_score for name in successful_score} + for i in failed_indices: + results[i]["test_scores"] = formatted_error.copy() + if "train_scores" in results[i]: + results[i]["train_scores"] = formatted_error.copy() + + +def _normalize_score_results(scores, scaler_score_key='score'): + """Creates a scoring dictionary based on the type of `scores`""" + if isinstance(scores[0], dict): + # multimetric scoring + return _aggregate_score_dicts(scores) + # scaler + return {scaler_score_key: scores} + + @_deprecate_positional_args def cross_val_score(estimator, X, y=None, *, groups=None, scoring=None, cv=None, n_jobs=None, verbose=0, fit_params=None, @@ -497,6 +542,8 @@ def _fit_and_score(estimator, X, y, scorer, train, test, verbose, The parameters that have been evaluated. estimator : estimator object The fitted estimator. + fit_failed : bool + The estimator failed to fit. """ progress_msg = "" if verbose > 2: @@ -567,7 +614,10 @@ def _fit_and_score(estimator, X, y, scorer, train, test, verbose, raise ValueError("error_score must be the string 'raise' or a" " numeric value. (Hint: if using 'raise', please" " make sure that it has been spelled correctly.)") + result["fit_failed"] = True else: + result["fit_failed"] = False + fit_time = time.time() - start_time test_scores = _score(estimator, X_test, y_test, scorer) score_time = time.time() - start_time - fit_time diff --git a/sklearn/model_selection/tests/test_search.py b/sklearn/model_selection/tests/test_search.py index 7c4f5a2ee9b1e..6007db07a5b38 100644 --- a/sklearn/model_selection/tests/test_search.py +++ b/sklearn/model_selection/tests/test_search.py @@ -62,6 +62,7 @@ from sklearn.metrics import accuracy_score from sklearn.metrics import make_scorer from sklearn.metrics import roc_auc_score +from sklearn.metrics import confusion_matrix from sklearn.metrics.pairwise import euclidean_distances from sklearn.impute import SimpleImputer from sklearn.pipeline import Pipeline @@ -1748,6 +1749,119 @@ def get_n_splits(self, *args, **kw): ridge.fit(X[:train_size], y[:train_size]) +def test_callable_multimetric_confusion_matrix(): + def custom_scorer(clf, X, y): + y_pred = clf.predict(X) + cm = confusion_matrix(y, y_pred) + return {'tn': cm[0, 0], 'fp': cm[0, 1], 'fn': cm[1, 0], 'tp': cm[1, 1]} + + X, y = make_classification(n_samples=40, n_features=4, + random_state=42) + est = LinearSVC(random_state=42) + search = GridSearchCV(est, {'C': [0.1, 1]}, scoring=custom_scorer, + refit='fp') + + search.fit(X, y) + + score_names = ['tn', 'fp', 'fn', 'tp'] + for name in score_names: + assert "mean_test_{}".format(name) in search.cv_results_ + + y_pred = search.predict(X) + cm = confusion_matrix(y, y_pred) + assert search.score(X, y) == pytest.approx(cm[0, 1]) + + +def test_callable_multimetric_same_as_list_of_strings(): + def custom_scorer(est, X, y): + y_pred = est.predict(X) + return {'recall': recall_score(y, y_pred), + 'accuracy': accuracy_score(y, y_pred)} + + X, y = make_classification(n_samples=40, n_features=4, + random_state=42) + est = LinearSVC(random_state=42) + search_callable = GridSearchCV(est, {'C': [0.1, 1]}, + scoring=custom_scorer, refit='recall') + search_str = GridSearchCV(est, {'C': [0.1, 1]}, + scoring=['recall', 'accuracy'], refit='recall') + + search_callable.fit(X, y) + search_str.fit(X, y) + + assert search_callable.best_score_ == pytest.approx(search_str.best_score_) + assert search_callable.best_index_ == search_str.best_index_ + assert search_callable.score(X, y) == pytest.approx(search_str.score(X, y)) + + +def test_callable_single_metric_same_as_single_string(): + def custom_scorer(est, X, y): + y_pred = est.predict(X) + return recall_score(y, y_pred) + + X, y = make_classification(n_samples=40, n_features=4, + random_state=42) + est = LinearSVC(random_state=42) + search_callable = GridSearchCV(est, {'C': [0.1, 1]}, + scoring=custom_scorer, refit=True) + search_str = GridSearchCV(est, {'C': [0.1, 1]}, + scoring='recall', refit='recall') + + search_callable.fit(X, y) + search_str.fit(X, y) + + assert search_callable.best_score_ == pytest.approx(search_str.best_score_) + assert search_callable.best_index_ == search_str.best_index_ + assert search_callable.score(X, y) == pytest.approx(search_str.score(X, y)) + + +def test_callable_multimetric_error_on_invalid_key(): + def bad_scorer(est, X, y): + return {'bad_name': 1} + + X, y = make_classification(n_samples=40, n_features=4, + random_state=42) + clf = GridSearchCV(LinearSVC(random_state=42), {'C': [0.1, 1]}, + scoring=bad_scorer, refit='good_name') + + msg = ('For multi-metric scoring, the parameter refit must be set to a ' + 'scorer key or a callable to refit') + with pytest.raises(ValueError, match=msg): + clf.fit(X, y) + + +def test_callable_multimetric_error_failing_clf(): + def custom_scorer(est, X, y): + return {'acc': 1} + + X, y = make_classification(n_samples=20, n_features=10, random_state=0) + + clf = FailingClassifier() + gs = GridSearchCV(clf, [{'parameter': [0, 1, 2]}], scoring=custom_scorer, + refit=False, error_score=0.1) + + with pytest.warns(FitFailedWarning, match='Estimator fit failed'): + gs.fit(X, y) + + assert_allclose(gs.cv_results_['mean_test_acc'], [1, 1, 0.1]) + + +def test_callable_multimetric_clf_all_fails(): + def custom_scorer(est, X, y): + return {'acc': 1} + X, y = make_classification(n_samples=20, n_features=10, random_state=0) + + clf = FailingClassifier() + + gs = GridSearchCV(clf, [{'parameter': [2, 2, 2]}], scoring=custom_scorer, + refit=False, error_score=0.1) + + with pytest.warns(FitFailedWarning, match='Estimator fit failed'), \ + pytest.raises(NotFittedError, + match="All estimators failed to fit"): + gs.fit(X, y) + + def test_n_features_in(): # make sure grid search and random search delegate n_features_in to the # best estimator diff --git a/sklearn/model_selection/tests/test_split.py b/sklearn/model_selection/tests/test_split.py index 0f9238d63ec64..4250eb8af8748 100644 --- a/sklearn/model_selection/tests/test_split.py +++ b/sklearn/model_selection/tests/test_split.py @@ -1546,7 +1546,7 @@ def test_nested_cv(): StratifiedShuffleSplit(n_splits=3, random_state=0)] for inner_cv, outer_cv in combinations_with_replacement(cvs, 2): - gs = GridSearchCV(Ridge(solver="eigen"), param_grid={'alpha': [1, .1]}, + gs = GridSearchCV(Ridge(), param_grid={'alpha': [1, .1]}, cv=inner_cv, error_score='raise') cross_val_score(gs, X=X, y=y, groups=groups, cv=outer_cv, fit_params={'groups': groups}) diff --git a/sklearn/model_selection/tests/test_validation.py b/sklearn/model_selection/tests/test_validation.py index 1cb08cc13f767..6e1faa1088075 100644 --- a/sklearn/model_selection/tests/test_validation.py +++ b/sklearn/model_selection/tests/test_validation.py @@ -52,13 +52,14 @@ from sklearn.metrics import precision_recall_fscore_support from sklearn.metrics import precision_score from sklearn.metrics import r2_score +from sklearn.metrics import mean_squared_error from sklearn.metrics import check_scoring from sklearn.linear_model import Ridge, LogisticRegression, SGDClassifier from sklearn.linear_model import PassiveAggressiveClassifier, RidgeClassifier from sklearn.ensemble import RandomForestClassifier from sklearn.neighbors import KNeighborsClassifier -from sklearn.svm import SVC +from sklearn.svm import SVC, LinearSVC from sklearn.cluster import KMeans from sklearn.impute import SimpleImputer @@ -317,8 +318,8 @@ def test_cross_validate_invalid_scoring_param(): cross_validate, estimator, X, y, scoring=[[make_scorer(precision_score)]]) - error_message_regexp = (".*should either be.*string or callable.*for " - "single.*.*dict.*for multi.*") + error_message_regexp = (".*scoring is invalid.*Refer to the scoring " + "glossary for details:.*") # Empty dict should raise invalid scoring error assert_raises_regex(ValueError, "An empty dict", @@ -342,16 +343,6 @@ def test_cross_validate_invalid_scoring_param(): cross_validate, estimator, X, y, scoring={"foo": multiclass_scorer}) - multivalued_scorer = make_scorer(confusion_matrix) - - # Multiclass Scorers that return multiple values are not supported yet - assert_raises_regex(ValueError, "scoring must return a number, got", - cross_validate, SVC(), X, y, - scoring=multivalued_scorer) - assert_raises_regex(ValueError, "scoring must return a number, got", - cross_validate, SVC(), X, y, - scoring={"foo": multivalued_scorer}) - assert_raises_regex(ValueError, "'mse' is not a valid scoring value.", cross_validate, SVC(), X, y, scoring="mse") @@ -463,9 +454,16 @@ def check_cross_validate_multi_metric(clf, X, y, scores): # Test multimetric evaluation when scoring is a list / dict (train_mse_scores, test_mse_scores, train_r2_scores, test_r2_scores, fitted_estimators) = scores + + def custom_scorer(clf, X, y): + y_pred = clf.predict(X) + return {'r2': r2_score(y, y_pred), + 'neg_mean_squared_error': -mean_squared_error(y, y_pred)} + all_scoring = (('r2', 'neg_mean_squared_error'), {'r2': make_scorer(r2_score), - 'neg_mean_squared_error': 'neg_mean_squared_error'}) + 'neg_mean_squared_error': 'neg_mean_squared_error'}, + custom_scorer) keys_sans_train = {'test_r2', 'test_neg_mean_squared_error', 'fit_time', 'score_time'} @@ -1767,3 +1765,20 @@ def two_params_scorer(estimator, X_test): fit_and_score_args = [None, None, None, two_params_scorer] assert_raise_message(ValueError, error_message, _score, *fit_and_score_args) + + +def test_callable_multimetric_confusion_matrix_cross_validate(): + def custom_scorer(clf, X, y): + y_pred = clf.predict(X) + cm = confusion_matrix(y, y_pred) + return {'tn': cm[0, 0], 'fp': cm[0, 1], 'fn': cm[1, 0], 'tp': cm[1, 1]} + + X, y = make_classification(n_samples=40, n_features=4, + random_state=42) + est = LinearSVC(random_state=42) + est.fit(X, y) + cv_results = cross_validate(est, X, y, cv=5, scoring=custom_scorer) + + score_names = ['tn', 'fp', 'fn', 'tp'] + for name in score_names: + assert "test_{}".format(name) in cv_results