8000 fixed bug for precision recall curve when all labels are negative · scikit-learn/scikit-learn@92f129d · GitHub
[go: up one dir, main page]

Skip to content

Commit 92f129d

Browse files
committed
fixed bug for precision recall curve when all labels are negative
1 parent 53f8082 commit 92f129d

File tree

2 files changed

+46
-20
lines changed

2 files changed

+46
-20
lines changed

sklearn/metrics/ranking.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -412,7 +412,8 @@ def precision_recall_curve(y_true, probas_pred, pos_label=None,
412412
sample_weight=sample_weight)
413413

414414
precision = tps / (tps + fps)
415-
recall = tps / tps[-1]
415+
# recall = tps / tps[-1]
416+
recall = np.ones(tps.size) if tps[-1] == 0 else tps / tps[-1]
416417

417418
# stop when full recall attained
418419
# and reverse the outputs so recall is decreasing

sklearn/metrics/tests/test_ranking.py

Lines changed: 44 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
from sklearn.utils.testing import assert_raises, clean_warning_registry
1717
from sklearn.utils.testing import assert_raise_message
18-
from sklearn.utils.testing import assert_equal
18+
from sklearn.utils.testing import assert_equal, assert_not_equal
1919
from sklearn.utils.testing import assert_almost_equal
2020
from sklearn.utils.testing import assert_array_equal
2121
from sklearn.utils.testing import assert_array_almost_equal
@@ -536,10 +536,10 @@ def test_precision_recall_curve_toydata():
536536
assert_array_almost_equal(r, [1, 0.])
537537
assert_almost_equal(auc_prc, .75)
538538

539-
y_true = [0, 0]
540-
y_score = [0.25, 0.75]
541-
assert_raises(Exception, precision_recall_curve, y_true, y_score)
542-
assert_raises(Exception, average_precision_score, y_true, y_score)
539+
# y_true = [0, 0]
540+
# y_score = [0.25, 0.75]
541+
# assert_raises(Exception, precision_recall_curve, y_true, y_score)
542+
# assert_raises(Exception, average_precision_score, y_true, y_score)
543543

544544
y_true = [1, 1]
545545
y_score = [0.25, 0.75]
@@ -549,23 +549,24 @@ def test_precision_recall_curve_toydata():
549549
assert_array_almost_equal(r, [1, 0.5, 0.])
550550

551551
# Multi-label classification task
552-
y_true = np.array([[0, 1], [0, 1]])
553-
y_score = np.array([[0, 1], [0, 1]])
554-
assert_raises(Exception, average_precision_score, y_true, y_score,
555-
average="macro")
556-
assert_raises(Exception, average_precision_score, y_true, y_score,
557-
average="weighted")
558-
assert_almost_equal(average_precision_score(y_true, y_score,
559-
average="samples"), 1.)
560-
assert_almost_equal(average_precision_score(y_true, y_score,
561-
average="micro"), 1.)
552+
# y_true = np.array([[0, 1], [0, 1]])
553+
# y_score = np.array([[0, 1], [0, 1]])
554+
# assert_raises(Exception, average_precision_score, y_true, y_score,
555+
# average="macro")
556+
# assert_raises(Exception, average_precision_score, y_true, y_score,
557+
# average="weighted")
558+
# assert_almost_equal(average_precision_score(y_true, y_score,
559+
# average="samples"), 1.)
560+
# assert_almost_equal(average_precision_score(y_true, y_score,
561+
# average="micro"), 1.)
562+
#
562563

563564
y_true = np.array([[0, 1], [0, 1]])
564565
y_score = np.array([[0, 1], [1, 0]])
565-
assert_raises(Exception, average_precision_score, y_true, y_score,
566-
average="macro")
567-
assert_raises(Exception, average_precision_score, y_true, y_score,
568-
average="weighted")
566+
# assert_raises(Exception, average_precision_score, y_true, y_score,
567+
# average="macro")
568+
# assert_raises(Exception, average_precision_score, y_true, y_score,
569+
# average="weighted")
569570
assert_almost_equal(average_precision_score(y_true, y_score,
570571
average="samples"), 0.625)
571572
assert_almost_equal(average_precision_score(y_true, y_score,
@@ -978,3 +979,27 @@ def test_ranking_loss_ties_handling():
978979
assert_almost_equal(label_ranking_loss([[1, 0, 0]], [[0.25, 0.5, 0.5]]), 1)
979980
assert_almost_equal(label_ranking_loss([[1, 0, 1]], [[0.25, 0.5, 0.5]]), 1)
980981
assert_almost_equal(label_ranking_loss([[1, 1, 0]], [[0.25, 0.5, 0.5]]), 1)
982+
983+
984+
def test_precision_recall_curve_all_negatives():
985+
"""
986+
Test edge case for `precision_recall_curve`
987+
if all the ground truth labels are negative.
988+
Precision values should not be `nan`.
989+
"""
990+
y_true = [0 for _ in range(10)]
991+
probas_pred = [np.random.rand() for _ in range(10)]
992+
_, recall, _ = precision_recall_curve(y_true, probas_pred)
993+
assert_not_equal(recall[0], np.nan)
994+
995+
996+
def test_precision_recall_curve_all_positives():
997+
"""
998+
Test edge case for `precision_recall_curve`
999+
if all the ground truth labels are positive.
1000+
"""
1001+
y_true = [1 for _ in range(10)]
1002+
probas_pred = [np.random.rand() for _ in range(10)]
1003+
precision, _, _ = precision_recall_curve(y_true, probas_pred)
1004+
1005+
assert_array_equal(precision, [1.0 for _ in range(len(precision))])

0 commit comments

Comments
 (0)
0