6
6
from itertools import product
7
7
import warnings
8
8
9
+ import pytest
10
+
9
11
from sklearn import datasets
10
12
from sklearn import svm
11
13
@@ -520,7 +522,8 @@ def test_matthews_corrcoef_multiclass():
520
522
assert_almost_equal (mcc , 0. )
521
523
522
524
523
- def test_matthews_corrcoef_overflow ():
525
+ @pytest .mark .parametrize ('n_points' , [100 , 10000 , 1000000 ])
526
+ def test_matthews_corrcoef_overflow (n_points ):
524
527
# https://github.com/scikit-learn/scikit-learn/issues/9622
525
528
rng = np .random .RandomState (20170906 )
526
529
@@ -543,16 +546,15 @@ def random_ys(n_points): # binary
543
546
y_pred = (x_pred > 0.5 )
544
547
return y_true , y_pred
545
548
546
- for n_points in [100 , 10000 , 1000000 ]:
<
8000
/td>547
- arr = np .repeat ([0. , 1. ], n_points ) # binary
548
- assert_almost_equal (matthews_corrcoef (arr , arr ), 1.0 )
549
- arr = np .repeat ([0. , 1. , 2. ], n_points ) # multiclass
550
- assert_almost_equal (matthews_corrcoef (arr , arr ), 1.0 )
549
+ arr = np .repeat ([0. , 1. ], n_points ) # binary
550
+ assert_almost_equal (matthews_corrcoef (arr , arr ), 1.0 )
551
+ arr = np .repeat ([0. , 1. , 2. ], n_points ) # multiclass
552
+ assert_almost_equal (matthews_corrcoef (arr , arr ), 1.0 )
551
553
552
- y_true , y_pred = random_ys (n_points )
553
- assert_almost_equal (matthews_corrcoef (y_true , y_true ), 1.0 )
554
- assert_almost_equal (matthews_corrcoef (y_true , y_pred ),
555
- mcc_safe (y_true , y_pred ))
554
+ y_true , y_pred = random_ys (n_points )
555
+ assert_almost_equal (matthews_corrcoef (y_true , y_true ), 1.0 )
556
+ assert_almost_equal (matthews_corrcoef (y_true , y_pred ),
557
+ mcc_safe (y_true , y_pred ))
556
558
557
559
558
560
def test_precision_recall_f1_score_multiclass ():
@@ -610,18 +612,19 @@ def test_precision_recall_f1_score_multiclass():
610
612
assert_array_equal (s , [24 , 20 , 31 ])
611
613
612
614
613
- def test_precision_refcall_f1_score_multilabel_unordered_labels ():
615
+ @pytest .mark .parametrize ('average' ,
616
+ ['samples' , 'micro' , 'macro' , 'weighted' , None ])
617
+ def test_precision_refcall_f1_score_multilabel_unordered_labels (average ):
614
618
# test that labels need not be sorted in the multilabel case
615
619
y_true = np .array ([[1 , 1 , 0 , 0 ]])
616
620
y_pred = np .array ([[0 , 0 , 1 , 1 ]])
617
- for average in ['samples' , 'micro' , 'macro' , 'weighted' , None ]:
618
- p , r , f , s = precision_recall_fscore_support (
619
- y_true , y_pred , labels = [3 , 0 , 1 , 2 ], warn_for = [], average = average )
620
- assert_array_equal (p , 0 )
621
- assert_array_equal (r , 0 )
622
- assert_array_equal (f , 0 )
623
- if average is None :
624
- assert_array_equal (s , [0 , 1 , 1 , 0 ])
621
+ p , r , f , s = precision_recall_fscore_support (
622
+ y_true , y_pred , labels = [3 , 0 , 1 , 2 ], warn_for = [], average = average )
623
+ assert_array_equal (p , 0 )
624
+ assert_array_equal (r , 0 )
625
+ assert_array_equal (f , 0 )
626
+ if average is None :
627
+ assert_array_equal (s , [0 , 1 , 1 , 0 ])
625
628
626
629
627
630
def test_precision_recall_f1_score_binary_averaged ():
@@ -1207,7 +1210,9 @@ def test_precision_recall_f1_score_with_an_empty_prediction():
1207
1210
0.333 , 2 )
1208
1211
1209
1212
1210
- def test_precision_recall_f1_no_labels ():
1213
+ @pytest .mark .parametrize ('beta' , [1 ])
1214
+ @pytest .mark .parametrize ('average' , ["macro" , "micro" , "weighted" , "samples" ])
1215
+ def test_precision_recall_f1_no_labels (beta , average ):
1211
1216
y_true = np .zeros ((20 , 3 ))
1212
1217
y_pred = np .zeros_like (y_true )
1213
1218
@@ -1219,33 +1224,31 @@ def test_precision_recall_f1_no_labels():
1219
1224
# |y_i| = [0, 0, 0]
1220
1225
# |y_hat_i| = [0, 0, 0]
1221
1226
1222
- for beta in [1 ]:
1223
- p , r , f , s = assert_warns (UndefinedMetricWarning ,
1224
- precision_recall_fscore_support ,
1225
- y_true , y_pred , average = None , beta = beta )
1226
- assert_array_almost_equal (p , [0 , 0 , 0 ], 2 )
1227
- assert_array_almost_equal (r , [0 , 0 , 0 ], 2 )
1228
- assert_array_almost_equal (f , [0 , 0 , 0 ], 2 )
1229
- assert_array_almost_equal (s , [0 , 0 , 0 ], 2 )
1230
-
1231
- fbeta = assert_warns (UndefinedMetricWarning , fbeta_score ,
1232
- y_true , y_pred , beta = beta , average = None )
1233
- assert_array_almost_equal (fbeta , [0 , 0 , 0 ], 2 )
1234
-
1235
- for average in ["macro" , "micro" , "weighted" , "samples" ]:
1236
- p , r , f , s = assert_warns (UndefinedMetricWarning ,
1237
- precision_recall_fscore_support ,
1238
- y_true , y_pred , average = average ,
1239
- beta = beta )
1240
- assert_almost_equal (p , 0 )
1241
- assert_almost_equal (r , 0 )
1242
- assert_almost_equal (f , 0 )
1243
- assert_equal (s , None )
1244
-
1245
- fbeta = assert_warns (UndefinedMetricWarning , fbeta_score ,
1246
- y_true , y_pred ,
1247
- beta = beta , average = average )
1248
- assert_almost_equal (fbeta , 0 )
1227
+ p , r , f , s = assert_warns (UndefinedMetricWarning ,
1228
+ precision_recall_fscore_support ,
1229
+ y_true , y_pred , average = None , beta = beta )
1230
+ assert_array_almost_equal (p , [0 , 0 , 0 ], 2 )
1231
+ assert_array_almost_equal (r , [0 , 0 , 0 ], 2 )
1232
+ assert_array_almost_equal (f , [0 , 0 , 0 ], 2 )
1233
+ assert_array_almost_equal (s , [0 , 0 , 0 ], 2 )
1234
+
1235
+ fbeta = assert_warns (UndefinedMetricWarning , fbeta_score ,
1236
+ y_true , y_pred , beta = beta , average = None )
1237
+ assert_array_almost_equal (fbeta , [0 , 0 , 0 ], 2 )
1238
+
1239
+ p , r , f , s = assert_warns (UndefinedMetricWarning ,
1240
+ precision_recall_fscore_support ,
1241
+ y_true , y_pred , average = average ,
1242
+ beta = beta )
1243
+ assert_almost_equal (p , 0 )
1244
+ assert_almost_equal (r , 0 )
1245
+ assert_almost_equal (f , 0 )
1246
+ assert_equal (s , None )
1247
+
1248
+ fbeta = assert_warns (UndefinedMetricWarning , fbeta_score ,
1249
+ y_true , y_pred ,
1250
+ beta = beta , average = average )
1251
+ assert_almost_equal (fbeta , 0 )
1249
1252
1250
1253
1251
1254
def test_prf_warnings ():
0 commit comments