8000 fixed bug for precision recall curve when all labels are negative · scikit-learn/scikit-learn@24e9a11 · GitHub
[go: up one dir, main page]

Skip to content

Commit 24e9a11

Browse files
committed
fixed bug for precision recall curve when all labels are negative
1 parent 542c02b commit 24e9a11

File tree

2 files changed

+16
-13
lines changed

2 files changed

+16
-13
lines changed

sklearn/metrics/ranking.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -412,7 +412,7 @@ def precision_recall_curve(y_true, probas_pred, pos_label=None,
412412
sample_weight=sample_weight)
413413

414414
precision = tps / (tps + fps)
415-
recall = tps / tps[-1]
415+
recall = np.ones(tps.size) if tps[-1] == 0 else tps / tps[-1]
416416

417417
# stop when full recall attained
418418
# and reverse the outputs so recall is decreasing

sklearn/metrics/tests/test_ranking.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -283,15 +283,17 @@ def test_roc_curve_toydata():
283283
y_true = [0, 0]
284284
y_score = [0.25, 0.75]
285285
# assert UndefinedMetricWarning because of no positive sample in y_true
286-
tpr, fpr, _ = assert_warns(UndefinedMetricWarning, roc_curve, y_true, y_score)
286+
tpr, fpr, _ = assert_warns(UndefinedMetricWarning, roc_curve,
287+
y_true, y_score)
287288
assert_raises(ValueError, roc_auc_score, y_true, y_score)
288289
assert_array_almost_equal(tpr, [0., 0.5, 1.])
289290
assert_array_almost_equal(fpr, [np.nan, np.nan, np.nan])
290291

291292
y_true = [1, 1]
292293
y_score = [0.25, 0.75]
293294
# assert UndefinedMetricWarning because of no negative sample in y_true
294-
tpr, fpr, _ = assert_warns(UndefinedMetricWarning, roc_curve, y_true, y_score)
295+
tpr, fpr, _ = assert_warns(UndefinedMetricWarning, roc_curve,
296+
y_true, y_score)
295297
assert_raises(ValueError, roc_auc_score, y_true, y_score)
296298
assert_array_almost_equal(tpr, [np.nan, np.nan])
297299
assert_array_almost_equal(fpr, [0.5, 1.])
@@ -538,8 +540,9 @@ def test_precision_recall_curve_toydata():
538540

539541
y_true = [0, 0]
540542
y_score = [0.25, 0.75]
541-
assert_raises(Exception, precision_recall_curve, y_true, y_score)
542-
assert_raises(Exception, average_precision_score, y_true, y_score)
543+
p, r, _ = precision_recall_curve(y_true, y_score)
544+
assert_array_equal(p, np.array([0.0, 1.0]))
545+
assert_array_equal(r, np.array([1.0, 0.0]))
543546

544547
y_true = [1, 1]
545548
y_score = [0.25, 0.75]
@@ -551,21 +554,21 @@ def test_precision_recall_curve_toydata():
551554
# Multi-label classification task
552555
y_true = np.array([[0, 1], [0, 1]])
553556
y_score = np.array([[0, 1], [0, 1]])
554-
assert_raises(Exception, average_precision_score, y_true, y_score,
555-
average="macro")
556-
assert_raises(Exception, average_precision_score, y_true, y_score,
557-
average="weighted")
557+
assert_almost_equal(average_precision_score(y_true, y_score,
558+
average="macro"), 0.75)
559+
assert_almost_equal(average_precision_score(y_true, y_score,
560+
average="weighted"), 1.0)
558561
assert_almost_equal(average_precision_score(y_true, y_score,
559562
average="samples"), 1.)
560563
assert_almost_equal(average_precision_score(y_true, y_score,
561564
average="micro"), 1.)
562565

563566
y_true = np.array([[0, 1], [0, 1]])
564567
y_score = np.array([[0, 1], [1, 0]])
565-
assert_raises(Exception, average_precision_score, y_true, y_score,
566-
average="macro")
567-
assert_raises(Exception, average_precision_score, y_true, y_score,
568-
average="weighted")
568+
assert_almost_equal(average_precision_score(y_true, y_score,
569+
average="macro"), 0.75)
570+
assert_almost_equal(average_precision_score(y_true, y_score,
571+
average="weighted"), 1.0)
569572
assert_almost_equal(average_precision_score(y_true, y_score,
570573
average="samples"), 0.625)
571574
assert_almost_equal(average_precision_score(y_true, y_score,

0 commit comments

Comments
 (0)
0