8000 FEA Adds `decision_threshold_curve` function by lucyleeow · Pull Request #31338 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

FEA Adds decision_threshold_curve function #31338

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 53 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
172ac47
initial proposal with preliminary tests
vitaliset Feb 18, 2023
d038e11
removing check that validate_params already does
vitaliset Feb 18, 2023
322eccf
changelog and linting from CI
vitaliset Feb 18, 2023
7dbbec5
trying to resolve doc related ci
vitaliset Feb 18, 2023
2a0c6b3
duplicate label
vitaliset Feb 18, 2023
fbb9b9b
docstring example import error
vitaliset Feb 18, 2023
acb94be
docstring typo
vitaliset Feb 18, 2023
a5cd201
docstring typo
vitaliset Feb 18, 2023
253b3e2
docstring typo
vitaliset Feb 18, 2023
cb5fee1
docstring typo
vitaliset Feb 18, 2023
9e45e2e
change in doc order and typos
vitaliset Feb 18, 2023
ad901a2
removing example
vitaliset Feb 20, 2023
1a4ce1b
Merge branch 'main' into metric_threshold_curve
vitaliset May 14, 2023
9b4febb
Update import of _check_pos_label_consistency
vitaliset May 14, 2023
119db53
codecov
vitaliset May 14, 2023
347f524
Merge branch 'metric_threshold_curve' of https://github.com/vitaliset…
vitaliset May 14, 2023
be893c8
linting
vitaliset May 14, 2023
bd1e64f
correcting typo
vitaliset May 14, 2023
0318950
test typo
vitaliset May 14, 2023
efd6d72
add example again to check pytest
vitaliset May 14, 2023
1e500c0
Merge branch 'main' into metric_threshold_curve
vitaliset May 16, 2023
10ebc90
Merge remote-tracking branch 'origin/main' into pr/vitaliset/25639
glemaitre May 20, 2024
dfa66a5
fixing imports
glemaitre May 20, 2024
1fb1c13
towards glemaitre suggestions
vitaliset May 22, 2024
e7bb2a7
applying black suggestions
vitaliset Jun 8, 2024
5a8f0c5
update extra stuff for consistency
vitaliset Jun 8, 2024
4fab2a3
removing doc files for now as we need to adapt to pr 29038
vitaliset Jun 8, 2024
48a0055
Merge branch 'main' into metric_threshold_curve
vitaliset Jun 8, 2024
fbf1d2e
Merge branch 'main' into metric_threshold_curve
vitaliset Jul 25, 2024
98873e6
Merge branch 'main' into metric_threshold_curve
vitaliset Jul 30, 2024
f1dc0e8
Update _decision_threshold.py to add authors
vitaliset Jul 30, 2024
0284251
towards using _curvescorer in the new decision threshold function. mi…
vitaliset Jul 30, 2024
d46bc1a
correcting circular dependences
vitaliset Jul 30, 2024
0a06199
Merge branch 'main' into metric_threshold_curve
vitaliset Aug 23, 2024
a424c3e
trying to solve the circular imports. looks like the order of init is…
vitaliset Sep 30, 2024
09df5ae
merge main
lucyleeow May 5, 2025
bd256a8
first commit, original tests pass
lucyleeow May 8, 2025
a386ded
min test to check func runs
lucyleeow May 8, 2025
eae5846
nits
lucyleeow May 8, 2025
5e7fd49
amend whats new;
lucyleeow May 8, 2025
d23a25f
Merge branch 'main' into metric_threshold_curve
lucyleeow May 8, 2025
8efaaca
revert from scorer order
lucyleeow May 8, 2025
6e0b5e0
amend to method
lucyleeow May 8, 2025
0ac8d1d
fix param valid, use greater_is_better
lucyleeow May 8, 2025
8a8e240
fix example
lucyleeow May 12, 2025
6715d6d
Merge branch 'main' into metric_threshold_curve
lucyleeow May 12, 2025
bf887aa
fix examples
lucyleeow May 12, 2025
edf99f0
rm pos label as param
lucyleeow May 14, 2025
9c624c5
typo
lucyleeow May 14, 2025
7643942
pos label fixes
lucyleeow May 14, 2025
13c8545
fix kwargs
lucyleeow May 15, 2025
b589505
add user guide section
lucyleeow May 15, 2025
577ea24
change ref label
lucyleeow May 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions doc/modules/classification_threshold.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
.. currentmodule:: sklearn.model_selection

.. _TunedThresholdClassifierCV:
.. _threshold_tunning:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I referenced this page starting here for decision_threshold_curve because I thought this first paragraph was appropriate, for context, not just the section I added. Not 100% on this though and happy to change


==================================================
Tuning the decision threshold for class prediction
Copy link
Member Author
@lucyleeow lucyleeow May 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@glemaitre Can't comment where this is relevant (after L49), but I wonder if it would be interesting to add another scenario where threshold tunning may be of interest - imbalanced datasets?

Expand Down Expand Up @@ -63,7 +63,7 @@ Post-tuning the decision threshold

One solution to address the problem stated in the introduction is to tune the decision
threshold of the classifier once the model has been trained. The
:class:`~sklearn.model_selection.TunedThresholdClassifierCV` tunes this threshold using
:class:`TunedThresholdClassifierCV` tunes this threshold using
an internal cross-validation. The optimum threshold is chosen to maximize a given
metric.

Expand All @@ -80,6 +80,15 @@ a utility metric defined by the business (in this case an insurance company).
:target: ../auto_examples/model_selection/plot_cost_sensitive_learning.html
:align: center

Plotting metric across thresholds
---------------------------------

The final plot above shows the value of a utility metric of interest across a range
of threshold values. This can be a useful visualization when tuning decision
threshold, especially if there is more than one metric of interest. The
:func:`decision_threshold_curve` allows you to easily generate such plots as it
computes the values required for each axis, scores per threshold and threshold values.

Options to tune the decision threshold
--------------------------------------

Expand Down Expand Up @@ -120,7 +129,7 @@ a meaningful metric for their use case.
Important notes regarding the internal cross-validation
-------------------------------------------------------

By default :class:`~sklearn.model_selection.TunedThresholdClassifierCV` uses a 5-fold
By default :class:`TunedThresholdClassifierCV` uses a 5-fold
stratified cross-validation to tune the decision threshold. The parameter `cv` allows to
control the cross-validation strategy. It is possible to bypass cross-validation by
setting `cv="prefit"` and providing a fitted classifier. In this case, the decision
Expand All @@ -143,7 +152,7 @@ Manually setting the decision threshold

The previous sections discussed strategies to find an optimal decision threshold. It is
also possible to manually set the decision threshold using the class
:class:`~sklearn.model_selection.FixedThresholdClassifier`. In case that you don't want
:class:`FixedThresholdClassifier`. In case that you don't want
to refit the model when calling `fit`, wrap your sub-estimator with a
:class:`~sklearn.frozen.FrozenEstimator` and do
``FixedThresholdClassifier(FrozenEstimator(estimator), ...)``.
Expand Down
2 changes: 1 addition & 1 deletion doc/modules/model_evaluation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ The most common decisions are done on binary classification tasks, where the res
probability of rain a decision is made on how to act (whether to take mitigating
measures like an umbrella or not).
For classifiers, this is what :term:`predict` returns.
See also :ref:`TunedThresholdClassifierCV`.
See also :ref:`threshold_tunning`.
There are many scoring functions which measure different aspects of such a
decision, most of them are covered with or derived from the
:func:`metrics.confusion_matrix`.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
- :func:`inspection.metric_threshold_curve` has been added to
assess performance over thresholds by computing a threshold-dependent
metric of interest per threshold. By
:user:`Carlo Lemos <vitaliset>` and :user:`Lucy Liu <lucyleeow>`.
2 changes: 2 additions & 0 deletions sklearn/metrics/__init__.py
B41A
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
recall_score,
zero_one_loss,
)
from ._decision_threshold import decision_threshold_curve
from ._dist_metrics import DistanceMetric
from ._plot.confusion_matrix import ConfusionMatrixDisplay
from ._plot.det_curve import DetCurveDisplay
Expand Down Expand Up @@ -124,6 +125,7 @@
"d2_tweedie_score",
"davies_bouldin_score",
"dcg_score",
"decision_threshold_curve",
"det_curve",
"euclidean_distances",
"explained_variance_score",
Expand Down
117 changes: 117 additions & 0 deletions sklearn/metrics/_decision_threshold.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
"""Metric per threshold curve to assess binary classification performance.

Compute metric per threshold, over a range of threshold values to aid visualization
of threshold-dependent metric behavior.

Utilizes `_CurveScorer` methods to do all the computation.
"""

# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause

from numbers import Integral

from ..utils._param_validation import Interval, validate_params


@validate_params(
{
"score_func": [callable],
"y_true": ["array-like"],
"y_score": ["array-like"],
"thresholds": [
Interval(Integral, 2, None, closed="left"),
"array-like",
],
"greater_is_better": ["boolean"],
"labels": ["array-like", None],
},
prefer_skip_nested_validation=True,
)
def decision_threshold_curve(
score_func,
y_true,
y_score,
# Should below 2 have a default value?
thresholds=20,
greater_is_better=True,
labels=None,
**kwargs,
):
"""Compute threshold-dependent metric of interest per threshold.

Note: this implementation is restricted to the binary classification task.

Read more in the :ref:`User Guide <threshold_tunning>`.

.. versionadded:: 1.8

Parameters
----------
score_func : callable
The score function to use. It will be called as
`score_func(y_true, y_pred, **kwargs)`.

y_true : array-like of shape (n_samples,)
Ground truth (correct) target labels.

y_score : array-like of shape (n_samples,)
Continuous response scores.

thresholds : int or array-like, default=20
Specifies number of decision thresholds to compute score for. If an integer,
it will be used to generate `thresholds` thresholds uniformly distributed
between the minimum and maximum of `y_score`. If an array-like, it will be
used as the thresholds.

greater_is_better : bool, default=True
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, this is consistent, with term used in other metrics, so avoided using "sign".

Whether `score_func` is a score function (default), meaning high is
good, or a loss function, meaning low is good. In the latter case, the
the output of `score_func` will be sign-flipped.

labels : array-like, default=None
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_CurveScorer uses the term "classes" but "labels" is consistent with what is used for other classification metrics, so chose this.

Class labels. If `None`, inferred from `y_true`.
TODO: used `labels` instead of `classes` to be consistent with other metrics.

**kwargs : dict
Parameters to pass to `score_func`.

Returns
-------
score_thresholds : ndarray of shape (n_thresholds,)
The scores associated with each threshold.

thresholds : ndarray of shape (n_thresholds,)
The thresholds used to compute the scores.

See Also
--------
precision_recall_curve : Compute precision-recall pairs for different
probability thresholds.
det_curve : Compute error rates for different probability thresholds.
roc_curve : Compute Receiver operating characteristic (ROC) curve.

Examples
--------
>>> import numpy as np
>>> from sklearn.metrics import accuracy_score, decision_threshold_curve
>>> y_true = np.array([0, 0, 1, 1])
>>> y_score = np.array([0.1, 0.4, 0.35, 0.8])
>>> score_thresholds, thresholds = decision_threshold_curve(
... accuracy_score, y_true, y_score, thresholds=4)
>>> thresholds
array([0.1, 0.33333333, 0.56666667, 0.8 ])
>>> score_thresholds
array([0.5, 0.75, 0.75, 0.75])
"""
# To prevent circular import
from ._scorer import _CurveScorer

Check warning on line 108 in sklearn/metrics/_decision_threshold.py

View check run for this annotation

Codecov / codecov/patch

sklearn/metrics/_decision_threshold.py#L108

Added line #L108 was not covered by tests

sign = 1 if greater_is_better else -1
curve_scorer = _CurveScorer(score_func, sign, {}, thresholds)
return curve_scorer._scores_from_predictions(

Check warning on line 112 in sklearn/metrics/_decision_threshold.py

View check run for this annotation

Codecov / codecov/patch

sklearn/metrics/_decision_threshold.py#L110-L112

Added lines #L110 - L112 were not covered by tests
y_true,
y_score,
labels,
**kwargs,
)
123 changes: 93 additions & 30 deletions sklearn/metrics/_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from ..utils import Bunch
from ..utils._param_validation import HasMethods, Hidden, StrOptions, validate_params
from ..utils._response import _get_response_values
from ..utils._unique import cached_unique
from ..utils.metadata_routing import (
MetadataRequest,
MetadataRouter,
Expand Down Expand Up @@ -1071,23 +1072,24 @@
`score_func(y_true, y_pred, **kwargs)`.

sign : int
Either 1 or -1 to returns the score with `sign * score_func(estimator, X, y)`.
Thus, `sign` defined if higher scores are better or worse.
Either 1 or -1. Score is returned as `sign * score_func(estimator, X, y)`.
Thus, `sign` defines whether higher scores are better or worse.

kwargs : dict
Additional parameters to pass to the score function.

thresholds : int or array-like
Related to the number of decision thresholds for which we want to compute the
score. If an integer, it will be used to generate `thresholds` thresholds
uniformly distributed between the minimum and maximum predicted scores. If an
array-like, it will be used as the thresholds.
Specifies number of decision thresholds to compute score for. If an integer,
it will be used to generate `thresholds` thresholds uniformly distributed
between the minimum and maximum of `y_score`. If an array-like, it will be
used as the thresholds.

response_method : str
response_method : str, default=None
The method to call on the estimator to get the response values.
If set to `None`, the `_scores_from_estimator` method cannot be used.
"""

def __init__(self, score_func, sign, kwargs, thresholds, response_method):
def __init__(self, score_func, sign, kwargs, thresholds, response_method=None):
super().__init__(
score_func=score_func,
sign=sign,
Expand All @@ -1110,8 +1112,75 @@
instance._metadata_request = scorer._get_metadata_request()
return instance

def _scores_from_predictions(
self,
y_true,
y_score,
classes=None,
**kwargs,
):
"""Computes scores per threshold, given continuous response and true labels.

Parameters
----------
y_true : array-like of shape (n_samples,)
Ground truth (correct) target labels.

y_score : array-like of shape (n_samples,)
Continuous response scores.

classes: array-like, default=None
Class labels. If `None`, inferred from `y_true`.

**kwargs : dict
Parameters to pass to `self.score_func`.

Returns
-------
score_thresholds : ndarray of shape (thresholds,)
The scores associated with each threshold.

thresholds : ndarray of shape (thresholds,)
The thresholds used to compute the scores.
"""
# This could also be done in `decision_threshold_curve`, not sure which
# is better
y_true_unique = cached_unique(y_true)
if classes is None:
classes = y_true_unique
# not sure if this separate error msg needed.
# there is the possibility that set(classes) != set(y_true_unique) fails
# because `y_true` only contains one class.
if len(y_true_unique) == 1:
raise ValueError("`y_true` only contains one class label.")

Check warning on line 1155 in sklearn/metrics/_scorer.py

View check run for this annotation

Codecov / codecov/patch

sklearn/metrics/_scorer.py#L1155

Added line #L1155 was not covered by tests
if set(classes) != set(y_true_unique):
raise ValueError(

Check warning on line 1157 in sklearn/metrics/_scorer.py

View check run for this annotation

Codecov / codecov/patch

sklearn/metrics/_scorer.py#L1157

Added line #L1157 was not covered by tests
f"`classes` ({classes}) is not equal to the unique values found in "
f"`y_true` ({y_true_unique})."
)
Comment on lines +1146 to +1160
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These checks could be done in decision_threshold_curve instead, not sure which is better.


if isinstance(self._thresholds, Integral):
potential_thresholds = np.linspace(
np.min(y_score), np.max(y_score), self._thresholds
)
else:
potential_thresholds = np.asarray(self._thresholds)

score_thresholds = [
self._sign
* self._score_func(
y_true,
_threshold_scores_to_class_labels(
y_score, th, classes, self._get_pos_label()
),
**{**self._kwargs, **kwargs},
)
for th in potential_thresholds
]
return np.array(score_thresholds), potential_thresholds

def _score(self, method_caller, estimator, X, y_true, **kwargs):
"""Evaluate predicted target values for X relative to y_true.
"""Computes scores per threshold, given estimator, X and true labels.

Parameters
----------
Expand Down Expand Up @@ -1140,27 +1209,21 @@
potential_thresholds : ndarray of shape (thresholds,)
The potential thresholds used to compute the scores.
"""
pos_label = self._get_pos_label()
if self._response_method is None:
raise ValueError(

Check warning on line 1213 in sklearn/metrics/_scorer.py

View check run for this annotation

Codecov / codecov/patch

sklearn/metrics/_scorer.py#L1213

Added line #L1213 was not covered by tests
"This method cannot be used when `_CurveScorer` initialized with "
"`response_method=None`"
)

y_score = method_caller(
estimator, self._response_method, X, pos_label=pos_label
estimator, self._response_method, X, pos_label=self._get_pos_label()
)

scoring_kwargs = {**self._kwargs, **kwargs}
if isinstance(self._thresholds, Integral):
potential_thresholds = np.linspace(
np.min(y_score), np.max(y_score), self._thresholds
)
else:
potential_thresholds = np.asarray(self._thresholds)
score_thresholds = [
self._sign
* self._score_func(
y_true,
_threshold_scores_to_class_labels(
y_score, th, estimator.classes_, pos_label
),
**scoring_kwargs,
)
for th in potential_thresholds
]
return np.array(score_thresholds), potential_thresholds
# why 'potential' ?
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just for my education, why use the term "potential", in "potential_thresholds". Is it because there a possibility that a threshold is redundant because the predicted labels are the same for adjacent thresholds?

score_thresholds, potential_thresholds = self._scores_from_predictions(
y_true,
y_score,
estimator.classes_,
**kwargs,
)
return score_thresholds, potential_thresholds
Loading
0