8000 MNT Use regression data for `check_sample_weight_invariance` test on … · scikit-learn/scikit-learn@ebc1276 · GitHub
[go: up one dir, main page]

Skip to content

Commit ebc1276

Browse files
authored
MNT Use regression data for check_sample_weight_invariance test on multioutput regression metrics (#30829)
1 parent 2b97ac5 commit ebc1276

File tree

1 file changed

+14
-1
lines changed

1 file changed

+14
-1
lines changed

sklearn/metrics/tests/test_common.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1611,7 +1611,7 @@ def test_multiclass_sample_weight_invariance(name):
16111611
@pytest.mark.parametrize(
16121612
"name",
16131613
sorted(
1614-
(MULTILABELS_METRICS | THRESHOLDED_MULTILABEL_METRICS | MULTIOUTPUT_METRICS)
1614+
(MULTILABELS_METRICS | THRESHOLDED_MULTILABEL_METRICS)
16151615
- METRICS_WITHOUT_SAMPLE_WEIGHT
16161616
),
16171617
)
@@ -1638,6 +1638,19 @@ def test_multilabel_sample_weight_invariance(name):
16381638
check_sample_weight_invariance(name, metric, y_true, y_pred)
16391639

16401640

1641+
@pytest.mark.parametrize(
1642+
"name",
1643+
sorted(MULTIOUTPUT_METRICS - METRICS_WITHOUT_SAMPLE_WEIGHT),
1644+
)
1645+
def test_multioutput_sample_weight_invariance(name):
1646+
random_state = check_random_state(0)
1647+
y_true = random_state.uniform(0, 2, size=(20, 5))
1648+
y_pred = random_state.uniform(0, 2, size=(20, 5))
1649+
1650+
metric = ALL_METRICS[name]
1651+
check_sample_weight_invariance(name, metric, y_true, y_pred)
1652+
1653+
16411654
def test_no_averaging_labels():
16421655
# test labels argument when not using averaging
16431656
# in multi-class and multi-label cases

0 commit comments

Comments
 (0)
0