8000 [MRG] Ensure that ROC curve starts at (0, 0) (#10093) · maskani-moh/scikit-learn@ddb9d09 · GitHub
[go: up one dir, main page]

Skip to content

Commit ddb9d09

Browse files
qinhanmin2014maskani-moh
authored andcommitted
[MRG] Ensure that ROC curve starts at (0, 0) (scikit-learn#10093)
1 parent c263eb4 commit ddb9d09

File tree

4 files changed

+28
-25
lines changed

4 files changed

+28
-25
lines changed

doc/modules/model_evaluation.rst

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1138,11 +1138,11 @@ Here is a small example of how to use the :func:`roc_curve` function::
11381138
>>> scores = np.array([0.1, 0.4, 0.35, 0.8])
11391139
>>> fpr, tpr, thresholds = roc_curve(y, scores, pos_label=2)
11401140
>>> fpr
1141-
array([ 0. , 0.5, 0.5, 1. ])
1141+
array([ 0. , 0. , 0.5, 0.5, 1. ])
11421142
>>> tpr
1143-
array([ 0.5, 0.5, 1. , 1. ])
1143+
array([ 0. , 0.5, 0.5, 1. , 1. ])
11441144
>>> thresholds
1145-
array([ 0.8 , 0.4 , 0.35, 0.1 ])
1145+
array([ 1.8 , 0.8 , 0.4 , 0.35, 0.1 ])
11461146

11471147
This figure shows an example of such an ROC curve:
11481148

doc/whats_new/v0.20.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ random sampling procedures.
1818
- :class:`decomposition.IncrementalPCA` in Python 2 (bug fix)
1919
- :class:`isotonic.IsotonicRegression` (bug fix)
2020
- :class:`metrics.roc_auc_score` (bug fix)
21+
- :class:`metrics.roc_curve` (bug fix)
2122
- :class:`neural_network.BaseMultilayerPerceptron` (bug fix)
2223
- :class:`neural_network.MLPRegressor` (bug fix)
2324
- :class:`neural_network.MLPClassifier` (bug fix)
@@ -160,6 +161,12 @@ Metrics
160161
- Fixed a bug due to floating point error in :func:`metrics.roc_auc_score` with
161162
non-integer sample weights. :issue:`9786` by :user:`Hanmin Qin <qinhanmin2014>`.
162163

164+
- Fixed a bug where :func:`metrics.roc_curve` sometimes starts on y-axis instead
165+
of (0, 0), which is inconsistent with the document and other implementations.
166+
Note that this will not influence the result from :func:`metrics.roc_auc_score`
167+
:issue:`10093` by :user:`alexryndin <alexryndin>`
168+
and :user:`Hanmin Qin <qinhanmin2014>`.
169+
163170
API changes summary
164171
-------------------
165172

sklearn/metrics/ranking.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -227,18 +227,13 @@ def roc_auc_score(y_true, y_score, multiclass="ovr", average="macro",
227227
Parameters
228228
----------
229229
y_true : array, shape = [n_samples] or [n_samples, n_classes]
230-
<<<<<<< 68c38761be8d86c944012b67d8d84feb3606ce6f
231230
True binary labels in binary label indicators.
232231
The multiclass case expects shape = [n_samples] and labels
233232
with values from 0 to (n_classes-1), inclusive.
234-
=======
235-
True binary labels or binary label indicators.
236-
>>>>>>> [MRG+1] Completely support binary y_true in roc_auc_score (#9828)
237233
238234
y_score : array, shape = [n_samples] or [n_samples, n_classes]
239235
Target scores, can either be probability estimates of the positive
240236
class, confidence values, or non-thresholded measure of decisions
241-
<<<<<<< 68c38761be8d86c944012b67d8d84feb3606ce6f
242237
(as returned by "decision_function" on some classifiers).
243238
The multiclass case expects shape = [n_samples, n_classes]
244239
where the scores correspond to probability estimates.
@@ -253,11 +248,6 @@ def roc_auc_score(y_true, y_score, multiclass="ovr", average="macro",
253248
``'ovo'``:
254249
Calculate metrics for the multiclass case using the one-vs-one
255250
approach.
256-
=======
257-
(as returned by "decision_function" on some classifiers). For binary
258-
y_true, y_score is supposed to be the score of the class with greater
259-
label.
260-
>>>>>>> [MRG+1] Completely support binary y_true in roc_auc_score (#9828)
261251
262252
average : string, [None, 'micro', 'macro' (default), 'samples', 'weighted']
263253
If ``None``, the scores for each class are returned. Otherwise,
@@ -287,6 +277,9 @@ def roc_auc_score(y_true, y_score, multiclass="ovr", average="macro",
287277
.. [1] `Wikipedia entry for the Receiver operating characteristic
288278
<https://en.wikipedia.org/wiki/Receiver_operating_characteristic>`_
289279
280+
.. [2] Fawcett T. An introduction to ROC analysis[J]. Pattern Recognition
281+
Letters, 2006, 27(8):861-874.
282+
290283
See also
291284
--------
292285
average_precision_score : Area under the precision-recall curve
@@ -589,6 +582,8 @@ def roc_curve(y_true, y_score, pos_label=None, sample_weight=None,
589582
.. [1] `Wikipedia entry for the Receiver operating characteristic
590583
<https://en.wikipedia.org/wiki/Receiver_operating_characteristic>`_
591584
585+
.. [2] Fawcett T. An introduction to ROC analysis[J]. Pattern Recognition
586+
Letters, 2006, 27(8):861-874.
592587
593588
Examples
594589
--------
@@ -598,11 +593,11 @@ def roc_curve(y_true, y_score, pos_label=None, sample_weight=None,
598593
>>> scores = np.array([0.1, 0.4, 0.35, 0.8])
599594
>>> fpr, tpr, thresholds = metrics.roc_curve(y, scores, pos_label=2)
600595
>>> fpr
601-
array([ 0. , 0.5, 0.5, 1. ])
596+
array([ 0. , 0. , 0.5, 0.5, 1. ])
602597
>>> tpr
603-
array([ 0.5, 0.5, 1. , 1. ])
598+
array([ 0. , 0.5, 0.5, 1. , 1. ])
604599
>>> thresholds
605-
array([ 0.8 , 0.4 , 0.35, 0.1 ])
600+
array([ 1.8 , 0.8 , 0.4 , 0.35, 0.1 ])
606601
607602
"""
608603
fps, tps, thresholds = _binary_clf_curve(
@@ -626,8 +621,9 @@ def roc_curve(y_true, y_score, pos_label=None, sample_weight=None,
626621
tps = tps[optimal_idxs]
627622
thresholds = thresholds[optimal_idxs]
628623

629-
if tps.size == 0 or fps[0] != 0:
624+
if tps.size == 0 or fps[0] != 0 or tps[0] != 0:
630625
# Add an extra threshold position if necessary
626+
# to make sure that the curve starts at (0, 0)
631627
tps = np.r_[0, tps]
632628
fps = np.r_[0, fps]
633629
thresholds = np.r_[thresholds[0] + 1, thresholds]

sklearn/metrics/tests/test_ranking.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -270,8 +270,8 @@ def test_roc_curve_toydata():
270270
y_score = [0, 1]
271271
tpr, fpr, _ = roc_curve(y_true, y_score)
272272
roc_auc = roc_auc_score(y_true, y_score)
273-
assert_array_almost_equal(tpr, [0, 1])
274-
assert_array_almost_equal(fpr, [1, 1])
273+
assert_array_almost_equal(tpr, [0, 0, 1])
274+
assert_array_almost_equal(fpr, [0, 1, 1])
275275
assert_almost_equal(roc_auc, 1.)
276276

277277
y_true = [0, 1]
@@ -294,8 +294,8 @@ def test_roc_curve_toydata():
294294
y_score = [1, 0]
295295
tpr, fpr, _ = roc_curve(y_true, y_score)
296296
roc_auc = roc_auc_score(y_true, y_score)
297-
assert_array_almost_equal(tpr, [0, 1])
298-
assert_array_almost_equal(fpr, [1, 1])
297+
assert_array_almost_equal(tpr, [0, 0, 1])
298+
assert_array_almost_equal(fpr, [0, 1, 1])
299299
assert_almost_equal(roc_auc, 1.)
300300

301301
y_true = [1, 0]
@@ -319,8 +319,8 @@ def test_roc_curve_toydata():
319319
# assert UndefinedMetricWarning because of no negative sample in y_true
320320
tpr, fpr, _ = assert_warns(UndefinedMetricWarning, roc_curve, y_true, y_score)
321321
assert_raises(ValueError, roc_auc_score, y_true, y_score)
322-
assert_array_almost_equal(tpr, [np.nan, np.nan])
323-
assert_array_almost_equal(fpr, [0.5, 1.])
322+
assert_array_almost_equal(tpr, [np.nan, np.nan, np.nan])
323+
assert_array_almost_equal(fpr, [0., 0.5, 1.])
324324

325325
# Multi-label classification task
326326
y_true = np.array([[0, 1], [0, 1]])
@@ -359,7 +359,7 @@ def test_roc_curve_drop_intermediate():
359359
y_true = [0, 0, 0, 0, 1, 1]
360360
y_score = [0., 0.2, 0.5, 0.6, 0.7, 1.0]
361361
tpr, fpr, thresholds = roc_curve(y_true, y_score, drop_intermediate=True)
362-
assert_array_almost_equal(thresholds, [1., 0.7, 0.])
362+
assert_array_almost_equal(thresholds, [2., 1., 0.7, 0.])
363363

364364
# Test dropping thresholds with repeating scores
365365
y_true = [0, 0, 0, 0, 0, 0, 0,
@@ -368,7 +368,7 @@ def test_roc_curve_drop_intermediate():
368368
0.6, 0.7, 0.8, 0.9, 0.9, 1.0]
369369
tpr, fpr, thresholds = roc_curve(y_true, y_score, drop_intermediate=True)
370370
assert_array_almost_equal(thresholds,
371-
[1.0, 0.9, 0.7, 0.6, 0.])
371+
[2.0, 1.0, 0.9, 0.7, 0.6, 0.])
372372

373373

374374
def test_roc_curve_fpr_tpr_increasing():

0 commit comments

Comments
 (0)
0