@@ -536,10 +536,11 @@ def test_precision_recall_curve_toydata():
536
536
assert_array_almost_equal (r , [1 , 0. ])
537
537
assert_almost_equal (auc_prc , .75 )
538
538
539
- # y_true = [0, 0]
540
- # y_score = [0.25, 0.75]
541
- # assert_raises(Exception, precision_recall_curve, y_true, y_score)
542
- # assert_raises(Exception, average_precision_score, y_true, y_score)
539
+ y_true = [0 , 0 ]
540
+ y_score = [0.25 , 0.75 ]
541
+ p , r , _ = precision_recall_curve (y_true , y_score )
542
+ assert_array_equal (p , np .array ([0.0 , 1.0 ]))
543
+ assert_array_equal (r , np .array ([1.0 , 0.0 ]))
543
544
544
545
y_true = [1 , 1 ]
545
546
y_score = [0.25 , 0.75 ]
@@ -549,24 +550,21 @@ def test_precision_recall_curve_toydata():
549
550
assert_array_almost_equal (r , [1 , 0.5 , 0. ])
550
551
551
552
# Multi-label classification task
552
- # y_true = np.array([[0, 1], [0, 1]])
553
- # y_score = np.array([[0, 1], [0, 1]])
554
- # assert_raises(Exception, average_precision_score, y_true, y_score,
555
- # average="macro")
556
- # assert_raises(Exception, average_precision_score, y_true, y_score,
557
- # average="weighted")
558
- # assert_almost_equal(average_precision_score(y_true, y_score,
559
- # average="samples"), 1.)
560
- # assert_almost_equal(average_precision_score(y_true, y_score,
561
- # average="micro"), 1.)
562
- #
553
+ y_true = np .array ([[0 , 1 ], [0 , 1 ]])
554
+ y_score = np .array ([[0 , 1 ], [0 , 1 ]])
555
+ assert_almost_equal (average_precision_score (y_true , y_score ,
556
+ average = "macro" ), 0.75 )
557
+ assert_almost_equal (average_precision_score (y_true , y_score ,
558
+ average = "weighted" ), 1.0 )
559
+ assert_almost_equal (average_precision_score (y_true , y_score ,
560
+ average = "samples" ), 1. )
561
+ assert_almost_equal (average_precision_score (y_true , y_score ,
562
+ average = "micro" ), 1. )
563
563
564
564
y_true = np .array ([[0 , 1 ], [0 , 1 ]])
565
565
y_score = np .array ([[0 , 1 ], [1 , 0 ]])
566
- # assert_raises(Exception, average_precision_score, y_true, y_score,
567
- # average="macro")
568
- # assert_raises(Exception, average_precision_score, y_true, y_score,
569
- # average="weighted")
566
+ assert_almost_equal (average_precision_score (y_true , y_score , average = "macro" ), 0.75 )
567
+ assert_almost_equal (average_precision_score (y_true , y_score , average = "weighted" ), 1.0 )
570
568
assert_almost_equal (average_precision_score (y_true , y_score ,
571
569
average = "samples" ), 0.625 )
572
570
assert_almost_equal (average_precision_score (y_true , y_score ,
@@ -980,26 +978,3 @@ def test_ranking_loss_ties_handling():
980
978
assert_almost_equal (label_ranking_loss ([[1 , 0 , 1 ]], [[0.25 , 0.5 , 0.5 ]]), 1 )
981
979
assert_almost_equal (label_ranking_loss ([[1 , 1 , 0 ]], [[0.25 , 0.5 , 0.5 ]]), 1 )
982
980
983
-
984
- def test_precision_recall_curve_all_negatives ():
985
- """
986
- Test edge case for `precision_recall_curve`
987
- if all the ground truth labels are negative.
988
- Precision values should not be `nan`.
989
- """
990
- y_true = [0 for _ in range (10 )]
991
- probas_pred = [np .random .rand () for _ in range (10 )]
992
- _ , recall , _ = precision_recall_curve (y_true , probas_pred )
993
- assert_not_equal (recall [0 ], np .nan )
994
-
995
-
996
- def test_precision_recall_curve_all_positives ():
997
- """
998
5E5D
- Test edge case for `precision_recall_curve`
999
- if all the ground truth labels are positive.
1000
- """
1001
- y_true = [1 for _ in range (10 )]
1002
- probas_pred = [np .random .rand () for _ in range (10 )]
1003
- precision , _ , _ = precision_recall_curve (y_true , probas_pred )
1004
-
1005
- assert_array_equal (precision , [1.0 for _ in range (len (precision ))])
0 commit comments