23
23
from sklearn .utils .testing import assert_raise_message
24
24
from sklearn .utils .testing import assert_true
25
25
from sklearn .utils .testing import ignore_warnings
26
+ from sklearn .utils .testing import _named_check
26
27
27
28
from sklearn .metrics import accuracy_score
28
29
from sklearn .metrics import balanced_accuracy_score
@@ -894,8 +895,8 @@ def test_averaging_multiclass(n_samples=50, n_classes=3):
894
895
y_pred_binarize = lb .transform (y_pred )
895
896
896
897
for name in METRICS_WITH_AVERAGING :
897
- yield (check_averaging , name , y_true , y_true_binarize ,
898
- y_pred , y_pred_binarize , y_score )
898
+ yield (_named_check ( check_averaging , name ), name , y_true ,
899
+ y_true_binarize , y_pred , y_pred_binarize , y_score )
899
900
900
901
901
902
def test_averaging_multilabel (n_classes = 5 , n_samples = 40 ):
@@ -909,8 +910,8 @@ def test_averaging_multilabel(n_classes=5, n_samples=40):
909
910
y_pred_binarize = y_pred
910
911
911
912
for name in METRICS_WITH_AVERAGING + THRESHOLDED_METRICS_WITH_AVERAGING :
912
- yield (check_averaging , name , y_true , y_true_binarize ,
913
- y_pred , y_pred_binarize , y_score )
913
+ yield (_named_check ( check_averaging , name ), name , y_true ,
914
+ y_true_binarize , y_pred , y_pred_binarize , y_score )
914
915
915
916
916
917
def test_averaging_multilabel_all_zeroes ():
@@ -921,8 +922,8 @@ def test_averaging_multilabel_all_zeroes():
921
922
y_pred_binarize = y_pred
922
923
923
924
for name in METRICS_WITH_AVERAGING :
924
- yield (check_averaging , name , y_true , y_true_binarize ,
925
- y_pred , y_pred_binarize , y_score )
925
+ yield (_named_check ( check_averaging , name ), name , y_true ,
926
+ y_true_binarize , y_pred , y_pred_binarize , y_score )
926
927
927
928
# Test _average_binary_score for weight.sum() == 0
928
929
binary_metric = (lambda y_true , y_score , average = "macro" :
@@ -940,8 +941,8 @@ def test_averaging_multilabel_all_ones():
940
941
y_pred_binarize = y_pred
941
942
942
943
for name in METRICS_WITH_AVERAGING :
943
- yield (check_averaging , name , y_true , y_true_binarize ,
944
- y_pred , y_pred_binarize , y_score )
944
+ yield (_named_check ( check_averaging , name ), name , y_true ,
945
+ y_true_binarize , y_pred , y_pred_binarize , y_score )
945
946
946
947
947
948
@ignore_warnings
@@ -1030,7 +1031,8 @@ def test_sample_weight_invariance(n_samples=50):
1030
1031
if name in METRICS_WITHOUT_SAMPLE_WEIGHT :
1031
1032
continue
1032
1033
metric = ALL_METRICS [name ]
1033
- yield check_sample_weight_invariance , name , metric , y_true , y_pred
1034
+ yield _named_check (check_sample_weight_invariance , name ), name ,\
1035
+ metric , y_true , y_pred
1034
1036
1035
1037
# binary
1036
1038
random_state = check_random_state (0 )
@@ -1045,9 +1047,11 @@ def test_sample_weight_invariance(n_samples=50):
1045
1047
continue
1046
1048
metric = ALL_METRICS [name ]
1047
1049
if name in THRESHOLDED_METRICS :
1048
- yield check_sample_weight_invariance , name , metric , y_true , y_score
1050
+ yield _named_check (check_sample_weight_invariance , name ), name ,\
1051
+ metric , y_true , y_score
1049
1052
else :
1050
- yield check_sample_weight_invariance , name , metric , y_true , y_pred
1053
+ yield _named_check (check_sample_weight_invariance , name ), name ,\
1054
+ metric , y_true , y_pred
1051
1055
1052
1056
# multiclass
1053
1057
random_state = check_random_state (0 )
@@ -1062,9 +1066,11 @@ def test_sample_weight_invariance(n_samples=50):
1062
1066
continue
1063
1067
metric = ALL_METRICS [name ]
1064
1068
if name in THRESHOLDED_METRICS :
1065
- yield check_sample_weight_invariance , name , metric , y_true , y_score
1069
+ yield _named_check (check_sample_weight_invariance , name ), name ,\
1070
+ metric , y_true , y_score
1066
1071
else :
1067
- yield check_sample_weight_invariance , name , metric , y_true , y_pred
1072
+ yield _named_check (check_sample_weight_invariance , name ), name ,\
1073
+ metric , y_true , y_pred
1068
1074
1069
1075
# multilabel indicator
1070
1076
_ , ya = make_multilabel_classification (n_features = 1 , n_classes = 20 ,
@@ -1084,11 +1090,11 @@ def test_sample_weight_invariance(n_samples=50):
1084
1090
1085
1091
metric = ALL_METRICS [name ]
1086
1092
if name in THRESHOLDED_METRICS :
1087
- yield (check_sample_weight_invariance , name , metric ,
1088
- y_true , y_score )
1093
+ yield (_named_check ( check_sample_weight_invariance , name ), name ,
1094
+ metric , y_true , y_score )
1089
1095
else :
1090
- yield (check_sample_weight_invariance , name , metric ,
1091
- y_true , y_pred )
1096
+ yield (_named_check ( check_sample_weight_invariance , name ), name ,
1097
+ metric , y_true , y_pred )
1092
1098
1093
1099
1094
1100
@ignore_warnings
0 commit comments