diff --git a/sklearn/metrics/_regression.py b/sklearn/metrics/_regression.py index 65a3073f3691c..40b514835a1bf 100644 --- a/sklearn/metrics/_regression.py +++ b/sklearn/metrics/_regression.py @@ -907,9 +907,9 @@ def median_absolute_error( >>> median_absolute_error(y_true, y_pred, multioutput=[0.3, 0.7]) 0.85 """ - y_type, y_true, y_pred, multioutput = _check_reg_targets( - y_true, y_pred, multioutput - ) + _, y_true, y_pred, multioutput = _check_reg_targets(y_true, y_pred, multioutput) + check_consistent_length(y_true, y_pred, sample_weight) + if sample_weight is None: output_errors = np.median(np.abs(y_pred - y_true), axis=0) else: diff --git a/sklearn/metrics/tests/test_common.py b/sklearn/metrics/tests/test_common.py index 9e8d0ce116394..b91cb7c9a11e5 100644 --- a/sklearn/metrics/tests/test_common.py +++ b/sklearn/metrics/tests/test_common.py @@ -552,7 +552,6 @@ def precision_recall_curve_padded_thresholds(*args, **kwargs): # No Sample weight support METRICS_WITHOUT_SAMPLE_WEIGHT = { - "median_absolute_error", "max_error", "ovo_roc_auc", "weighted_ovo_roc_auc", @@ -1445,9 +1444,10 @@ def test_averaging_multilabel_all_ones(name): check_averaging(name, y_true, y_true_binarize, y_pred, y_pred_binarize, y_score) -def check_sample_weight_invariance(name, metric, y1, y2): +def check_sample_weight_invariance(name, metric, y1, y2, sample_weight=None): rng = np.random.RandomState(0) - sample_weight = rng.randint(1, 10, size=len(y1)) + if sample_weight is None: + sample_weight = rng.randint(1, 10, size=len(y1)) # top_k_accuracy_score always lead to a perfect score for k > 1 in the # binary case @@ -1550,13 +1550,15 @@ def check_sample_weight_invariance(name, metric, y1, y2): ), ) def test_regression_sample_weight_invariance(name): - n_samples = 50 + n_samples = 51 random_state = check_random_state(0) # regression y_true = random_state.random_sample(size=(n_samples,)) y_pred = random_state.random_sample(size=(n_samples,)) + sample_weight = np.arange(len(y_true)) metric = ALL_METRICS[name] - check_sample_weight_invariance(name, metric, y_true, y_pred) + + check_sample_weight_invariance(name, metric, y_true, y_pred, sample_weight) @pytest.mark.parametrize(