-
-
Notifications
You must be signed in to change notification settings - Fork 25.9k
FEA Adds decision_threshold_curve
function
#31338
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
172ac47
d038e11
322eccf
7dbbec5
2a0c6b3
fbb9b9b
acb94be
a5cd201
253b3e2
cb5fee1
9e45e2e
ad901a2
1a4ce1b
9b4febb
119db53
347f524
be893c8
bd1e64f
0318950
efd6d72
1e500c0
10ebc90
dfa66a5
1fb1c13
e7bb2a7
5a8f0c5
4fab2a3
48a0055
fbf1d2e
98873e6
f1dc0e8
0284251
d46bc1a
0a06199
a424c3e
09df5ae
bd256a8
a386ded
eae5846
5e7fd49
d23a25f
8efaaca
6e0b5e0
0ac8d1d
8a8e240
6715d6d
bf887aa
edf99f0
9c624c5
7643942
13c8545
b589505
577ea24
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
.. currentmodule:: sklearn.model_selection | ||
|
||
.. _TunedThresholdClassifierCV: | ||
.. _threshold_tunning: | ||
|
||
================================================== | ||
Tuning the decision threshold for class prediction | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @glemaitre Can't comment where this is relevant (after L49), but I wonder if it would be interesting to add another scenario where threshold tunning may be of interest - imbalanced datasets? |
||
|
@@ -63,7 +63,7 @@ Post-tuning the decision threshold | |
|
||
One solution to address the problem stated in the introduction is to tune the decision | ||
threshold of the classifier once the model has been trained. The | ||
:class:`~sklearn.model_selection.TunedThresholdClassifierCV` tunes this threshold using | ||
:class:`TunedThresholdClassifierCV` tunes this threshold using | ||
an internal cross-validation. The optimum threshold is chosen to maximize a given | ||
metric. | ||
|
||
|
@@ -80,6 +80,15 @@ a utility metric defined by the business (in this case an insurance company). | |
:target: ../auto_examples/model_selection/plot_cost_sensitive_learning.html | ||
:align: center | ||
|
||
Plotting metric across thresholds | ||
--------------------------------- | ||
|
||
The final plot above shows the value of a utility metric of interest across a range | ||
of threshold values. This can be a useful visualization when tuning decision | ||
threshold, especially if there is more than one metric of interest. The | ||
:func:`decision_threshold_curve` allows you to easily generate such plots as it | ||
computes the values required for each axis, scores per threshold and threshold values. | ||
|
||
Options to tune the decision threshold | ||
-------------------------------------- | ||
|
||
|
@@ -120,7 +129,7 @@ a meaningful metric for their use case. | |
Important notes regarding the internal cross-validation | ||
------------------------------------------------------- | ||
|
||
By default :class:`~sklearn.model_selection.TunedThresholdClassifierCV` uses a 5-fold | ||
By default :class:`TunedThresholdClassifierCV` uses a 5-fold | ||
stratified cross-validation to tune the decision threshold. The parameter `cv` allows to | ||
control the cross-validation strategy. It is possible to bypass cross-validation by | ||
setting `cv="prefit"` and providing a fitted classifier. In this case, the decision | ||
|
@@ -143,7 +152,7 @@ Manually setting the decision threshold | |
|
||
The previous sections discussed strategies to find an optimal decision threshold. It is | ||
also possible to manually set the decision threshold using the class | ||
:class:`~sklearn.model_selection.FixedThresholdClassifier`. In case that you don't want | ||
:class:`FixedThresholdClassifier`. In case that you don't want | ||
to refit the model when calling `fit`, wrap your sub-estimator with a | ||
:class:`~sklearn.frozen.FrozenEstimator` and do | ||
``FixedThresholdClassifier(FrozenEstimator(estimator), ...)``. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
- :func:`inspection.metric_threshold_curve` has been added to | ||
assess performance over thresholds by computing a threshold-dependent | ||
metric of interest per threshold. By | ||
:user:`Carlo Lemos <vitaliset>` and :user:`Lucy Liu <lucyleeow>`. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
"""Metric per threshold curve to assess binary classification performance. | ||
|
||
Compute metric per threshold, over a range of threshold values to aid visualization | ||
of threshold-dependent metric behavior. | ||
|
||
Utilizes `_CurveScorer` methods to do all the computation. | ||
""" | ||
|
||
# Authors: The scikit-learn developers | ||
# SPDX-License-Identifier: BSD-3-Clause | ||
|
||
from numbers import Integral | ||
|
||
from ..utils._param_validation import Interval, validate_params | ||
|
||
|
||
@validate_params( | ||
{ | ||
"score_func": [callable], | ||
"y_true": ["array-like"], | ||
"y_score": ["array-like"], | ||
"thresholds": [ | ||
Interval(Integral, 2, None, closed="left"), | ||
"array-like", | ||
], | ||
"greater_is_better": ["boolean"], | ||
"labels": ["array-like", None], | ||
}, | ||
prefer_skip_nested_validation=True, | ||
) | ||
def decision_threshold_curve( | ||
score_func, | ||
y_true, | ||
y_score, | ||
# Should below 2 have a default value? | ||
thresholds=20, | ||
greater_is_better=True, | ||
labels=None, | ||
**kwargs, | ||
): | ||
"""Compute threshold-dependent metric of interest per threshold. | ||
|
||
Note: this implementation is restricted to the binary classification task. | ||
|
||
Read more in the :ref:`User Guide <threshold_tunning>`. | ||
|
||
.. versionadded:: 1.8 | ||
|
||
Parameters | ||
---------- | ||
score_func : callable | ||
The score function to use. It will be called as | ||
`score_func(y_true, y_pred, **kwargs)`. | ||
|
||
y_true : array-like of shape (n_samples,) | ||
Ground truth (correct) target labels. | ||
|
||
y_score : array-like of shape (n_samples,) | ||
Continuous response scores. | ||
|
||
thresholds : int or array-like, default=20 | ||
Specifies number of decision thresholds to compute score for. If an integer, | ||
it will be used to generate `thresholds` thresholds uniformly distributed | ||
between the minimum and maximum of `y_score`. If an array-like, it will be | ||
used as the thresholds. | ||
|
||
greater_is_better : bool, default=True | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Again, this is consistent, with term used in other metrics, so avoided using "sign". |
||
Whether `score_func` is a score function (default), meaning high is | ||
good, or a loss function, meaning low is good. In the latter case, the | ||
the output of `score_func` will be sign-flipped. | ||
|
||
labels : array-like, default=None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
Class labels. If `None`, inferred from `y_true`. | ||
TODO: used `labels` instead of `classes` to be consistent with other metrics. | ||
|
||
**kwargs : dict | ||
Parameters to pass to `score_func`. | ||
|
||
Returns | ||
------- | ||
score_thresholds : ndarray of shape (n_thresholds,) | ||
The scores associated with each threshold. | ||
|
||
thresholds : ndarray of shape (n_thresholds,) | ||
The thresholds used to compute the scores. | ||
|
||
See Also | ||
-------- | ||
precision_recall_curve : Compute precision-recall pairs for different | ||
probability thresholds. | ||
det_curve : Compute error rates for different probability thresholds. | ||
roc_curve : Compute Receiver operating characteristic (ROC) curve. | ||
|
||
Examples | ||
-------- | ||
>>> import numpy as np | ||
>>> from sklearn.metrics import accuracy_score, decision_threshold_curve | ||
>>> y_true = np.array([0, 0, 1, 1]) | ||
>>> y_score = np.array([0.1, 0.4, 0.35, 0.8]) | ||
>>> score_thresholds, thresholds = decision_threshold_curve( | ||
... accuracy_score, y_true, y_score, thresholds=4) | ||
>>> thresholds | ||
array([0.1, 0.33333333, 0.56666667, 0.8 ]) | ||
>>> score_thresholds | ||
array([0.5, 0.75, 0.75, 0.75]) | ||
""" | ||
# To prevent circular import | ||
from ._scorer import _CurveScorer | ||
|
||
sign = 1 if greater_is_better else -1 | ||
curve_scorer = _CurveScorer(score_func, sign, {}, thresholds) | ||
return curve_scorer._scores_from_predictions( | ||
y_true, | ||
y_score, | ||
labels, | ||
**kwargs, | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,6 +30,7 @@ | |
from ..utils import Bunch | ||
from ..utils._param_validation import HasMethods, Hidden, StrOptions, validate_params | ||
from ..utils._response import _get_response_values | ||
from ..utils._unique import cached_unique | ||
from ..utils.metadata_routing import ( | ||
MetadataRequest, | ||
MetadataRouter, | ||
|
@@ -1071,23 +1072,24 @@ | |
`score_func(y_true, y_pred, **kwargs)`. | ||
|
||
sign : int | ||
Either 1 or -1 to returns the score with `sign * score_func(estimator, X, y)`. | ||
Thus, `sign` defined if higher scores are better or worse. | ||
Either 1 or -1. Score is returned as `sign * score_func(estimator, X, y)`. | ||
Thus, `sign` defines whether higher scores are better or worse. | ||
|
||
kwargs : dict | ||
Additional parameters to pass to the score function. | ||
|
||
thresholds : int or array-like | ||
Related to the number of decision thresholds for which we want to compute the | ||
score. If an integer, it will be used to generate `thresholds` thresholds | ||
uniformly distributed between the minimum and maximum predicted scores. If an | ||
array-like, it will be used as the thresholds. | ||
Specifies number of decision thresholds to compute score for. If an integer, | ||
it will be used to generate `thresholds` thresholds uniformly distributed | ||
between the minimum and maximum of `y_score`. If an array-like, it will be | ||
used as the thresholds. | ||
|
||
response_method : str | ||
response_method : str, default=None | ||
The method to call on the estimator to get the response values. | ||
If set to `None`, the `_scores_from_estimator` method cannot be used. | ||
""" | ||
|
||
def __init__(self, score_func, sign, kwargs, thresholds, response_method): | ||
def __init__(self, score_func, sign, kwargs, thresholds, response_method=None): | ||
super().__init__( | ||
score_func=score_func, | ||
sign=sign, | ||
|
@@ -1110,8 +1112,75 @@ | |
instance._metadata_request = scorer._get_metadata_request() | ||
return instance | ||
|
||
def _scores_from_predictions( | ||
self, | ||
y_true, | ||
y_score, | ||
classes=None, | ||
**kwargs, | ||
): | ||
"""Computes scores per threshold, given continuous response and true labels. | ||
|
||
Parameters | ||
---------- | ||
y_true : array-like of shape (n_samples,) | ||
Ground truth (correct) target labels. | ||
|
||
y_score : array-like of shape (n_samples,) | ||
Continuous response scores. | ||
|
||
classes: array-like, default=None | ||
Class labels. If `None`, inferred from `y_true`. | ||
|
||
**kwargs : dict | ||
Parameters to pass to `self.score_func`. | ||
|
||
Returns | ||
------- | ||
score_thresholds : ndarray of shape (thresholds,) | ||
The scores associated with each threshold. | ||
|
||
thresholds : ndarray of shape (thresholds,) | ||
The thresholds used to compute the scores. | ||
""" | ||
# This could also be done in `decision_threshold_curve`, not sure which | ||
# is better | ||
y_true_unique = cached_unique(y_true) | ||
if classes is None: | ||
classes = y_true_unique | ||
# not sure if this separate error msg needed. | ||
# there is the possibility that set(classes) != set(y_true_unique) fails | ||
# because `y_true` only contains one class. | ||
if len(y_true_unique) == 1: | ||
raise ValueError("`y_true` only contains one class label.") | ||
if set(classes) != set(y_true_unique): | ||
raise ValueError( | ||
f"`classes` ({classes}) is not equal to the unique values found in " | ||
f"`y_true` ({y_true_unique})." | ||
) | ||
Comment on lines
+1146
to
+1160
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These checks could be done in |
||
|
||
if isinstance(self._thresholds, Integral): | ||
potential_thresholds = np.linspace( | ||
np.min(y_score), np.max(y_score), self._thresholds | ||
) | ||
else: | ||
potential_thresholds = np.asarray(self._thresholds) | ||
|
||
score_thresholds = [ | ||
self._sign | ||
* self._score_func( | ||
y_true, | ||
_threshold_scores_to_class_labels( | ||
y_score, th, classes, self._get_pos_label() | ||
), | ||
**{**self._kwargs, **kwargs}, | ||
) | ||
for th in potential_thresholds | ||
] | ||
return np.array(score_thresholds), potential_thresholds | ||
|
||
def _score(self, method_caller, estimator, X, y_true, **kwargs): | ||
"""Evaluate predicted target values for X relative to y_true. | ||
"""Computes scores per threshold, given estimator, X and true labels. | ||
|
||
Parameters | ||
---------- | ||
|
@@ -1140,27 +1209,21 @@ | |
potential_thresholds : ndarray of shape (thresholds,) | ||
The potential thresholds used to compute the scores. | ||
""" | ||
pos_label = self._get_pos_label() | ||
if self._response_method is None: | ||
raise ValueError( | ||
"This method cannot be used when `_CurveScorer` initialized with " | ||
"`response_method=None`" | ||
) | ||
|
||
y_score = method_caller( | ||
estimator, self._response_method, X, pos_label=pos_label | ||
estimator, self._response_method, X, pos_label=self._get_pos_label() | ||
) | ||
|
||
scoring_kwargs = {**self._kwargs, **kwargs} | ||
if isinstance(self._thresholds, Integral): | ||
potential_thresholds = np.linspace( | ||
np.min(y_score), np.max(y_score), self._thresholds | ||
) | ||
else: | ||
potential_thresholds = np.asarray(self._thresholds) | ||
score_thresholds = [ | ||
self._sign | ||
* self._score_func( | ||
y_true, | ||
_threshold_scores_to_class_labels( | ||
y_score, th, estimator.classes_, pos_label | ||
), | ||
**scoring_kwargs, | ||
) | ||
for th in potential_thresholds | ||
] | ||
return np.array(score_thresholds), potential_thresholds | ||
# why 'potential' ? | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just for my education, why use the term "potential", in "potential_thresholds". Is it because there a possibility that a threshold is redundant because the predicted labels are the same for adjacent thresholds? |
||
score_thresholds, potential_thresholds = self._scores_from_predictions( | ||
y_true, | ||
y_score, | ||
estimator.classes_, | ||
**kwargs, | ||
) | ||
return score_thresholds, potential_thresholds |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I referenced this page starting here for
decision_threshold_curve
because I thought this first paragraph was appropriate, for context, not just the section I added. Not 100% on this though and happy to change