10000 [MRG] Ensure that ROC curve starts at (0, 0) (#10093) · jwjohnson314/scikit-learn@3ed115b · GitHub
[go: up one dir, main page]

Skip to content

Commit 3ed115b

Browse files
qinhanmin2014Jeremiah Johnson
authored andcommitted
[MRG] Ensure that ROC curve starts at (0, 0) (scikit-learn#10093)
1 parent ce35f2b commit 3ed115b

File tree

4 files changed

+28
-16
lines changed

4 files changed

+28
-16
lines changed

doc/modules/model_evaluation.rst

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1137,11 +1137,11 @@ Here is a small example of how to use the :func:`roc_curve` function::
11371137
>>> scores = np.array([0.1, 0.4, 0.35, 0.8])
11381138
>>> fpr, tpr, thresholds = roc_curve(y, scores, pos_label=2)
11391139
>>> fpr
1140-
array([ 0. , 0.5, 0.5, 1. ])
1140+
array([ 0. , 0. , 0.5, 0.5, 1. ])
11411141
>>> tpr
1142-
array([ 0.5, 0.5, 1. , 1. ])
1142+
array([ 0. , 0.5, 0.5, 1. , 1. ])
11431143
>>> thresholds
1144-
array([ 0.8 , 0.4 , 0.35, 0.1 ])
1144+
array([ 1.8 , 0.8 , 0.4 , 0.35, 0.1 ])
11451145

11461146
This figure shows an example of such an ROC curve:
11471147

doc/whats_new/v0.20.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ random sampling procedures.
1818
- :class:`decomposition.IncrementalPCA` in Python 2 (bug fix)
1919
- :class:`isotonic.IsotonicRegression` (bug fix)
2020
- :class:`metrics.roc_auc_score` (bug fix)
21+
- :class:`metrics.roc_curve` (bug fix)
2122
- :class:`neural_network.BaseMultilayerPerceptron` (bug fix)
2223
- :class:`neural_network.MLPRegressor` (bug fix)
2324
- :class:`neural_network.MLPClassifier` (bug fix)
@@ -160,6 +161,12 @@ Metrics
160161
- Fixed a bug due to floating point error in :func:`metrics.roc_auc_score` with
161162
non-integer sample weights. :issue:`9786` by :user:`Hanmin Qin <qinhanmin2014>`.
162163

164+
- Fixed a bug where :func:`metrics.roc_curve` sometimes starts on y-axis instead
165+
of (0, 0), which is inconsistent with the document and other implementations.
166+
Note that this will not influence the result from :func:`metrics.roc_auc_score`
167+
:issue:`10093` by :user:`alexryndin <alexryndin>`
168+
and :user:`Hanmin Qin <qinhanmin2014>`.
169+
163170
API changes summary
164171
-------------------
165172

sklearn/metrics/ranking.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,6 @@ def _binary_uninterpolated_average_precision(
217217
sample_weight=sample_weight)
218218

219219

220-
221220
def roc_auc_score(y_true, y_score, average="macro", sample_weight=None):
222221
"""Compute Area Under the Receiver Operating Characteristic Curve (ROC AUC)
223222
from prediction scores.
@@ -267,6 +266,9 @@ def roc_auc_score(y_true, y_score, average="macro", sample_weight=None):
267266
.. [1] `Wikipedia entry for the Receiver operating characteristic
268267
<https://en.wikipedia.org/wiki/Receiver_operating_characteristic>`_
269268
269+
.. [2] Fawcett T. An introduction to ROC analysis[J]. Pattern Recognition
270+
Letters, 2006, 27(8):861-874.
271+
270272
See also
271273
--------
272274
average_precision_score : Area under the precision-recall curve
@@ -541,6 +543,8 @@ def roc_curve(y_true, y_score, pos_label=None, sample_weight=None,
541543
.. [1] `Wikipedia entry for the Receiver operating characteristic
542544
<https://en.wikipedia.org/wiki/Receiver_operating_characteristic>`_
543545
546+
.. [2] Fawcett T. An introduction to ROC analysis[J]. Pattern Recognition
547+
Letters, 2006, 27(8):861-874.
544548
545549
Examples
546550
--------
@@ -550,11 +554,11 @@ def roc_curve(y_true, y_score, pos_label=None, sample_weight=None,
550554
>>> scores = np.array([0.1, 0.4, 0.35, 0.8])
551555
>>> fpr, tpr, thresholds = metrics.roc_curve(y, scores, pos_label=2)
552556
>>> fpr
553-
array([ 0. , 0.5, 0.5, 1. ])
557+
array([ 0. , 0. , 0.5, 0.5, 1. ])
554558
>>> tpr
555-
array([ 0.5, 0.5, 1. , 1. ])
559+
array([ 0. , 0.5, 0.5, 1. , 1. ])
556560
>>> thresholds
557-
array([ 0.8 , 0.4 , 0.35, 0.1 ])
561+
array([ 1.8 , 0.8 , 0.4 , 0.35, 0.1 ])
558562
559563
"""
560564
fps, tps, thresholds = _binary_clf_curve(
@@ -578,8 +582,9 @@ def roc_curve(y_true, y_score, pos_label=None, sample_weight=None,
578582
tps = tps[optimal_idxs]
579583
thresholds = thresholds[optimal_idxs]
580584

581-
if tps.size == 0 or fps[0] != 0:
585+
if tps.size == 0 or fps[0] != 0 or tps[0] != 0:
582586
# Add an extra threshold position if necessary
587+
# to make sure that the curve starts at (0, 0)
583588
tps = np.r_[0, tps]
584589
fps = np.r_[0, fps]
585590
thresholds = np.r_[thresholds[0] + 1, thresholds]

sklearn/metrics/tests/test_ranking.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -270,8 +270,8 @@ def test_roc_curve_toydata():
270270
y_score = [0, 1]
271271
tpr, fpr, _ = roc_curve(y_true, y_score)
272272
roc_auc = roc_auc_score(y_true, y_score)
273-
assert_array_almost_equal(tpr, [0, 1])
274-
assert_array_almost_equal(fpr, [1, 1])
273+
assert_array_almost_equal(tpr, [0, 0, 1])
274+
assert_array_almost_equal(fpr, [0, 1, 1])
275275
assert_almost_equal(roc_auc, 1.)
276276

277277
y_true = [0, 1]
@@ -294,8 +294,8 @@ def test_roc_curve_toydata():
294294
y_score = [1, 0]
295295
tpr, fpr, _ = roc_curve(y_true, y_score)
296296
roc_auc = roc_auc_score(y_true, y_score)
297-
assert_array_almost_equal(tpr, [0, 1])
298-
assert_array_almost_equal(fpr, [1, 1])
297+
assert_array_almost_equal(tpr, [0, 0, 1])
298+
assert_array_almost_equal(fpr, [0, 1, 1])
299299
assert_almost_equal(roc_auc, 1.)
300300

301301
y_true = [1, 0]
@@ -319,8 +319,8 @@ def test_roc_curve_toydata():
319319
# assert UndefinedMetricWarning because of no negative sample in y_true
320320
tpr, fpr, _ = assert_warns(UndefinedMetricWarning, roc_curve, y_true, y_score)
321321
assert_raises(ValueError, roc_auc_score, y_true, y_score)
322-
assert_array_almost_equal(tpr, [np.nan, np.nan])
323-
assert_array_almost_equal(fpr, [0.5, 1.])
322+
assert_array_almost_equal(tpr, [np.nan, np.nan, np.nan])
323+
assert_array_almost_equal(fpr, [0., 0.5, 1.])
324324

325325
# Multi-label classification task
326326
y_true = np.array([[0, 1], [0, 1]])
@@ -359,7 +359,7 @@ def test_roc_curve_drop_intermediate():
359359
y_true = [0, 0, 0, 0, 1, 1]
360360
y_score = [0., 0.2, 0.5, 0.6, 0.7, 1.0]
361361
tpr, fpr, thresholds = roc_curve(y_true, y_score, drop_intermediate=True)
362-
assert_array_almost_equal(thresholds, [1., 0.7, 0.])
362+
assert_array_almost_equal(thresholds, [2., 1., 0.7, 0.])
363363

364364
# Test dropping thresholds with repeating scores
365365
y_true = [0, 0, 0, 0, 0, 0, 0,
@@ -368,7 +368,7 @@ def test_roc_curve_drop_intermediate():
368368
0.6, 0.7, 0.8, 0.9, 0.9, 1.0]
369369
tpr, fpr, thresholds = roc_curve(y_true, y_score, drop_intermediate=True)
370370
assert_array_almost_equal(thresholds,
371-
[1.0, 0.9, 0.7, 0.6, 0.])
371+
[2.0, 1.0, 0.9, 0.7, 0.6, 0.])
372372

373373

374374
def test_roc_curve_fpr_tpr_increasing():

0 commit comments

Comments
 (0)
0