@@ -1444,9 +1444,10 @@ def test_averaging_multilabel_all_ones(name):
1444
1444
check_averaging(name, y_true, y_true_binarize, y_pred, y_pred_binarize, y_score)
1445
1445
1446
1446
1447
- def check_sample_weight_invariance(name, metric, y1, y2):
1447
+ def check_sample_weight_invariance(name, metric, y1, y2, sample_weight=None ):
1448
1448
rng = np.random.RandomState(0)
1449
- sample_weight = rng.randint(1, 10, size=len(y1))
1449
+ if sample_weight is None:
1450
+ sample_weight = rng.randint(1, 10, size=len(y1))
1450
1451
1451
1452
# top_k_accuracy_score always lead to a perfect score for k > 1 in the
1452
1453
# binary case
@@ -1550,13 +1551,14 @@ def check_sample_weight_invariance(name, metric, y1, y2):
1550
1551
)
1551
1552
def test_regression_sample_weight_invariance(name):
1552
1553
n_samples = 51
1553
- random_state = check_random_state(1 )
1554
+ random_state = check_random_state(0 )
1554
1555
# regression
1555
1556
y_true = random_state.random_sample(size=(n_samples,))
1556
1557
y_pred = random_state.random_sample(size=(n_samples,))
1558
+ sample_weight = np.arange(len(y_true))
1557
1559
metric = ALL_METRICS[name]
1558
1560
1559
- check_sample_weight_invariance(name, metric, y_true, y_pred)
1561
+ check_sample_weight_invariance(name, metric, y_true, y_pred, sample_weight )
1560
1562
1561
1563
1562
1564
@pytest.mark.parametrize(
0 commit comments