@@ -1013,13 +1013,26 @@ def check_sample_weight_invariance(name, metric, y1, y2):
1013
1013
1014
1014
def test_sample_weight_invariance (n_samples = 50 ):
1015
1015
random_state = check_random_state (0 )
1016
+ # regression
1017
+ y_true = random_state .random_sample (size = (n_samples ,))
1018
+ y_pred = random_state .random_sample (size = (n_samples ,))
1019
+ for name in ALL_METRICS :
1020
+ if name not in REGRESSION_METRICS :
1021
+ continue
1022
+ if name in METRICS_WITHOUT_SAMPLE_WEIGHT :
1023
+ continue
1024
+ metric = ALL_METRICS [name ]
1025
+ yield _named_check (check_sample_weight_invariance , name ), name ,\
1026
+ metric , y_true , y_pred
1016
1027
1017
1028
# binary
1018
1029
random_state = check_random_state (0 )
1019
1030
y_true = random_state .randint (0 , 2 , size<
97EC
/span>= (n_samples , ))
1020
1031
y_pred = random_state .randint (0 , 2 , size = (n_samples , ))
1021
1032
y_score = random_state .random_sample (size = (n_samples ,))
1022
1033
for name in ALL_METRICS :
1034
+ if name in REGRESSION_METRICS :
1035
+ continue
1023
1036
if (name in METRICS_WITHOUT_SAMPLE_WEIGHT or
1024
1037
name in METRIC_UNDEFINED_BINARY ):
1025
1038
continue
@@ -1037,6 +1050,8 @@ def test_sample_weight_invariance(n_samples=50):
1037
1050
y_pred = random_state .randint (0 , 5 , size = (n_samples , ))
1038
1051
y_score = random_state .random_sample (size = (n_samples , 5 ))
1039
1052
for name in ALL_METRICS :
1053
+ if name in REGRESSION_METRICS :
1054
+ continue
1040
1055
if (name in METRICS_WITHOUT_SAMPLE_WEIGHT or
1041
1056
name in METRIC_UNDEFINED_BINARY_MULTICLASS ):
1042
1057
continue
0 commit comments