8000 Add array API support to `median_absolute_error` by lucyleeow · Pull Request #31406 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

Add array API support to median_absolute_error #31406

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 28 commits into from
Jun 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/modules/array_api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ Metrics
- :func:`sklearn.metrics.mean_squared_error`
- :func:`sklearn.metrics.mean_squared_log_error`
- :func:`sklearn.metrics.mean_tweedie_deviance`
- :func:`sklearn.metrics.median_absolute_error`
- :func:`sklearn.metrics.multilabel_confusion_matrix`
- :func:`sklearn.metrics.pairwise.additive_chi2_kernel`
- :func:`sklearn.metrics.pairwise.chi2_kernel`
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
- :func:`metrics.median_absolute_error` now supports Array API compatible inputs.
By :user:`Lucy Liu <lucyleeow>`.
8 changes: 5 additions & 3 deletions sklearn/metrics/_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from ..utils._array_api import (
_average,
_find_matching_floating_dtype,
_median,
get_namespace,
get_namespace_and_device,
size,
Expand Down Expand Up @@ -915,14 +916,15 @@ def median_absolute_error(
>>> median_absolute_error(y_true, y_pred, multioutput=[0.3, 0.7])
0.85
"""
xp, _ = get_namespace(y_true, y_pred, multioutput, sample_weight)
_, y_true, y_pred, sample_weight, multioutput = _check_reg_targets(
y_true, y_pred, sample_weight, multioutput
)
if sample_weight is None:
output_errors = np.median(np.abs(y_pred - y_true), axis=0)
output_errors = _median(xp.abs(y_pred - y_true), axis=0)
else:
output_errors = _weighted_percentile(
np.abs(y_pred - y_true), sample_weight=sample_weight
xp.abs(y_pred - y_true), sample_weight=sample_weight
)
if isinstance(multioutput, str):
if multioutput == "raw_values":
Expand All @@ -931,7 +933,7 @@ def median_absolute_error(
# pass None as weights to np.average: uniform mean
multioutput = None

return float(np.average(output_errors, weights=multioutput))
return float(_average(output_errors, weights=multioutput))


def _assemble_r2_explained_variance(
Expand Down
21 changes: 21 additions & 0 deletions sklearn/metrics/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2231,6 +2231,10 @@ def check_array_api_metric_pairwise(metric, array_namespace, device, dtype_name)
check_array_api_regression_metric,
check_array_api_regression_metric_multioutput,
],
median_absolute_error: [
check_array_api_regression_metric,
check_array_api_regression_metric_multioutput,
],
d2_tweedie_score: [
check_array_api_regression_metric,
],
Expand Down Expand Up @@ -2275,6 +2279,23 @@ def yield_metric_checker_combinations(metric_checkers=array_api_metric_checkers)
)
@pytest.mark.parametrize("metric, check_func", yield_metric_checker_combinations())
def test_array_api_compliance(metric, array_namespace, device, dtype_name, check_func):
# TODO: Remove once array-api-strict > 2.3.1
# https://github.com/data-apis/array-api-strict/issues/134 has been fixed but
# not released yet.
if (
getattr(metric, "__name__", None) == "median_absolute_error"
and array_namespace == "array_api_strict"
):
try:
import array_api_strict
except ImportError:
pass
else:
if device == array_api_strict.Device("device1"):
pytest.xfail(
"`_weighted_percentile` is affected by array_api_strict bug when "
"indexing with tuple of arrays on non-'CPU_DEVICE' devices."
)
check_func(metric, array_namespace, device, dtype_name)


Expand Down
24 changes: 24 additions & 0 deletions sklearn/utils/_array_api.py
6D40
Original file line number Diff line number Diff line change
Expand Up @@ -669,6 +669,30 @@ def _average(a, axis=None, weights=None, normalize=True, xp=None):
return sum_ / scale


def _median(x, axis=None, keepdims=False, xp=None):
# XXX: `median` is not included in the array API spec, but is implemented
# in most array libraries, and all that we support (as of May 2025).
# TODO: consider simplifying this code to use scipy instead once the oldest
# supported SciPy version provides `scipy.stats.quantile` with native array API
# support (likely scipy 1.6 at the time of writing). Proper benchmarking of
# either option with popular array namespaces is required to evaluate the
# impact of this choice.
xp, _, device = get_namespace_and_device(x, xp=xp)

# `torch.median` takes the lower of the two medians when `x` has even number
# of elements, thus we use `torch.quantile(q=0.5)`, which gives mean of the two
if array_api_compat.is_torch_namespace(xp):
return xp.quantile(x, q=0.5, dim=axis, keepdim=keepdims)

if hasattr(xp, "median"):
return xp.median(x, axis=axis, keepdims=keepdims)

# Intended mostly for array-api-strict (which as no "median", as per the spec)
# as `_convert_to_numpy` does not necessarily work for all array types.
x_np = _convert_to_numpy(x, xp=xp)
return xp.asarray(numpy.median(x_np, axis=axis, keepdims=keepdims), device=device)


def _xlogy(x, y, xp=None):
# TODO: Remove this once https://github.com/scipy/scipy/issues/21736 is fixed
xp, _, device_ = get_namespace_and_device(x, y, xp=xp)
Expand Down
31 changes: 31 additions & 0 deletions sklearn/utils/tests/test_array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
_is_numpy_namespace,
_isin,
_max_precision_float_dtype,
_median,
_nanmax,
_nanmean,
_nanmin,
Expand Down Expand Up @@ -603,3 +604,33 @@ def test_sparse_device(csr_container, dispatch):
assert device(a, numpy.array([1])) is None
assert get_namespace_and_device(a, b)[2] is None
assert get_namespace_and_device(a, numpy.array([1]))[2] is None


@pytest.mark.parametrize(
"namespace, device, dtype_name",
yield_namespace_device_dtype_combinations(),
ids=_get_namespace_device_dtype_ids,
)
@pytest.mark.parametrize("axis", [None, 0, 1])
def test_median(namespace, device, dtype_name, axis):
# Note: depending on the value of `axis`, this test will compare median
# computations on arrays of even (4) or odd (5) numbers of elements, hence
# will test for median computation with and without interpolation to check
# that array API namespaces yield consistent results even when the median is
# not mathematically uniquely defined.
xp = _array_api_for_tests(namespace, device)
rng = numpy.random.RandomState(0)

X_np = rng.uniform(low=0.0, high=1.0, size=(5, 4)).astype(dtype_name)
result_np = numpy.median(X_np, axis=axis)

X_xp = xp.asarray(X_np, device=device)
with config_context(array_api_dispatch=True):
result_xp = _median(X_xp, axis=axis)

if xp.__name__ != "array_api_strict":
# We covert array-api-strict arrays to numpy arrays as `median` is not
# part of the Array API spec
assert get_namespace(result_xp)[0] == xp
assert result_xp.device == X_xp.device
assert_allclose(result_np, _convert_to_numpy(result_xp, xp=xp))
25 changes: 18 additions & 7 deletions sklearn/utils/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,13 @@

from .. import get_config as _get_config
from ..exceptions import DataConversionWarning, NotFittedError, PositiveSpectrumWarning
from ..utils._array_api import _asarray_with_order, _is_numpy_namespace, get_namespace
from ..utils._array_api import (
_asarray_with_order,
_is_numpy_namespace,
_max_precision_float_dtype,
get_namespace,
get_namespace_and_device,
)
from ..utils.deprecation import _deprecate_force_all_finite
from ..utils.fixes import ComplexWarning, _preserve_dia_indices_dtype
from ._isfinite import FiniteStatus, cy_isfinite
Expand Down Expand Up @@ -390,7 +396,8 @@ def _num_samples(x):

if not hasattr(x, "__len__") and not hasattr(x, "shape"):
if hasattr(x, "__array__"):
x = np.asarray(x)
xp, _ = get_namespace(x)
x = xp.asarray(x)
else:
raise TypeError(message)

Expand Down Expand Up @@ -2167,20 +2174,24 @@ def _check_sample_weight(
sample_weight : ndarray of shape (n_samples,)
Validated sample weight. It is guaranteed to be "C" contiguous.
"""
n_samples = _num_samples(X)
xp, _, device = get_namespace_and_device(sample_weight, X)

xp, _ = get_namespace(X)
n_samples = _num_samples(X)

if dtype is not None and dtype not in [xp.float32, xp.float64]:
dtype = xp.float64
max_float_type = _max_precision_float_dtype(xp, device)
float_dtypes = (
[xp.float32] if max_float_type == xp.float32 else [xp.float64, xp.float32]
)
if dtype is not None and dtype not in float_dtypes:
dtype = max_float_type

if sample_weight is None:
sample_weight = xp.ones(n_samples, dtype=dtype)
elif isinstance(sample_weight, numbers.Number):
sample_weight = xp.full(n_samples, sample_weight, dtype=dtype)
else:
if dtype is None:
dtype = [xp.float64, xp.float32]
dtype = float_dtypes
sample_weight = check_array(
sample_weight,
accept_sparse=False,
Expand Down
0