8000 use arange sample weight · scikit-learn/scikit-learn@92884bc · GitHub
[go: up one dir, main page]

Skip to content

Commit 92884bc

Browse files
committed
use arange sample weight
1 parent 967212c commit 92884bc

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

sklearn/metrics/tests/test_common.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1444,9 +1444,10 @@ def test_averaging_multilabel_all_ones(name):
14441444
check_averaging(name, y_true, y_true_binarize, y_pred, y_pred_binarize, y_score)
14451445

14461446

1447-
def check_sample_weight_invariance(name, metric, y1, y2):
1447+
def check_sample_weight_invariance(name, metric, y1, y2, sample_weight=None):
14481448
rng = np.random.RandomState(0)
1449-
sample_weight = rng.randint(1, 10, size=len(y1))
1449+
if sample_weight is None:
1450+
sample_weight = rng.randint(1, 10, size=len(y1))
14501451

14511452
# top_k_accuracy_score always lead to a perfect score for k > 1 in the
14521453
# binary case
@@ -1550,13 +1551,14 @@ def check_sample_weight_invariance(name, metric, y1, y2):
15501551
)
15511552
def test_regression_sample_weight_invariance(name):
15521553
n_samples = 51
1553-
random_state = check_random_state(1)
1554+
random_state = check_random_state(0)
15541555
# regression
15551556
y_true = random_state.random_sample(size=(n_samples,))
15561557
y_pred = random_state.random_sample(size=(n_samples,))
1558+
sample_weight = np.arange(len(y_true))
15571559
metric = ALL_METRICS[name]
15581560

1559-
check_sample_weight_invariance(name, metric, y_true, y_pred)
1561+
check_sample_weight_invariance(name, metric, y_true, y_pred, sample_weight)
15601562

15611563

15621564
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)
0