diff --git a/doc/modules/array_api.rst b/doc/modules/array_api.rst index e1a499c97506b..ee049937f5ce0 100644 --- a/doc/modules/array_api.rst +++ b/doc/modules/array_api.rst @@ -149,6 +149,7 @@ Metrics - :func:`sklearn.metrics.mean_squared_error` - :func:`sklearn.metrics.mean_squared_log_error` - :func:`sklearn.metrics.mean_tweedie_deviance` +- :func:`sklearn.metrics.median_absolute_error` - :func:`sklearn.metrics.multilabel_confusion_matrix` - :func:`sklearn.metrics.pairwise.additive_chi2_kernel` - :func:`sklearn.metrics.pairwise.chi2_kernel` diff --git a/doc/whats_new/upcoming_changes/sklearn.metrics/31406.enhancement.rst b/doc/whats_new/upcoming_changes/sklearn.metrics/31406.enhancement.rst new file mode 100644 index 0000000000000..4736c67c80132 --- /dev/null +++ b/doc/whats_new/upcoming_changes/sklearn.metrics/31406.enhancement.rst @@ -0,0 +1,2 @@ +- :func:`metrics.median_absolute_error` now supports Array API compatible inputs. + By :user:`Lucy Liu `. diff --git a/sklearn/metrics/_regression.py b/sklearn/metrics/_regression.py index 0731e00ce3a1a..e7435756c52b2 100644 --- a/sklearn/metrics/_regression.py +++ b/sklearn/metrics/_regression.py @@ -19,6 +19,7 @@ from ..utils._array_api import ( _average, _find_matching_floating_dtype, + _median, get_namespace, get_namespace_and_device, size, @@ -915,14 +916,15 @@ def median_absolute_error( >>> median_absolute_error(y_true, y_pred, multioutput=[0.3, 0.7]) 0.85 """ + xp, _ = get_namespace(y_true, y_pred, multioutput, sample_weight) _, y_true, y_pred, sample_weight, multioutput = _check_reg_targets( y_true, y_pred, sample_weight, multioutput ) if sample_weight is None: - output_errors = np.median(np.abs(y_pred - y_true), axis=0) + output_errors = _median(xp.abs(y_pred - y_true), axis=0) else: output_errors = _weighted_percentile( - np.abs(y_pred - y_true), sample_weight=sample_weight + xp.abs(y_pred - y_true), sample_weight=sample_weight ) if isinstance(multioutput, str): if multioutput == "raw_values": @@ -931,7 +933,7 @@ def median_absolute_error( # pass None as weights to np.average: uniform mean multioutput = None - return float(np.average(output_errors, weights=multioutput)) + return float(_average(output_errors, weights=multioutput)) def _assemble_r2_explained_variance( diff --git a/sklearn/metrics/tests/test_common.py b/sklearn/metrics/tests/test_common.py index bad71e29573b8..238ea821d8340 100644 --- a/sklearn/metrics/tests/test_common.py +++ b/sklearn/metrics/tests/test_common.py @@ -2231,6 +2231,10 @@ def check_array_api_metric_pairwise(metric, array_namespace, device, dtype_name) check_array_api_regression_metric, check_array_api_regression_metric_multioutput, ], + median_absolute_error: [ + check_array_api_regression_metric, + check_array_api_regression_metric_multioutput, + ], d2_tweedie_score: [ check_array_api_regression_metric, ], @@ -2275,6 +2279,23 @@ def yield_metric_checker_combinations(metric_checkers=array_api_metric_checkers) ) @pytest.mark.parametrize("metric, check_func", yield_metric_checker_combinations()) def test_array_api_compliance(metric, array_namespace, device, dtype_name, check_func): + # TODO: Remove once array-api-strict > 2.3.1 + # https://github.com/data-apis/array-api-strict/issues/134 has been fixed but + # not released yet. + if ( + getattr(metric, "__name__", None) == "median_absolute_error" + and array_namespace == "array_api_strict" + ): + try: + import array_api_strict + except ImportError: + pass + else: + if device == array_api_strict.Device("device1"): + pytest.xfail( + "`_weighted_percentile` is affected by array_api_strict bug when " + "indexing with tuple of arrays on non-'CPU_DEVICE' devices." + ) check_func(metric, array_namespace, device, dtype_name) diff --git a/sklearn/utils/_array_api.py b/sklearn/utils/_array_api.py index a9f35516f17b6..e2bee3530f26f 100644 --- a/sklearn/utils/_array_api.py +++ b/sklearn/utils/_array_api.py @@ -669,6 +669,30 @@ def _average(a, axis=None, weights=None, normalize=True, xp=None): return sum_ / scale +def _median(x, axis=None, keepdims=False, xp=None): + # XXX: `median` is not included in the array API spec, but is implemented + # in most array libraries, and all that we support (as of May 2025). + # TODO: consider simplifying this code to use scipy instead once the oldest + # supported SciPy version provides `scipy.stats.quantile` with native array API + # support (likely scipy 1.6 at the time of writing). Proper benchmarking of + # either option with popular array namespaces is required to evaluate the + # impact of this choice. + xp, _, device = get_namespace_and_device(x, xp=xp) + + # `torch.median` takes the lower of the two medians when `x` has even number + # of elements, thus we use `torch.quantile(q=0.5)`, which gives mean of the two + if array_api_compat.is_torch_namespace(xp): + return xp.quantile(x, q=0.5, dim=axis, keepdim=keepdims) + + if hasattr(xp, "median"): + return xp.median(x, axis=axis, keepdims=keepdims) + + # Intended mostly for array-api-strict (which as no "median", as per the spec) + # as `_convert_to_numpy` does not necessarily work for all array types. + x_np = _convert_to_numpy(x, xp=xp) + return xp.asarray(numpy.median(x_np, axis=axis, keepdims=keepdims), device=device) + + def _xlogy(x, y, xp=None): # TODO: Remove this once https://github.com/scipy/scipy/issues/21736 is fixed xp, _, device_ = get_namespace_and_device(x, y, xp=xp) diff --git a/sklearn/utils/tests/test_array_api.py b/sklearn/utils/tests/test_array_api.py index 4dfbfd4d62ea1..4d74b0bf8db43 100644 --- a/sklearn/utils/tests/test_array_api.py +++ b/sklearn/utils/tests/test_array_api.py @@ -19,6 +19,7 @@ _is_numpy_namespace, _isin, _max_precision_float_dtype, + _median, _nanmax, _nanmean, _nanmin, @@ -603,3 +604,33 @@ def test_sparse_device(csr_container, dispatch): assert device(a, numpy.array([1])) is None assert get_namespace_and_device(a, b)[2] is None assert get_namespace_and_device(a, numpy.array([1]))[2] is None + + +@pytest.mark.parametrize( + "namespace, device, dtype_name", + yield_namespace_device_dtype_combinations(), + ids=_get_namespace_device_dtype_ids, +) +@pytest.mark.parametrize("axis", [None, 0, 1]) +def test_median(namespace, device, dtype_name, axis): + # Note: depending on the value of `axis`, this test will compare median + # computations on arrays of even (4) or odd (5) numbers of elements, hence + # will test for median computation with and without interpolation to check + # that array API namespaces yield consistent results even when the median is + # not mathematically uniquely defined. + xp = _array_api_for_tests(namespace, device) + rng = numpy.random.RandomState(0) + + X_np = rng.uniform(low=0.0, high=1.0, size=(5, 4)).astype(dtype_name) + result_np = numpy.median(X_np, axis=axis) + + X_xp = xp.asarray(X_np, device=device) + with config_context(array_api_dispatch=True): + result_xp = _median(X_xp, axis=axis) + + if xp.__name__ != "array_api_strict": + # We covert array-api-strict arrays to numpy arrays as `median` is not + # part of the Array API spec + assert get_namespace(result_xp)[0] == xp + assert result_xp.device == X_xp.device + assert_allclose(result_np, _convert_to_numpy(result_xp, xp=xp)) diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index 86bdd07c41f1c..d766ad16545da 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -18,7 +18,13 @@ from .. import get_config as _get_config from ..exceptions import DataConversionWarning, NotFittedError, PositiveSpectrumWarning -from ..utils._array_api import _asarray_with_order, _is_numpy_namespace, get_namespace +from ..utils._array_api import ( + _asarray_with_order, + _is_numpy_namespace, + _max_precision_float_dtype, + get_namespace, + get_namespace_and_device, +) from ..utils.deprecation import _deprecate_force_all_finite from ..utils.fixes import ComplexWarning, _preserve_dia_indices_dtype from ._isfinite import FiniteStatus, cy_isfinite @@ -390,7 +396,8 @@ def _num_samples(x): if not hasattr(x, "__len__") and not hasattr(x, "shape"): if hasattr(x, "__array__"): - x = np.asarray(x) + xp, _ = get_namespace(x) + x = xp.asarray(x) else: raise TypeError(message) @@ -2167,12 +2174,16 @@ def _check_sample_weight( sample_weight : ndarray of shape (n_samples,) Validated sample weight. It is guaranteed to be "C" contiguous. """ - n_samples = _num_samples(X) + xp, _, device = get_namespace_and_device(sample_weight, X) - xp, _ = get_namespace(X) + n_samples = _num_samples(X) - if dtype is not None and dtype not in [xp.float32, xp.float64]: - dtype = xp.float64 + max_float_type = _max_precision_float_dtype(xp, device) + float_dtypes = ( + [xp.float32] if max_float_type == xp.float32 else [xp.float64, xp.float32] + ) + if dtype is not None and dtype not in float_dtypes: + dtype = max_float_type if sample_weight is None: sample_weight = xp.ones(n_samples, dtype=dtype) @@ -2180,7 +2191,7 @@ def _check_sample_weight( sample_weight = xp.full(n_samples, sample_weight, dtype=dtype) else: if dtype is None: - dtype = [xp.float64, xp.float32] + dtype = float_dtypes sample_weight = check_array( sample_weight, accept_sparse=False,