8000 FIX thresholds should not exceed 1.0 with probabilities in `roc_curve… · scikit-learn/scikit-learn@31c8c75 · GitHub
[go: up one dir, main page]

Skip to content

Commit 31c8c75

Browse files
glemaitreogrisel
andauthored
FIX thresholds should not exceed 1.0 with probabilities in roc_curve (#26194)
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
1 parent 3d0df7b commit 31c8c75

File tree

4 files changed

+31
-7
lines changed

4 files changed

+31
-7
lines changed

doc/modules/model_evaluation.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1360,7 +1360,7 @@ function::
13601360
>>> tpr
13611361
array([0. , 0.5, 0.5, 1. , 1. ])
13621362
>>> thresholds
1363-
array([1.8 , 0.8 , 0.4 , 0.35, 0.1 ])
1363+
array([ inf, 0.8 , 0.4 , 0.35, 0.1 ])
13641364

13651365
Compared to metrics such as the subset accuracy, the Hamming loss, or the
13661366
F1 score, ROC doesn't require optimizing a threshold for each label.

doc/whats_new/v1.3.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,11 @@ Changelog
406406
- |API| The `eps` parameter of the :func:`log_loss` has been deprecated and will be
407407
removed in 1.5. :pr:`25299` by :user:`Omar Salman <OmarManzoor>`.
408408

409+
- |Fix| In :func:`metrics.roc_curve`, use the threshold value `np.inf` instead of
410+
arbritrary `max(y_score) + 1`. This threshold is associated with the ROC curve point
411+
`tpr=0` and `fpr=0`.
412+
:pr:`26194` by :user:`Guillaume Lemaitre <glemaitre>`.
413+
409414
:mod:`sklearn.model_selection`
410415
..............................
411416

sklearn/metrics/_ranking.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1016,10 +1016,10 @@ def roc_curve(
10161016
Increasing true positive rates such that element `i` is the true
10171017
positive rate of predictions with score >= `thresholds[i]`.
10181018
1019-
thresholds : ndarray of shape = (n_thresholds,)
1019+
thresholds : ndarray of shape (n_thresholds,)
10201020
Decreasing thresholds on the decision function used to compute
10211021
fpr and tpr. `thresholds[0]` represents no instances being predicted
1022-
and is arbitrarily set to `max(y_score) + 1`.
1022+
and is arbitrarily set to `np.inf`.
10231023
10241024
See Also
10251025
--------
@@ -1036,6 +1036,10 @@ def roc_curve(
10361036
are reversed upon returning them to ensure they correspond to both ``fpr``
10371037
and ``tpr``, which are sorted in reversed order during their calculation.
10381038
1039+
An arbritrary threshold is added for the case `tpr=0` and `fpr=0` to
1040+
ensure that the curve starts at `(0, 0)`. This threshold corresponds to the
1041+
`np.inf`.
1042+
10391043
References
10401044
----------
10411045
.. [1] `Wikipedia entry for the Receiver operating characteristic
@@ -1056,7 +1060,7 @@ def roc_curve(
10561060
>>> tpr
10571061
array([0. , 0.5, 0.5, 1. , 1. ])
10581062
>>> thresholds
1059-
array([1.8 , 0.8 , 0.4 , 0.35, 0.1 ])
1063+
array([ inf, 0.8 , 0.4 , 0.35, 0.1 ])
10601064
"""
10611065
fps, tps, thresholds = _binary_clf_curve(
10621066
y_true, y_score, pos_label=pos_label, sample_weight=sample_weight
@@ -1083,7 +1087,8 @@ def roc_curve(
10831087
# to make sure that the curve starts at (0, 0)
10841088
tps = np.r_[0, tps]
10851089
fps = np.r_[0, fps]
1086-
thresholds = np.r_[thresholds[0] + 1, thresholds]
1090+
# get dtype of `y_score` even if it is an array-like
1091+
thresholds = np.r_[np.inf, thresholds]
10871092

10881093
if fps[-1] <= 0:
10891094
warnings.warn(

sklearn/metrics/tests/test_ranking.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -418,13 +418,13 @@ def test_roc_curve_drop_intermediate():
418418
y_true = [0, 0, 0, 0, 1, 1]
419419
y_score = [0.0, 0.2, 0.5, 0.6, 0.7, 1.0]
420420
tpr, fpr, thresholds = roc_curve(y_true, y_score, drop_intermediate=True)
421-
assert_array_almost_equal(thresholds, [2.0, 1.0, 0.7, 0.0])
421+
assert_array_almost_equal(thresholds, [np.inf, 1.0, 0.7, 0.0])
422422

423423
# Test dropping thresholds with repeating scores
424424
y_true = [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]
425425
y_score = [0.0, 0.1, 0.6, 0.6, 0.7, 0.8, 0.9, 0.6, 0.7, 0.8, 0.9, 0.9, 1.0]
426426
tpr, fpr, thresholds = roc_curve(y_true, y_score, drop_intermediate=True)
427-
assert_array_almost_equal(thresholds, [2.0, 1.0, 0.9, 0.7, 0.6, 0.0])
427+
assert_array_almost_equal(thresholds, [np.inf, 1.0, 0.9, 0.7, 0.6, 0.0])
428428

429429

430430
def test_roc_curve_fpr_tpr_increasing():
@@ -2199,3 +2199,17 @@ def test_ranking_metric_pos_label_types(metric, classes):
21992199
assert not np.isnan(metric_1).any()
22002200
assert not np.isnan(metric_2).any()
22012201
assert not np.isnan(thresholds).any()
2202+
2203+
2204+
def test_roc_curve_with_probablity_estimates(global_random_seed):
2205+
"""Check that thresholds do not exceed 1.0 when `y_score` is a probability
2206+
estimate.
2207+
2208+
Non-regression test for:
2209+
https://github.com/scikit-learn/scikit-learn/issues/26193
2210+
"""
2211+
rng = np.random.RandomState(global_random_seed)
2212+
y_true = rng.randint(0, 2, size=10)
2213+
y_score = rng.rand(10)
2214+
_, _, thresholds = roc_curve(y_true, y_score)
2215+
assert np.isinf(thresholds[0])

0 commit comments

Comments
 (0)
0