8000 MAINT Parameters validation for metrics.check_scoring · Pull Request #26041 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

MAINT Parameters validation for metrics.check_scoring #26041

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from Apr 27, 2023
10 changes: 2 additions & 8 deletions doc/modules/model_evaluation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -115,17 +115,11 @@ Usage examples:
>>> clf = svm.SVC(random_state=0)
>>> cross_val_score(clf, X, y, cv=5, scoring='recall_macro')
array([0.96..., 0.96..., 0.96..., 0.93..., 1. ])
>>> model = svm.SVC()
>>> cross_val_score(model, X, y, cv=5, scoring='wrong_choice')
Traceback (most recent call last):
ValueError: 'wrong_choice' is not a valid scoring value. Use
sklearn.metrics.get_scorer_names() to get valid options.

.. note::

The values listed by the ``ValueError`` exception correspond to the
functions measuring prediction accuracy described in the following
sections. You can retrieve the names of all available scorers by calling
If a wrong scoring name is passed, an ``InvalidParameterError`` is raised.
You can retrieve the names of all available scorers by calling
:func:`~sklearn.metrics.get_scorer_names`.

.. currentmodule:: sklearn.metrics
Expand Down
140 changes: 65 additions & 75 deletions sklearn/metrics/_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
# Arnaud Joly <arnaud.v.joly@gmail.com>
# License: Simplified BSD

from collections.abc import Iterable
from functools import partial
from collections import Counter
from traceback import format_exc
Expand Down Expand Up @@ -65,7 +64,7 @@

from ..utils.multiclass import type_of_target
from ..base import is_regressor
from ..utils._param_validation import validate_params
from ..utils._param_validation import HasMethods, StrOptions, validate_params


def _cached_call(cache, estimator, method, *args, **kwargs):
Expand Down Expand Up @@ -451,79 +450,6 @@ def _passthrough_scorer(estimator, *args, **kwargs):
return estimator.score(*args, **kwargs)


def check_scoring(estimator, scoring=None, *, allow_none=False):
"""Determine scorer from user options.

A TypeError will be thrown if the estimator cannot be scored.

Parameters
----------
estimator : estimator object implementing 'fit'
The object to use to fit the data.

scoring : str or callable, default=None
A string (see model evaluation documentation) or
a scorer callable object / function with signature
``scorer(estimator, X, y)``.
If None, the provided estimator object's `score` method is used.

allow_none : bool, default=False
If no scoring is specified and the estimator has no score function, we
can either return None or raise an exception.

Returns
-------
scoring : callable
A scorer callable object / function with signature
``scorer(estimator, X, y)``.
"""
if not hasattr(estimator, "fit"):
raise TypeError(
"estimator should be an estimator implementing 'fit' method, %r was passed"
% estimator
)
if isinstance(scoring, str):
return get_scorer(scoring)
elif callable(scoring):
# Heuristic to ensure user has not passed a metric
module = getattr(scoring, "__module__", None)
if (
hasattr(module, "startswith")
and module.startswith("sklearn.metrics.")
and not module.startswith("sklearn.metrics._scorer")
and not module.startswith("sklearn.metrics.tests.")
):
raise ValueError(
"scoring value %r looks like it is a metric "
"function rather than a scorer. A scorer should "
"require an estimator as its first parameter. "
"Please use `make_scorer` to convert a metric "
"to a scorer." % scoring
)
return get_scorer(scoring)
elif scoring is None:
if hasattr(estimator, "score"):
return _passthrough_scorer
elif allow_none:
return None
else:
raise TypeError(
"If no scoring is specified, the estimator passed should "
"have a 'score' method. The estimator %r does not." % estimator
)
elif isinstance(scoring, Iterable):
raise ValueError(
"For evaluating multiple scores, use "
"sklearn.model_selection.cross_validate instead. "
"{0} was passed.".format(scoring)
)
else:
raise ValueError(
"scoring value should either be a callable, string or None. %r was passed"
% scoring
)


def _check_multimetric_scoring(estimator, scoring):
"""Check the scoring parameter in cases when multiple metrics are allowed.

Expand Down Expand Up @@ -882,3 +808,67 @@ def get_scorer_names():
_SCORERS[qualified_name] = make_scorer(metric, pos_label=None, average=average)

SCORERS = _DeprecatedScorers(_SCORERS)


@validate_params(
{
"estimator": [HasMethods("fit")],
"scoring": [StrOptions(set(get_scorer_names())), callable, None],
"allow_none": ["boolean"],
}
)
def check_scoring(estimator, scoring=None, *, allow_none=False):
"""Determine scorer from user options.

A TypeError will be thrown if the estimator cannot be scored.

Parameters
----------
estimator : estimator object implementing 'fit'
The object to use to fit the data.

scoring : str or callable, default=None
A string (see model evaluation documentation) or
a scorer callable object / function with signature
``scorer(estimator, X, y)``.
If None, the provided estimator object's `score` method is used.

allow_none : bool, default=False
If no scoring is specified and the estimator has no score function, we
can either return None or raise an exception.

Returns
-------
scoring : callable
A scorer callable object / function with signature
``scorer(estimator, X, y)``.
"""
if isinstance(scoring, str):
return get_scorer(scoring)
if callable(scoring):
# Heuristic to ensure user has not passed a metric
module = getattr(scoring, "__module__", None)
if (
hasattr(module, "startswith")
and module.startswith("sklearn.metrics.")
and not module.startswith("sklearn.metrics._scorer")
and not module.startswith("sklearn.metrics.tests.")
):
raise ValueError(
"scoring value %r looks like it is a metric "
"function rather than a scorer. A scorer should "
"require an estimator as its first parameter. "
"Please use `make_scorer` to convert a metric "
"to a scorer." % scoring
)
return get_scorer(scoring)
if scoring is None:
if hasattr(estimator, "score"):
return _passthrough_scorer
elif allow_none:
return None
else:
raise TypeError(
"If no scoring is specified, the estimator passed should "
"have a 'score' method. The estimator %r does not." % estimator
)
13 changes: 0 additions & 13 deletions sklearn/metrics/tests/test_score_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,12 +179,6 @@ def teardown_module():
shutil.rmtree(TEMP_FOLDER)


class EstimatorWithoutFit:
"""Dummy estimator to test scoring validators"""

pass


class EstimatorWithFit(BaseEstimator):
"""Dummy estimator to test scoring validators"""

Expand Down Expand Up @@ -228,13 +222,6 @@ def test_all_scorers_repr():

def check_scoring_validator_for_single_metric_usecases(scoring_validator):
# Test all branches of single metric usecases
estimator = EstimatorWithoutFit()
pattern = (
r"estimator should be an estimator implementing 'fit' method," r" .* was passed"
)
with pytest.raises(TypeError, match=pattern):
scoring_validator(estimator)

estimator = EstimatorWithFitAndScore()
estimator.fit([[1]], [1])
scorer = scoring_validator(estimator)
Expand Down
1 change: 1 addition & 0 deletions sklearn/tests/test_public_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ def _check_function_param_validation(
"sklearn.metrics.balanced_accuracy_score",
"sklearn.metrics.brier_score_loss",
"sklearn.metrics.calinski_harabasz_score",
"sklearn.metrics.check_scoring",
"sklearn.metrics.completeness_score",
"sklearn.metrics.class_likelihood_ratios",
"sklearn.metrics.classification_report",
Expand Down
0