8000 MAINT Parameters validation for metrics.check_scoring (#26041) · scikit-learn/scikit-learn@4af3087 · GitHub
[go: up one dir, main page]

Skip to content

Commit 4af3087

Browse files
Théophile Barangerjeremiedbb
Théophile Baranger
andauthored
MAINT Parameters validation for metrics.check_scoring (#26041)
Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com>
1 parent 092caed commit 4af3087

File tree

4 files changed

+68
-96
lines changed

4 files changed

+68
-96
lines changed

doc/modules/model_evaluation.rst

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -115,17 +115,11 @@ Usage examples:
115115
>>> clf = svm.SVC(random_state=0)
116116
>>> cross_val_score(clf, X, y, cv=5, scoring='recall_macro')
117117
array([0.96..., 0.96..., 0.96..., 0.93..., 1. ])
118-
>>> model = svm.SVC()
119-
>>> cross_val_score(model, X, y, cv=5, scoring='wrong_choice')
120-
Traceback (most recent call last):
121-
ValueError: 'wrong_choice' is not a valid scoring value. Use
122-
sklearn.metrics.get_scorer_names() to get valid options.
123118

124119
.. note::
125120

126-
The values listed by the ``ValueError`` exception correspond to the
127-
functions measuring prediction accuracy described in the following
128-
sections. You can retrieve the names of all available scorers by calling
121+
If a wrong scoring name is passed, an ``InvalidParameterError`` is raised.
122+
You can retrieve the names of all available scorers by calling
129123
:func:`~sklearn.metrics.get_scorer_names`.
130124

131125
.. currentmodule:: sklearn.metrics

sklearn/metrics/_scorer.py

Lines changed: 65 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
# Arnaud Joly <arnaud.v.joly@gmail.com>
1919
# License: Simplified BSD
2020

21-
from collections.abc import Iterable
2221
from functools import partial
2322
from collections import Counter
2423
from traceback import format_exc
@@ -65,7 +64,7 @@
6564

6665
from ..utils.multiclass import type_of_target
6766
from ..base import is_regressor
68-
from ..utils._param_validation import validate_params
67+
from ..utils._param_validation import HasMethods, StrOptions, validate_params
6968

7069

7170
def _cached_call(cache, estimator, method, *args, **kwargs):
@@ -451,79 +450,6 @@ def _passthrough_scorer(estimator, *args, **kwargs):
451450
return estimator.score(*args, **kwargs)
452451

453452

454-
def check_scoring(estimator, scoring=None, *, allow_none=False):
455-
"""Determine scorer from user options.
456-
457-
A TypeError will be thrown if the estimator cannot be scored.
458-
459-
Parameters
460-
----------
461-
estimator : estimator object implementing 'fit'
462-
The object to use to fit the data.
463-
464-
scoring : str or callable, default=None
465-
A string (see model evaluation documentation) or
466-
a scorer callable object / function with signature
467-
``scorer(estimator, X, y)``.
468-
If None, the provided estimator object's `score` method is used.
469-
470-
allow_none : bool, default=False
471-
If no scoring is specified and the estimator has no score function, we
472-
can either return None or raise an exception.
473-
474-
Returns
475-
-------
476-
scoring : callable
477-
A scorer callable object / function with signature
478-
``scorer(estimator, X, y)``.
479-
"""
480-
if not hasattr(estimator, "fit"):
481-
raise TypeError(
482-
"estimator should be an estimator implementing 'fit' method, %r was passed"
483-
% estimator
484-
)
485-
if isinstance(scoring, str):
486-
return get_scorer(scoring)
487-
elif callable(scoring):
488-
# Heuristic to ensure user has not passed a metric
489-
module = getattr(scoring, "__module__", None)
490-
if (
491-
hasattr(module, "startswith")
492-
and module.startswith("sklearn.metrics.")
493-
and not module.startswith("sklearn.metrics._scorer")
494-
and not module.startswith("sklearn.metrics.tests.")
495-
):
496-
raise ValueError(
497-
"scoring value %r looks like it is a metric "
498-
"function rather than a scorer. A scorer should "
499-
"require an estimator as its first parameter. "
500-
"Please use `make_scorer` to convert a metric "
501-
"to a scorer." % scoring
502-
)
503-
return get_scorer(scoring)
504-
elif scoring is None:
505-
if hasattr(estimator, "score"):
506-
return _passthrough_scorer
507-
elif allow_none:
508-
return None
509-
else:
510-
raise TypeError(
511-
"If no scoring is specified, the estimator passed should "
512-
"have a 'score' method. The estimator %r does not." % estimator
513-
)
514-
elif isinstance(scoring, Iterable):
515-
raise ValueError(
516-
"For evaluating multiple scores, use "
517-
"sklearn.model_selection.cross_validate instead. "
518-
"{0} was passed.".format(scoring)
519-
)
520-
else:
521-
raise ValueError(
522-
"scoring value should either be a callable, string or None. %r was passed"
523-
% scoring
524-
)
525-
526-
527453
def _check_multimetric_scoring(estimator, scoring):
528454
"""Check the scoring parameter in cases when multiple metrics are allowed.
529455
@@ -882,3 +808,67 @@ def get_scorer_names():
882808
_SCORERS[qualified_name] = make_scorer(metric, pos_label=None, average=average)
883809

884810
SCORERS = _DeprecatedScorers(_SCORERS)
811+
812+
813+
@validate_params(
814+
{
815+
"estimator": [HasMethods("fit")],
816+
"scoring": [StrOptions(set(get_scorer_names())), callable, None],
817+
"allow_none": ["boolean"],
818+
}
819+
)
820+
def check_scoring(estimator, scoring=None, *, allow_none=False):
821+
"""Determine scorer from user options.
822+
823+
A TypeError will be thrown if the estimator cannot be scored.
824+
825+
Parameters
826+
----------
827+
estimator : estimator object implementing 'fit'
828+
The object to use to fit the data.
829+
830+
scoring : str or callable, default=None
831+
A string (see model evaluation documentation) or
832+
a scorer callable object / function with signature
833+
``scorer(estimator, X, y)``.
834+
If None, the provided estimator object's `score` method is used.
835+
836+
allow_none : bool, default=False
837+
If no scoring is specified and the estimator has no score function, we
838+
can either return None or raise an exception.
839+
840+
Returns
841+
-------
842+
scoring : callable
843+
A scorer callable object / function with signature
844+
``scorer(estimator, X, y)``.
845+
"""
846+
if isinstance(scoring, str):
847+
return get_scorer(scoring)
848+
if callable(scoring):
849+
# Heuristic to ensure user has not passed a metric
850+
module = getattr(scoring, "__module__", None)
851+
if (
852+
hasattr(module, "startswith")
853+
and module.startswith("sklearn.metrics.")
854+
and not module.startswith("sklearn.metrics._scorer")
855+
and not module.startswith("sklearn.metrics.tests.")
856+
):
857+
raise ValueError(
858+
"scoring value %r looks like it is a metric "
859+
"function rather than a scorer. A scorer should "
860+
"require an estimator as its first parameter. "
861+
"Please use `make_scorer` to convert a metric "
862+
"to a scorer." % scoring
863+
)
864+
return get_scorer(scoring)
865+
if scoring is None:
866+
if hasattr(estimator, "score"):
867+
return _passthrough_scorer
868+
elif allow_none:
869+
return None
870+
else:
871+
raise TypeError(
872+
"If no scoring is specified, the estimator passed should "
873+
"have a 'score' method. The estimator %r does not." % estimator
874+
)

sklearn/metrics/tests/test_score_objects.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -179,12 +179,6 @@ def teardown_module():
179179
shutil.rmtree(TEMP_FOLDER)
180180

181181

182-
class EstimatorWithoutFit:
183-
"""Dummy estimator to test scoring validators"""
184-
185-
pass
186-
187-
188182
class EstimatorWithFit(BaseEstimator):
189183
"""Dummy estimator to test scoring validators"""
190184

@@ -228,13 +222,6 @@ def test_all_scorers_repr():
228222

229223
def check_scoring_validator_for_single_metric_usecases(scoring_validator):
230224
# Test all branches of single metric usecases
231-
estimator = EstimatorWithoutFit()
232-
pattern = (
233-
r"estimator should be an estimator implementing 'fit' method," r" .* was passed"
234-
)
235-
with pytest.raises(TypeError, match=pattern):
236-
scoring_validator(estimator)
237-
238225
estimator = EstimatorWithFitAndScore()
239226
estimator.fit([[1]], [1])
240227
scorer = scoring_validator(estimator)

sklearn/tests/test_public_functions.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,7 @@ def _check_function_param_validation(
183183
"sklearn.metrics.balanced_accuracy_score",
184184
"sklearn.metrics.brier_score_loss",
185185
"sklearn.metrics.calinski_harabasz_score",
186+
"sklearn.metrics.check_scoring",
186187
"sklearn.metrics.completeness_score",
187188
"sklearn.metrics.class_likelihood_ratios",
188189
"sklearn.metrics.classification_report",

0 commit comments

Comments
 (0)
0