diff --git a/sklearn/metrics/tests/test_common.py b/sklearn/metrics/tests/test_common.py index 9e8d0ce116394..5f44e7b212105 100644 --- a/sklearn/metrics/tests/test_common.py +++ b/sklearn/metrics/tests/test_common.py @@ -1611,7 +1611,7 @@ def test_multiclass_sample_weight_invariance(name): @pytest.mark.parametrize( "name", sorted( - (MULTILABELS_METRICS | THRESHOLDED_MULTILABEL_METRICS | MULTIOUTPUT_METRICS) + (MULTILABELS_METRICS | THRESHOLDED_MULTILABEL_METRICS) - METRICS_WITHOUT_SAMPLE_WEIGHT ), ) @@ -1638,6 +1638,19 @@ def test_multilabel_sample_weight_invariance(name): check_sample_weight_invariance(name, metric, y_true, y_pred) +@pytest.mark.parametrize( + "name", + sorted(MULTIOUTPUT_METRICS - METRICS_WITHOUT_SAMPLE_WEIGHT), +) +def test_multioutput_sample_weight_invariance(name): + random_state = check_random_state(0) + y_true = random_state.uniform(0, 2, size=(20, 5)) + y_pred = random_state.uniform(0, 2, size=(20, 5)) + + metric = ALL_METRICS[name] + check_sample_weight_invariance(name, metric, y_true, y_pred) + + def test_no_averaging_labels(): # test labels argument when not using averaging # in multi-class and multi-label cases