@@ -283,15 +283,17 @@ def test_roc_curve_toydata():
283
283
y_true = [0 , 0 ]
284
284
y_score = [0.25 , 0.75 ]
285
285
# assert UndefinedMetricWarning because of no positive sample in y_true
286
- tpr , fpr , _ = assert_warns (UndefinedMetricWarning , roc_curve , y_true , y_score )
286
+ tpr , fpr , _ = assert_warns (UndefinedMetricWarning , roc_curve ,
287
+ y_true , y_score )
287
288
assert_raises (ValueError , roc_auc_score , y_true , y_score )
288
289
assert_array_almost_equal (tpr , [0. , 0.5 , 1. ])
289
290
assert_array_almost_equal (fpr , [np .nan , np .nan , np .nan ])
290
291
291
292
y_true = [1 , 1 ]
292
293
y_score = [0.25 , 0.75 ]
293
294
# assert UndefinedMetricWarning because of no negative sample in y_true
294
- tpr , fpr , _ = assert_warns (UndefinedMetricWarning , roc_curve , y_true , y_score )
295
+ tpr , fpr , _ = assert_warns (UndefinedMetricWarning , roc_curve ,
296
+ y_true , y_score )
295
297
assert_raises (ValueError , roc_auc_score , y_true , y_score )
296
298
assert_array_almost_equal (tpr , [np .nan , np .nan ])
297
299
assert_array_almost_equal (fpr , [0.5 , 1. ])
@@ -538,8 +540,9 @@ def test_precision_recall_curve_toydata():
538
540
539
541
y_true = [0 , 0 ]
540
542
y_score = [0.25 , 0.75 ]
541
- assert_raises (Exception , precision_recall_curve , y_true , y_score )
542
- assert_raises (Exception , average_precision_score , y_true , y_score )
543
+ p , r , _ = precision_recall_curve (y_true , y_score )
544
+ assert_array_equal (p , np .array ([0.0 , 1.0 ]))
545
+ assert_array_equal (r , np .array ([1.0 , 0.0 ]))
543
546
544
547
y_true = [1 , 1 ]
545
548
y_score = [0.25 , 0.75 ]
@@ -551,21 +554,21 @@ def test_precision_recall_curve_toydata():
551
554
# Multi-label classification task
552
555
y_true = np .array ([[0 , 1 ], [0 , 1 ]])
553
556
y_score = np .array ([[0 , 1 ], [0 , 1 ]])
554
- assert_raises ( Exception , average_precision_score , y_true , y_score ,
555
- average = "macro" )
556
- assert_raises ( Exception , average_precision_score , y_true , y_score ,
557
- average = "weighted" )
557
+ assert_almost_equal ( average_precision_score ( y_true , y_score ,
558
+ average = "macro" ), 0.75 )
559
+ assert_almost_equal ( average_precision_score ( y_true , y_score ,
560
+ average = "weighted" ), 1.0 )
558
561
assert_almost_equal (average_precision_score (y_true , y_score ,
559
562
average = "samples" ), 1. )
560
563
assert_almost_equal (average_precision_score (y_true , y_score ,
561
564
average = "micro" ), 1. )
562
565
563
566
y_true = np .array ([[0 , 1 ], [0 , 1 ]])
564
567
y_score = np .array ([[0 , 1 ], [1 , 0 ]])
565
- assert_raises ( Exception , average_precision_score , y_true , y_score ,
566
- average = "macro" )
567
- assert_raises ( Exception , average_precision_score , y_true , y_score ,
568
- average = "weighted" )
568
+ assert_almost_equal ( average_precision_score ( y_true , y_score ,
569
+ average = "macro" ), 0.75 )
570
+ assert_almost_equal ( average_precision_score ( y_true , y_score ,
571
+ average = "weighted" ), 1.0 )
569
572
assert_almost_equal (average_precision_score (y_true , y_score ,
570
573
average = "samples" ), 0.625 )
571
574
assert_almost_equal (average_precision_score (y_true , y_score ,
0 commit comments