@@ -309,15 +309,17 @@ def test_roc_curve_toydata():
309
309
y_true = [0 , 0 ]
310
310
y_score = [0.25 , 0.75 ]
311
311
# assert UndefinedMetricWarning because of no positive sample in y_true
312
- tpr , fpr , _ = assert_warns (UndefinedMetricWarning , roc_curve , y_true , y_score )
312
+ tpr , fpr , _ = assert_warns (UndefinedMetricWarning , roc_curve ,
313
+ y_true , y_score )
313
314
assert_raises (ValueError , roc_auc_score , y_true , y_score )
314
315
assert_array_almost_equal (tpr , [0. , 0.5 , 1. ])
315
316
assert_array_almost_equal (fpr , [np .nan , np .nan , np .nan ])
316
317
317
318
y_true = [1 , 1 ]
318
319
y_score = [0.25 , 0.75 ]
319
320
# assert UndefinedMetricWarning because of no negative sample in y_true
320
- tpr , fpr , _ = assert_warns (UndefinedMetricWarning , roc_curve , y_true , y_score )
321
+ tpr , fpr , _ = assert_warns (UndefinedMetricWarning , roc_curve ,
322
+ y_true , y_score )
321
323
assert_raises (ValueError , roc_auc_score , y_true , y_score )
322
324
assert_array_almost_equal (tpr , [np .nan , np .nan ])
323
325
assert_array_almost_equal (fpr , [0.5 , 1. ])
@@ -565,8 +567,9 @@ def test_precision_recall_curve_toydata():
565
567
566
568
y_true = [0 , 0 ]
567
569
y_score = [0.25 , 0.75 ]
568
- assert_raises (Exception , precision_recall_curve , y_true , y_score )
569
- assert_raises (Exception , average_precision_score , y_true , y_score )
570
+ p , r , _ = precision_recall_curve (y_true , y_score )
571
+ assert_array_equal (p , np .array ([0.0 , 1.0 ]))
572
+ assert_array_equal (r , np .array ([1.0 , 0.0 ]))
570
573
571
574
y_true = [1 , 1 ]
572
575
y_score = [0.25 , 0.75 ]
@@ -578,21 +581,21 @@ def test_precision_recall_curve_toydata():
578
581
# Multi-label classification task
579
582
y_true = np .array ([[0 , 1 ], [0 , 1 ]])
580
583
y_score = np .array ([[0 , 1 ], [0 , 1 ]])
581
- assert_raises ( Exception , average_precision_score , y_true , y_score ,
582
- average = "macro" )
583
- assert_raises ( Exception , average_precision_score , y_true , y_score ,
584
- average = "weighted" )
584
+ assert_almost_equal ( average_precision_score ( y_true , y_score ,
585
+ average = "macro" ), 0.75 )
586
+ assert_almost_equal ( average_precision_score ( y_true , y_score ,
587
+ average = "weighted" ), 1.0 )
585
588
assert_almost_equal (average_precision_score (y_true , y_score ,
586
589
average = "samples" ), 1. )
587
590
assert_almost_equal (average_precision_score (y_true , y_score ,
588
591
average = "micro" ), 1. )
589
592
590
593
y_true = np .array ([[0 , 1 ], [0 , 1 ]])
591 594
y_score = np .array ([[0 , 1 ], [1 , 0 ]])
592
- assert_raises ( Exception , average_precision_score , y_true , y_score ,
593
- average = "macro" )
594
- assert_raises ( Exception , average_precision_score , y_true , y_score ,
595
- average = "weighted" )
595
+ assert_almost_equal ( average_precision_score ( y_true , y_score ,
596
+ average = "macro" ), 0.75 )
597
+ assert_almost_equal ( average_precision_score ( y_true , y_score ,
598
+ average = "weighted" ), 1.0 )
596
599
assert_almost_equal (average_precision_score (y_true , y_score ,
597
600
average = "samples" ), 0.75 )
598
601
assert_almost_equal (average_precision_score (y_true , y_score ,
0 commit comments