187
187
# When you add a new metric or functionality, check if a general test
188
188
# is already written.
189
189
190
- # Metric undefined with "binary" or "multiclass" input
191
- METRIC_UNDEFINED_MULTICLASS = [
192
- "samples_f0.5_score" , "samples_f1_score" , "samples_f2_score" ,
193
- "samples_precision_score" , "samples_recall_score" ,
190
+ # Those metrics don't support binary inputs
191
+ METRIC_UNDEFINED_BINARY = [
192
+ "samples_f0.5_score" ,
193
+ "samples_f1_score" ,
194
+ "samples_f2_score" ,
195
+ "samples_precision_score" ,
196
+ "samples_recall_score" ,
197
+ "coverage_error" ,
194
198
195
- # Those metrics don't support multiclass outputs
196
- "average_precision_score" , "weighted_average_precision_score" ,
197
- "micro_average_precision_score" , "macro_average_precision_score" ,
199
+ "roc_auc_score" ,
200
+ "micro_roc_auc" ,
201
+ "weighted_roc_auc" ,
202
+ "macro_roc_auc" ,
203
+ "samples_roc_auc" ,
204
+
205
+ "average_precision_score" ,
206
+ "weighted_average_precision_score" ,
207
+ "micro_average_precision_score" ,
208
+ "macro_average_precision_score" ,
198
209
"samples_average_precision_score" ,
199
210
211
+ "label_ranking_loss" ,
200
212
"label_ranking_average_precision_score" ,
213
+ ]
201
214
202
- "roc_auc_score" , "micro_roc_auc" , "weighted_roc_auc" ,
203
- "macro_roc_auc" , "samples_roc_auc" ,
204
-
205
- "coverage_error" ,
215
+ # Those metrics don't support multiclass inputs
216
+ METRIC_UNDEFINED_MULTICLASS = [
206
217
"brier_score_loss" ,
207
- "label_ranking_loss " ,
218
+ "matthews_corrcoef_score " ,
208
219
]
209
220
221
+ # Metric undefined with "binary" or "multiclass" input
222
+ METRIC_UNDEFINED_BINARY_MULTICLASS = set (METRIC_UNDEFINED_BINARY ).union (
223
+ set (METRIC_UNDEFINED_MULTICLASS ))
224
+
210
225
# Metrics with an "average" argument
211
226
METRICS_WITH_AVERAGING = [
212
227
"precision_score" , "recall_score" , "f1_score" , "f2_score" , "f0.5_score"
346
361
METRICS_WITHOUT_SAMPLE_WEIGHT = [
347
362
"cohen_kappa_score" ,
348
363
"confusion_matrix" ,
349
- "matthews_corrcoef_score" ,
350
364
"median_absolute_error" ,
351
365
]
352
366
@@ -359,10 +373,9 @@ def test_symmetry():
359
373
y_pred = random_state .randint (0 , 2 , size = (20 , ))
360
374
361
375
# We shouldn't forget any metrics
362
- assert_equal (set (SYMMETRIC_METRICS ).union (NOT_SYMMETRIC_METRICS ,
363
- THRESHOLDED_METRICS ,
364
- METRIC_UNDEFINED_MULTICLASS ),
365
- set (ALL_METRICS ))
376
+ assert_equal (set (SYMMETRIC_METRICS ).union (
377
+ NOT_SYMMETRIC_METRICS , THRESHOLDED_METRICS ,
378
+ METRIC_UNDEFINED_BINARY_MULTICLASS ), set (ALL_METRICS ))
366
379
367
380
assert_equal (
368
381
set (SYMMETRIC_METRICS ).intersection (set (NOT_SYMMETRIC_METRICS )),
@@ -390,7 +403,7 @@ def test_sample_order_invariance():
390
403
y_true_shuffle , y_pred_shuffle = shuffle (y_true , y_pred , random_state = 0 )
391
404
392
405
for name , metric in ALL_METRICS .items ():
393
- if name in METRIC_UNDEFINED_MULTICLASS :
406
+ if name in METRIC_UNDEFINED_BINARY_MULTICLASS :
394
407
continue
395
408
396
409
assert_almost_equal (metric (y_true , y_pred ),
@@ -457,7 +470,7 @@ def test_format_invariance_with_1d_vectors():
457
470
<
F438
code class="diff-text syntax-highlighted-line"> y2_row = np .reshape (y2_1d , (1 , - 1 ))
458
471
459
472
for name , metric in ALL_METRICS .items ():
460
- if name in METRIC_UNDEFINED_MULTICLASS :
473
+ if name in METRIC_UNDEFINED_BINARY_MULTICLASS :
461
474
continue
462
475
463
476
measure = metric (y1 , y2 )
@@ -532,7 +545,7 @@ def test_invariance_string_vs_numbers_labels():
532
545
labels_str = ["eggs" , "spam" ]
533
546
534
547
for name , metric in CLASSIFICATION_METRICS .items ():
535
- if name in METRIC_UNDEFINED_MULTICLASS :
548
+ if name in METRIC_UNDEFINED_BINARY_MULTICLASS :
536
549
continue
537
550
538
551
measure_with_number = metric (y1 , y2 )
@@ -613,7 +626,8 @@ def check_single_sample_multioutput(name):
613
626
614
627
def test_single_sample ():
615
628
for name in ALL_METRICS :
616
- if name in METRIC_UNDEFINED_MULTICLASS or name in THRESHOLDED_METRICS :
629
+ if (name in METRIC_UNDEFINED_BINARY_MULTICLASS or
630
+ name in THRESHOLDED_METRICS ):
617
631
# Those metrics are not always defined with one sample
618
632
# or in multiclass classification
619
633
continue
@@ -915,9 +929,9 @@ def check_sample_weight_invariance(name, metric, y1, y2):
915
929
sample_weight = sample_weight .tolist ())
916
930
assert_almost_equal (
917
931
weighted_score , weighted_score_list ,
918
- err_msg = "Weighted scores for array and list sample_weight input are "
919
- " not equal (%f != %f) for %s" % (
920
- weighted_score , weighted_score_list , name ))
932
+ err_msg = ( "Weighted scores for array and list "
933
+ "sample_weight input are not equal (%f != %f) for %s") % (
934
+ weighted_score , weighted_score_list , name ))
921
935
922
936
# check that integer weights is the same as repeated samples
923
937
repeat_weighted_score = metric (
@@ -963,14 +977,14 @@ def check_sample_weight_invariance(name, metric, y1, y2):
963
977
def test_sample_weight_invariance (n_samples = 50 ):
964
978
random_state = check_random_state (0 )
965
979
966
- # binary output
980
+ # binary
967
981
random_state = check_random_state (0 )
968
982
y_true = random_state .randint (0 , 2 , size = (n_samples , ))
969
983
y_pred = random_state .randint (0 , 2 , size = (n_samples , ))
970
984
y_score = random_state .random_sample (size = (n_samples ,))
971
985
for name in ALL_METRICS :
972
986
if (name in METRICS_WITHOUT_SAMPLE_WEIGHT or
973
- name in METRIC_UNDEFINED_MULTICLASS ):
987
+ name in METRIC_UNDEFINED_BINARY ):
974
988
continue
975
989
metric = ALL_METRICS [name ]
976
990
if name in THRESHOLDED_METRICS :
@@ -985,7 +999,7 @@ def test_sample_weight_invariance(n_samples=50):
985
999
y_score = random_state .random_sample (size = (n_samples , 5 ))
986
1000
for name in ALL_METRICS :
987
1001
if (name in METRICS_WITHOUT_SAMPLE_WEIGHT or
988
- name in METRIC_UNDEFINED_MULTICLASS ):
1002
+ name in METRIC_UNDEFINED_BINARY_MULTICLASS ):
989
1003
continue
990
1004
metric = ALL_METRICS [name ]
991
1005
if name in THRESHOLDED_METRICS :
0 commit comments