8000 fixed bug for precision recall curve when all labels are negative · scikit-learn/scikit-learn@b0b9bca · GitHub
[go: up one dir, main page]

Skip to content

Commit b0b9bca

Browse files
committed
fixed bug for precision recall curve when all labels are negative
1 parent cb1b6c4 commit b0b9bca

File tree

2 files changed

+16
-13
lines changed

2 files changed

+16
-13
lines changed

sklearn/metrics/ranking.py

8000
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -417,7 +417,7 @@ def precision_recall_curve(y_true, probas_pred, pos_label=None,
417417
sample_weight=sample_weight)
418418

419419
precision = tps / (tps + fps)
420-
recall = tps / tps[-1]
420+
recall = np.ones(tps.size) if tps[-1] == 0 else tps / tps[-1]
421421

422422
# stop when full recall attained
423423
# and reverse the outputs so recall is decreasing

sklearn/metrics/tests/test_ranking.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -309,15 +309,17 @@ def test_roc_curve_toydata():
309309
y_true = [0, 0]
310310
y_score = [0.25, 0.75]
311311
# assert UndefinedMetricWarning because of no positive sample in y_true
312-
tpr, fpr, _ = assert_warns(UndefinedMetricWarning, roc_curve, y_true, y_score)
312+
tpr, fpr, _ = assert_warns(UndefinedMetricWarning, roc_curve,
313+
y_true, y_score)
313314
assert_raises(ValueError, roc_auc_score, y_true, y_score)
314315
assert_array_almost_equal(tpr, [0., 0.5, 1.])
315316
assert_array_almost_equal(fpr, [np.nan, np.nan, np.nan])
316317

317318
y_true = [1, 1]
318319
y_score = [0.25, 0.75]
319320
# assert UndefinedMetricWarning because of no negative sample in y_true
320-
tpr, fpr, _ = assert_warns(UndefinedMetricWarning, roc_curve, y_true, y_score)
321+
tpr, fpr, _ = assert_warns(UndefinedMetricWarning, roc_curve,
322+
y_true, y_score)
321323
assert_raises(ValueError, roc_auc_score, y_true, y_score)
322324
assert_array_almost_equal(tpr, [np.nan, np.nan])
323325
assert_array_almost_equal(fpr, [0.5, 1.])
@@ -565,8 +567,9 @@ def test_precision_recall_curve_toydata():
565567

566568
y_true = [0, 0]
567569
y_score = [0.25, 0.75]
568-
assert_raises(Exception, precision_recall_curve, y_true, y_score)
569-
assert_raises(Exception, average_precision_score, y_true, y_score)
570+
p, r, _ = precision_recall_curve(y_true, y_score)
571+
assert_array_equal(p, np.array([0.0, 1.0]))
572+
assert_array_equal(r, np.array([1.0, 0.0]))
570573

571574
y_true = [1, 1]
572575
y_score = [0.25, 0.75]
@@ -578,21 +581,21 @@ def test_precision_recall_curve_toydata():
578581
# Multi-label classification task
579582
y_true = np.array([[0, 1], [0, 1]])
580583
y_score = np.array([[0, 1], [0, 1]])
581-
assert_raises(Exception, average_precision_score, y_true, y_score,
582-
average="macro")
583-
assert_raises(Exception, average_precision_score, y_true, y_score,
584-
average="weighted")
584+
assert_almost_equal(average_precision_score(y_true, y_score,
585+
average="macro"), 0.75)
586+
assert_almost_equal(average_precision_score(y_true, y_score,
587+
average="weighted"), 1.0)
585588
assert_almost_equal(average_precision_score(y_true, y_score,
586589
average="samples"), 1.)
587590
assert_almost_equal(average_precision_score(y_true, y_score,
588591
average="micro"), 1.)
589592

590593
y_true = np.array([[0, 1], [0, 1]])
591594
y_score = np.array([[0, 1], [1, 0]])
592-
assert_raises(Exception, average_precision_score, y_true, y_score,
593-
average="macro")
594-
assert_raises(Exception, average_precision_score, y_true, y_score,
595-
average="weighted")
595+
assert_almost_equal(average_precision_score(y_true, y_score,
596+
average="macro"), 0.75)
597+
assert_almost_equal(average_precision_score(y_true, y_score,
598+
average="weighted"), 1.0)
596599
assert_almost_equal(average_precision_score(y_true, y_score,
597600
average="samples"), 0.75)
598601
assert_almost_equal(average_precision_score(y_true, y_score,

0 commit comments

Comments
 (0)
0