8000 MNT Moving `_CurveScorer` from `model_selection` to `metrics` (#29216) · scikit-learn/scikit-learn@3b7734e · GitHub
[go: up one dir, main page]

Skip to content

Commit 3b7734e

Browse files
vitalisetglemaitre
andauthored
MNT Moving _CurveScorer from model_selection to metrics (#29216)
Co-authored-by: Guillaume Lemaitre <guillaume@probabl.ai>
1 parent 2ffe206 commit 3b7734e

File tree

4 files changed

+217
-211
lines changed

4 files changed

+217
-211
lines changed

sklearn/metrics/_scorer.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,11 @@
2121
from collections import Counter
2222
from functools import partial
2323
from inspect import signature
24+
from numbers import Integral
2425
from traceback import format_exc
2526

27+
import numpy as np
28+
2629
from ..base import is_regressor
2730
from ..utils import Bunch
2831
from ..utils._param_validation import HasMethods, Hidden, StrOptions, validate_params
@@ -1083,3 +1086,120 @@ def check_scoring(estimator=None, scoring=None, *, allow_none=False, raise_exc=T
10831086
"If no scoring is specified, the estimator passed should "
10841087
"have a 'score' method. The estimator %r does not." % estimator
10851088
)
1089+
1090+
1091+
def _threshold_scores_to_class_labels(y_score, threshold, classes, pos_label):
1092+
"""Threshold `y_score` and return the associated class labels."""
1093+
if pos_label is None:
1094+
map_thresholded_score_to_label = np.array([0, 1])
1095+
else:
1096+
pos_label_idx = np.flatnonzero(classes == pos_label)[0]
1097+
neg_label_idx = np.flatnonzero(classes != pos_label)[0]
1098+
map_thresholded_score_to_label = np.array([neg_label_idx, pos_label_idx])
1099+
1100+
return classes[map_thresholded_score_to_label[(y_score >= threshold).astype(int)]]
1101+
1102+
1103+
class _CurveScorer(_BaseScorer):
1104+
"""Scorer taking a continuous response and output a score for each threshold.
1105+
1106+
Parameters
1107+
----------
1108+
score_func : callable
1109+
The score function to use. It will be called as
1110+
`score_func(y_true, y_pred, **kwargs)`.
1111+
1112+
sign : int
1113+
Either 1 or -1 to returns the score with `sign * score_func(estimator, X, y)`.
1114+
Thus, `sign` defined if higher scores are better or worse.
1115+
1116+
kwargs : dict
1117+
Additional parameters to pass to the score function.
1118+
1119+
thresholds : int or array-like
1120+
Related to the number of decision thresholds for which we want to compute the
1121+
score. If an integer, it will be used to generate `thresholds` thresholds
1122+
uniformly distributed between the minimum and maximum predicted scores. If an
1123+
array-like, it will be used as the thresholds.
1124+
1125+
response_method : str
1126+
The method to call on the estimator to get the response values.
1127+
"""
1128+
1129+
def __init__(self, score_func, sign, kwargs, thresholds, response_method):
1130+
super().__init__(
1131+
score_func=score_func,
1132+
sign=sign,
1133+
kwargs=kwargs,
1134+
response_method=response_method,
1135+
)
1136+
self._thresholds = thresholds
1137+
1138+
@classmethod
1139+
def from_scorer(cls, scorer, response_method, thresholds):
1140+
"""Create a continuous scorer from a normal scorer."""
1141+
instance = cls(
1142+
score_func=scorer._score_func,
1143+
sign=scorer._sign,
1144+
response_method=response_method,
1145+
thresholds=thresholds,
1146+
kwargs=scorer._kwargs,
1147+
)
1148+
# transfer the metadata request
1149+
instance._metadata_request = scorer._get_metadata_request()
1150+
return instance
1151+
1152+
def _score(self, method_caller, estimator, X, y_true, **kwargs):
1153+
"""Evaluate predicted target values for X relative to y_true.
1154+
1155+
Parameters
1156+
----------
1157+
method_caller : callable
1158+
Returns predictions given an estimator, method name, and other
1159+
arguments, potentially caching results.
1160+
1161+
estimator : object
1162+
Trained estimator to use for scoring.
1163+
1164+
X : {array-like, sparse matrix} of shape (n_samples, n_features)
1165+
Test data that will be fed to estimator.predict.
1166+
1167+
y_true : array-like of shape (n_samples,)
1168+
Gold standard target values for X.
1169+
1170+
**kwargs : dict
1171+
Other parameters passed to the scorer. Refer to
1172+
:func:`set_score_request` for more details.
1173+
1174+
Returns
1175+
-------
1176+
scores : ndarray of shape (thresholds,)
1177+
The scores associated to each threshold.
1178+
1179+
potential_thresholds : ndarray of shape (thresholds,)
1180+
The potential thresholds used to compute the scores.
1181+
"""
1182+
pos_label = self._get_pos_label()
1183+
y_score = method_caller(
1184+
estimator, self._response_method, X, pos_label=pos_label
1185+
)
1186+
1187+
scoring_kwargs = {**self._kwargs, **kwargs}
1188+
if isinstance(self._thresholds, Integral):
1189+
potential_thresholds = np.linspace(
1190+
np.min(y_score), np.max(y_score), self._thresholds
1191+
)
1192+
else:
1193+
potential_thresholds = np.asarray(self._thresholds)
1194+
score_thresholds = [
1195+
self._sign
1196+
* self._score_func(
1197+
y_true,
1198+
_threshold_scores_to_class_labels(
1199+
y_score, th, estimator.classes_, pos_label
1200+
),
1201+
**scoring_kwargs,
1202+
)
1203+
for th in potential_thresholds
1204+
]
1205+
return np.array(score_thresholds), potential_thresholds

sklearn/metrics/tests/test_score_objects.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
from sklearn.metrics import cluster as cluster_module
4444
from sklearn.metrics._scorer import (
4545
_check_multimetric_scoring,
46+
_CurveScorer,
4647
_MultimetricScorer,
4748
_PassthroughScorer,
4849
_Scorer,
@@ -1605,3 +1606,94 @@ def test_metadata_routing_multimetric_metadata_routing(enable_metadata_routing):
16051606
multimetric_scorer = _MultimetricScorer(scorers={"acc": get_scorer("accuracy")})
16061607
with config_context(enable_metadata_routing=enable_metadata_routing):
16071608
multimetric_scorer(estimator, X, y)
1609+
1610+
1611+
def test_curve_scorer():
1612+
"""Check the behaviour of the `_CurveScorer` class."""
1613+
X, y = make_classification(random_state=0)
1614+
estimator = LogisticRegression().fit(X, y)
1615+
curve_scorer = _CurveScorer(
1616+
balanced_accuracy_score,
1617+
sign=1,
1618+
response_method="predict_proba",
1619+
thresholds=10,
1620+
kwargs={},
1621+
)
1622+
scores, thresholds = curve_scorer(estimator, X, y)
1623+
1624+
assert thresholds.shape == scores.shape
1625+
# check that the thresholds are probabilities with extreme values close to 0 and 1.
1626+
# they are not exactly 0 and 1 because they are the extremum of the
1627+
# `estimator.predict_proba(X)` values.
1628+
assert 0 <= thresholds.min() <= 0.01
1629+
assert 0.99 <= thresholds.max() <= 1
1630+
# balanced accuracy should be between 0.5 and 1 when it is not adjusted
1631+
assert 0.5 <= scores.min() <= 1
1632+
1633+
# check that passing kwargs to the scorer works
1634+
curve_scorer = _CurveScorer(
1635+
balanced_accuracy_score,
1636+
sign=1,
1637+
response_method="predict_proba",
1638+
thresholds=10,
1639+
kwargs={"adjusted": True},
1640+
)
1641+
scores, thresholds = curve_scorer(estimator, X, y)
1642+
1643+
# balanced accuracy should be between 0.5 and 1 when it is not adjusted
1644+
assert 0 <= scores.min() <= 0.5
1645+
1646+
# check that we can inverse the sign of the score when dealing with `neg_*` scorer
1647+
curve_scorer = _CurveScorer(
1648+
balanced_accuracy_score,
1649+
sign=-1,
1650+
response_method="predict_proba",
1651+
thresholds=10,
1652+
kwargs={"adjusted": True},
1653+
)
1654+
scores, thresholds = curve_scorer(estimator, X, y)
1655+
1656+
assert all(scores <= 0)
1657+
1658+
1659+
def test_curve_scorer_pos_label(global_random_seed):
1660+
"""Check that we propagate properly the `pos_label` parameter to the scorer."""
1661+
n_samples = 30
1662+
X, y = make_classification(
1663+
n_samples=n_samples, weights=[0.9, 0.1], random_state=global_random_seed
1664+
)
1665+
estimator = LogisticRegression().fit(X, y)
1666+
1667+
curve_scorer = _CurveScorer(
1668+
recall_score,
1669+
sign=1,
1670+
response_method="predict_proba",
1671+
thresholds=10,
1672+
kwargs={"pos_label": 1},
1673+
)
1674+
scores_pos_label_1, thresholds_pos_label_1 = curve_scorer(estimator, X, y)
1675+
1676+
curve_scorer = _CurveScorer(
1677+
recall_score,
1678+
sign=1,
1679+
response_method="predict_proba",
1680+
thresholds=10,
1681+
kwargs={"pos_label": 0},
1682+
)
1683+
scores_pos_label_0, thresholds_pos_label_0 = curve_scorer(estimator, X, y)
1684+
1685+
# Since `pos_label` is forwarded to the curve_scorer, the thresholds are not equal.
1686+
assert not (thresholds_pos_label_1 == thresholds_pos_label_0).all()
1687+
# The min-max range for the thresholds is defined by the probabilities of the
1688+
# `pos_label` class (the column of `predict_proba`).
1689+
y_pred = estimator.predict_proba(X)
1690+
assert thresholds_pos_label_0.min() == pytest.approx(y_pred.min(axis=0)[0])
1691+
assert thresholds_pos_label_0.max() == pytest.approx(y_pred.max(axis=0)[0])
1692+
assert thresholds_pos_label_1.min() == pytest.approx(y_pred.min(axis=0)[1])
1693+
assert thresholds_pos_label_1.max() == pytest.approx(y_pred.max(axis=0)[1])
1694+
1695+
# The recall cannot be negative and `pos_label=1` should have a higher recall
1696+
# since there is less samples to be considered.
1697+
assert 0.0 < scores_pos_label_0.min() < scores_pos_label_1.min()
1698+
assert scores_pos_label_0.max() == pytest.approx(1.0)
1699+
assert scores_pos_label_1.max() == pytest.approx(1.0)

sklearn/model_selection/_classification_threshold.py

Lines changed: 4 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,10 @@
1818
check_scoring,
1919
get_scorer_names,
2020
)
21-
from ..metrics._scorer import _BaseScorer
21+
from ..metrics._scorer import (
22+
_CurveScorer,
23+
_threshold_scores_to_class_labels,
24+
)
2225
from ..utils import _safe_indexing
2326
from ..utils._param_validation import HasMethods, Interval, RealNotInt, StrOptions
2427
from ..utils._response import _get_response_values_binary
@@ -57,18 +60,6 @@ def check(self):
5760
return check
5861

5962

60-
def _threshold_scores_to_class_labels(y_score, threshold, classes, pos_label):
61-
"""Threshold `y_score` and return the associated class labels."""
62-
if pos_label is None:
63-
map_thresholded_score_to_label = np.array([0, 1])
64-
else:
65-
pos_label_idx = np.flatnonzero(classes == pos_label)[0]
66-
neg_label_idx = np.flatnonzero(classes != pos_label)[0]
67-
map_thresholded_score_to_label = np.array([neg_label_idx, pos_label_idx])
68-
69-
return classes[map_thresholded_score_to_label[(y_score >= threshold).astype(int)]]
70-
71-
7263
class BaseThresholdClassifier(ClassifierMixin, MetaEstimatorMixin, BaseEstimator):
7364
"""Base class for binary classifiers that set a non-default decision threshold.
7465
@@ -429,111 +420,6 @@ def get_metadata_routing(self):
429420
return router
430421

431422

432-
class _CurveScorer(_BaseScorer):
433-
"""Scorer taking a continuous response and output a score for each threshold.
434-
435-
Parameters
436-
----------
437-
score_func : callable
438-
The score function to use. It will be called as
439-
`score_func(y_true, y_pred, **kwargs)`.
440-
441-
sign : int
442-
Either 1 or -1 to returns the score with `sign * score_func(estimator, X, y)`.
443-
Thus, `sign` defined if higher scores are better or worse.
444-
445-
kwargs : dict
446-
Additional parameters to pass to the score function.
447-
448-
thresholds : int or array-like
449-
Related to the number of decision thresholds for which we want to compute the
450-
score. If an integer, it will be used to generate `thresholds` thresholds
451-
uniformly distributed between the minimum and maximum predicted scores. If an
452-
array-like, it will be used as the thresholds.
453-
454-
response_method : str
455-
The method to call on the estimator to get the response values.
456-
"""
457-
458-
def __init__(self, score_func, sign, kwargs, thresholds, response_method):
459-
super().__init__(
460-
score_func=score_func,
461-
sign=sign,
462-
kwargs=kwargs,
463-
response_method=response_method,
464-
)
465-
self._thresholds = thresholds
466-
467-
@classmethod
468-
def from_scorer(cls, scorer, response_method, thresholds):
469-
"""Create a continuous scorer from a normal scorer."""
470-
instance = cls(
471-
score_func=scorer._score_func,
472-
sign=scorer._sign,
473-
response_method=response_method,
474-
thresholds=thresholds,
475-
kwargs=scorer._kwargs,
476-
)
477-
# transfer the metadata request
478-
instance._metadata_request = scorer._get_metadata_request()
479-
return instance
480-
481-
def _score(self, method_caller, estimator, X, y_true, **kwargs):
482-
"""Evaluate predicted target values for X relative to y_true.
483-
484-
Parameters
485-
----------
486-
method_caller : callable
487-
Returns predictions given an estimator, method name, and other
488-
arguments, potentially caching results.
489-
490-
estimator : object
491-
Trained estimator to use for scoring.
492-
493-
X : {array-like, sparse matrix} of shape (n_samples, n_features)
494-
Test data that will be fed to estimator.predict.
495-
496-
y_true : array-like of shape (n_samples,)
497-
Gold standard target values for X.
498-
499-
**kwargs : dict
500-
Other parameters passed to the scorer. Refer to
501-
:func:`set_score_request` for more details.
502-
503-
Returns
504-
-------
505-
scores : ndarray of shape (thresholds,)
506-
The scores associated to each threshold.
507-
508-
potential_thresholds : ndarray of shape (thresholds,)
509-
The potential thresholds used to compute the scores.
510-
"""
511-
pos_label = self._get_pos_label()
512-
y_score = method_caller(
513-
estimator, self._response_method, X, pos_label=pos_label
514-
)
515-
516-
scoring_kwargs = {**self._kwargs, **kwargs}
517-
if isinstance(self._thresholds, Integral):
518-
potential_thresholds = np.linspace(
519-
np.min(y_score), np.max(y_score), self._thresholds
520-
)
521-
else:
522-
potential_thresholds = np.asarray(self._thresholds)
523-
score_thresholds = [
524-
self._sign
525-
* self._score_func(
526-
y_true,
527-
_threshold_scores_to_class_labels(
528-
y_score, th, estimator.classes_, pos_label
529-
),
530-
**scoring_kwargs,
531-
)
532-
for th in potential_thresholds
533-
]
534-
return np.array(score_thresholds), potential_thresholds
535-
536-
537423
def _fit_and_score_over_thresholds(
538424
classifier,
539425
X,

0 commit comments

Comments
 (0)
0