10000 updated tests · scikit-learn/scikit-learn@50fbd04 · GitHub
[go: up one dir, main page]

Skip to content

Commit 50fbd04

Browse files
committed
updated tests
1 parent 92f129d commit 50fbd04

File tree

2 files changed

+17
-43
lines changed

2 files changed

+17
-43
lines changed

sklearn/metrics/ranking.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -412,7 +412,6 @@ def precision_recall_curve(y_true, probas_pred, pos_label=None,
412412
sample_weight=sample_weight)
413413

414414
precision = tps / (tps + fps)
415-
# recall = tps / tps[-1]
416415
recall = np.ones(tps.size) if tps[-1] == 0 else tps / tps[-1]
417416

418417
# stop when full recall attained

sklearn/metrics/tests/test_ranking.py

Lines changed: 17 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -536,10 +536,11 @@ def test_precision_recall_curve_toydata():
536536
assert_array_almost_equal(r, [1, 0.])
537537
assert_almost_equal(auc_prc, .75)
538538

539-
# y_true = [0, 0]
540-
# y_score = [0.25, 0.75]
541-
# assert_raises(Exception, precision_recall_curve, y_true, y_score)
542-
# assert_raises(Exception, average_precision_score, y_true, y_score)
539+
y_true = [0, 0]
540+
y_score = [0.25, 0.75]
541+
p, r, _ = precision_recall_curve(y_true, y_score)
542+
assert_array_equal(p, np.array([0.0, 1.0]))
543+
assert_array_equal(r, np.array([1.0, 0.0]))
543544

544545
y_true = [1, 1]
545546
y_score = [0.25, 0.75]
@@ -549,24 +550,21 @@ def test_precision_recall_curve_toydata():
549550
assert_array_almost_equal(r, [1, 0.5, 0.])
550551

551552
# Multi-label classification task
552-
# y_true = np.array([[0, 1], [0, 1]])
553-
# y_score = np.array([[0, 1], [0, 1]])
554-
# assert_raises(Exception, average_precision_score, y_true, y_score,
555-
# average="macro")
556-
# assert_raises(Exception, average_precision_score, y_true, y_score,
557-
# average="weighted")
558-
# assert_almost_equal(average_precision_score(y_true, y_score,
559-
# average="samples"), 1.)
560-
# assert_almost_equal(average_precision_score(y_true, y_score,
561-
# average="micro"), 1.)
562-
#
553+
y_true = np.array([[0, 1], [0, 1]])
554+
y_score = np.array([[0, 1], [0, 1]])
555+
assert_almost_equal(average_precision_score(y_true, y_score,
556+
average="macro"), 0.75)
557+
assert_almost_equal(average_precision_score(y_true, y_score,
558+
average="weighted"), 1.0)
559+
assert_almost_equal(average_precision_score(y_true, y_score,
560+
average="samples"), 1.)
561+
assert_almost_equal(average_precision_score(y_true, y_score,
562+
average="micro"), 1.)
563563

564564
y_true = np.array([[0, 1], [0, 1]])
565565
y_score = np.array([[0, 1], [1, 0]])
566-
# assert_raises(Exception, average_precision_score, y_true, y_score,
567-
# average="macro")
568-
# assert_raises(Exception, average_precision_score, y_true, y_score,
569-
# average="weighted")
566+
assert_almost_equal(average_precision_score(y_true, y_score, average="macro"), 0.75)
567+
assert_almost_equal(average_precision_score(y_true, y_score, average="weighted"), 1.0)
570568
assert_almost_equal(average_precision_score(y_true, y_score,
571569
average="samples"), 0.625)
572570
assert_almost_equal(average_precision_score(y_true, y_score,
@@ -980,26 +978,3 @@ def test_ranking_loss_ties_handling():
980978
assert_almost_equal(label_ranking_loss([[1, 0, 1]], [[0.25, 0.5, 0.5]]), 1)
981979
assert_almost_equal(label_ranking_loss([[1, 1, 0]], [[0.25, 0.5, 0.5]]), 1)
982980

983-
984-
def test_precision_recall_curve_all_negatives():
985-
"""
986-
Test edge case for `precision_recall_curve`
987-
if all the ground truth labels are negative.
988-
Precision values should not be `nan`.
989-
"""
990-
y_true = [0 for _ in range(10)]
991-
probas_pred = [np.random.rand() for _ in range(10)]
992-
_, recall, _ = precision_recall_curve(y_true, probas_pred)
993-
assert_not_equal(recall[0], np.nan)
994-
995-
996-
def test_precision_recall_curve_all_positives():
997-
"""
998 5E5D -
Test edge case for `precision_recall_curve`
999-
if all the ground truth labels are positive.
1000-
"""
1001-
y_true = [1 for _ in range(10)]
1002-
probas_pred = [np.random.rand() for _ in range(10)]
1003-
precision, _, _ = precision_recall_curve(y_true, probas_pred)
1004-
1005-
assert_array_equal(precision, [1.0 for _ in range(len(precision))])

0 commit comments

Comments
 (0)
0