352
352
"samples_precision_score" , "samples_recall_score" ,
353
353
]
354
354
355
+ MULTILABEL_INDICATOR_METRICS_WITH_SAMPLE_WEIGHT = [
356
+ "average_precision_score" ,
357
+ "weighted_average_precision_score" ,
358
+ "micro_average_precision_score" ,
359
+ "macro_average_precision_score" ,
360
+ "samples_average_precision_score" ,
361
+ ]
362
+
355
363
# Regression metrics that support multioutput and weighted samples
356
364
MULTIOUTPUT_METRICS_WITH_SAMPLE_WEIGHT = [
357
365
"mean_squared_error" ,
@@ -2565,7 +2573,7 @@ def test_averaging_multilabel_all_ones():
2565
2573
@ignore_warnings
2566
2574
def check_sample_weight_invariance (name , metric , y1 , y2 ):
2567
2575
rng = np .random .RandomState (0 )
2568
- sample_weight = rng .randint (10 , size = len (y1 ))
2576
+ sample_weight = rng .randint (1 , 10 , size = len (y1 ))
2569
2577
2570
2578
# check that unit weights gives the same score as no weight
2571
2579
unweighted_score = metric (y1 , y2 , sample_weight = None )
@@ -2591,14 +2599,13 @@ def check_sample_weight_invariance(name, metric, y1, y2):
2591
2599
"not equal (%f != %f) for %s" % (
2592
2600
weighted_score , weighted_score_list , name ))
2593
2601
2594
- if not name .startswith ('samples' ):
2595
- # check that integer weights is the same as repeated samples
2596
- repeat_weighted_score = metric (
2597
- np .repeat (y1 , sample_weight , axis = 0 ),
2598
- np .repeat (y2 , sample_weight , axis = 0 ), sample_weight = None )
2599
- assert_almost_equal (
2600
- weighted_score , repeat_weighted_score ,
2601
- err_msg = "Weighting %s is not equal to repeating samples" % name )
2602
+ # check that integer weights is the same as repeated samples
2603
+ repeat_weighted_score = metric (
2604
+ np .repeat (y1 , sample_weight , axis = 0 ),
2605
+ np .repeat (y2 , sample_weight , axis = 0 ), sample_weight = None )
2606
+ assert_almost_equal (
2607
+ weighted_score , repeat_weighted_score ,
2608
+ err_msg = "Weighting %s is not equal to repeating samples" % name )
2602
2609
2603
2610
if not name .startswith ('unnormalized' ):
2604
2611
# check that the score is invariant under scaling of the weights by a
@@ -2612,33 +2619,34 @@ def check_sample_weight_invariance(name, metric, y1, y2):
2612
2619
2613
2620
2614
2621
def test_sample_weight_invariance ():
2615
- # generate some data
2622
+ # binary
2616
2623
y1 , y2 , _ = make_prediction (binary = True )
2617
-
2618
2624
for name in METRICS_WITH_SAMPLE_WEIGHT :
2619
2625
metric = ALL_METRICS [name ]
2620
2626
yield check_sample_weight_invariance , name , metric , y1 , y2
2621
2627
2622
- # multilabel
2628
+ # multilabel sequence
2623
2629
n_classes = 3
2624
2630
n_samples = 10
2625
- _ , y1_multilabel = make_multilabel_classification (
2631
+ _ , y1 = make_multilabel_classification (
2626
2632
n_features = 1 , n_classes = n_classes ,
2627
2633
random_state = 0 , n_samples = n_samples )
2628
- _ , y2_multilabel = make_multilabel_classification (
2634
+ _ , y2 = make_multilabel_classification (
2629
2635
n_features = 1 , n_classes = n_classes ,
2630
2636
random_state = 1 , n_samples = n_samples )
2631
-
2632
2637
for name in MULTILABEL_METRICS_WITH_SAMPLE_WEIGHT :
2633
2638
metric = ALL_METRICS [name ]
2634
- yield (check_sample_weight_invariance ,
2635
- name , metric , y1_multilabel , y2_multilabel )
2639
+ yield (check_sample_weight_invariance , name , metric , y1 , y2 )
2636
2640
2637
- # multioutput
2638
- y1_multioutput = np .array ([[1 , 0 , 0 , 1 ], [0 , 1 , 1 , 1 ], [1 , 1 , 0 , 1 ]])
2639
- y2_multioutput = np .array ([[0 , 0 , 1 , 1 ], [1 , 0 , 1 , 1 ], [1 , 1 , 0 , 1 ]])
2641
+ # multilabel indicator
2642
+ y1 = np .array ([[1 , 0 , 0 , 1 ], [0 , 1 , 1 , 1 ], [1 , 1 , 0 , 1 ]])
2643
+ y2 = np .array ([[0 , 0 , 1 , 1 ], [1 , 0 , 1 , 1 ], [1 , 1 , 0 , 1 ]])
2644
+ for name in MULTILABEL_INDICATOR_METRICS_WITH_SAMPLE_WEIGHT :
2645
+ metric = ALL_METRICS [name ]
2646
+ yield (check_sample_weight_invariance , name , metric , y1 , y2 )
2640
2647
2648
+ # multioutput
2641
2649
for name in MULTIOUTPUT_METRICS_WITH_SAMPLE_WEIGHT :
2642
2650
metric = ALL_METRICS [name ]
2643
2651
yield (check_sample_weight_invariance ,
2644
- name , metric , y1_multioutput , y2_multioutput )
2652
+ name , metric , y1 , y2 )
0 commit comments