From 37e21b0051ff8e8cef86e0bf3b1c09954df0cade Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Mon, 19 May 2025 15:36:25 +1000 Subject: [PATCH 01/27] add median --- sklearn/utils/_array_api.py | 33 +++++++++++++++++++++++++++ sklearn/utils/tests/test_array_api.py | 23 +++++++++++++++++++ 2 files changed, 56 insertions(+) diff --git a/sklearn/utils/_array_api.py b/sklearn/utils/_array_api.py index a9f35516f17b6..64c027a314c70 100644 --- a/sklearn/utils/_array_api.py +++ b/sklearn/utils/_array_api.py @@ -669,6 +669,39 @@ def _average(a, axis=None, weights=None, normalize=True, xp=None): return sum_ / scale +def _median(X, axis=None, xp=None): + ( + xp, + _, + ) = get_namespace(X, xp=xp) + + if _is_numpy_namespace(xp): + return numpy.median(X, axis=axis) + + if X.ndim == 0: + return float(X) + + if axis is None: + X = xp.reshape(X, (-1,)) + axis = 0 + + X_sorted = xp.sort(X, axis=axis) + indexer = [slice(None)] * X.ndim + index = X.shape[axis] // 2 + if X.shape[axis] % 2 == 1: + # index with slice to allow mean (below) to work + indexer[axis] = slice(index, index + 1) + else: + indexer[axis] = slice(index - 1, index + 1) + indexer = tuple(indexer) + + # Use mean in both odd and even case to coerce data type, + # using out array if needed. + rout = xp.mean(X_sorted[indexer], axis=axis) + return rout + # Need to add NaN handling + + def _xlogy(x, y, xp=None): # TODO: Remove this once https://github.com/scipy/scipy/issues/21736 is fixed xp, _, device_ = get_namespace_and_device(x, y, xp=xp) diff --git a/sklearn/utils/tests/test_array_api.py b/sklearn/utils/tests/test_array_api.py index 164e3024a31e7..5aeb4bbe3a33f 100644 --- a/sklearn/utils/tests/test_array_api.py +++ b/sklearn/utils/tests/test_array_api.py @@ -19,6 +19,7 @@ _is_numpy_namespace, _isin, _max_precision_float_dtype, + _median, _nanmax, _nanmean, _nanmin, @@ -599,3 +600,25 @@ def test_sparse_device(csr_container, dispatch): assert device(a, numpy.array([1])) is None assert get_namespace_and_device(a, b)[2] is None assert get_namespace_and_device(a, numpy.array([1]))[2] is None + + +@pytest.mark.parametrize( + "namespace, device, dtype_name", + yield_namespace_device_dtype_combinations(), + ids=_get_namespace_device_dtype_ids, +) +@pytest.mark.parametrize("axis", [None, 0, 1]) +def test_median(namespace, device, dtype_name, axis): + xp = _array_api_for_tests(namespace, device) + rng = numpy.random.RandomState(0) + + X_np = numpy.array(rng.random_sample((5, 4)), dtype=dtype_name) + result_np = numpy.median(X_np, axis=axis) + + X_xp = xp.asarray(X_np, device=device) + with config_context(array_api_dispatch=True): + result_xp = _median(X_xp, axis=axis) + + assert get_namespace(result_xp)[0].__name__ == xp.__name__ + assert result_xp.device == X_xp.device + assert_allclose(result_np, _convert_to_numpy(result_xp, xp=xp)) From f99397b950dafe6e0c4a5663e4a4a705ee6f809a Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Wed, 21 May 2025 10:31:16 +1000 Subject: [PATCH 02/27] amend comment --- sklearn/utils/_array_api.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/sklearn/utils/_array_api.py b/sklearn/utils/_array_api.py index 64c027a314c70..9e640d2167ccb 100644 --- a/sklearn/utils/_array_api.py +++ b/sklearn/utils/_array_api.py @@ -670,10 +670,7 @@ def _average(a, axis=None, weights=None, normalize=True, xp=None): def _median(X, axis=None, xp=None): - ( - xp, - _, - ) = get_namespace(X, xp=xp) + xp, _ = get_namespace(X, xp=xp) if _is_numpy_namespace(xp): return numpy.median(X, axis=axis) @@ -699,7 +696,7 @@ def _median(X, axis=None, xp=None): # using out array if needed. rout = xp.mean(X_sorted[indexer], axis=axis) return rout - # Need to add NaN handling + # `xp.mean` not guaranteed to return nan if nan in input, def _xlogy(x, y, xp=None): From 3940610223ab66789b26b401b283bb4a479f7c03 Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Mon, 26 May 2025 15:48:02 +1000 Subject: [PATCH 03/27] add support --- doc/modules/array_api.rst | 1 + sklearn/metrics/_regression.py | 12 +++++----- sklearn/metrics/tests/test_common.py | 4 ++++ sklearn/utils/_array_api.py | 36 +++++++++++++++++----------- 4 files changed, 33 insertions(+), 20 deletions(-) diff --git a/doc/modules/array_api.rst b/doc/modules/array_api.rst index e1a499c97506b..ee049937f5ce0 100644 --- a/doc/modules/array_api.rst +++ b/doc/modules/array_api.rst @@ -149,6 +149,7 @@ Metrics - :func:`sklearn.metrics.mean_squared_error` - :func:`sklearn.metrics.mean_squared_log_error` - :func:`sklearn.metrics.mean_tweedie_deviance` +- :func:`sklearn.metrics.median_absolute_error` - :func:`sklearn.metrics.multilabel_confusion_matrix` - :func:`sklearn.metrics.pairwise.additive_chi2_kernel` - :func:`sklearn.metrics.pairwise.chi2_kernel` diff --git a/sklearn/metrics/_regression.py b/sklearn/metrics/_regression.py index 4c46346d63d92..5fe5dbd144f31 100644 --- a/sklearn/metrics/_regression.py +++ b/sklearn/metrics/_regression.py @@ -19,6 +19,7 @@ from ..utils._array_api import ( _average, _find_matching_floating_dtype, + _median, get_namespace, get_namespace_and_device, size, @@ -912,15 +913,14 @@ def median_absolute_error( >>> median_absolute_error(y_true, y_pred, multioutput=[0.3, 0.7]) 0.85 """ - y_type, y_true, y_pred, multioutput = _check_reg_targets( - y_true, y_pred, multioutput - ) + xp, _ = get_namespace(y_true, y_pred, multioutput, sample_weight) + _, y_true, y_pred, multioutput = _check_reg_targets(y_true, y_pred, multioutput) if sample_weight is None: - output_errors = np.median(np.abs(y_pred - y_true), axis=0) + output_errors = _median(xp.abs(y_pred - y_true), axis=0) else: sample_weight = _check_sample_weight(sample_weight, y_pred) output_errors = _weighted_percentile( - np.abs(y_pred - y_true), sample_weight=sample_weight + xp.abs(y_pred - y_true), sample_weight=sample_weight ) if isinstance(multioutput, str): if multioutput == "raw_values": @@ -929,7 +929,7 @@ def median_absolute_error( # pass None as weights to np.average: uniform mean multioutput = None - return float(np.average(output_errors, weights=multioutput)) + return float(_average(output_errors, weights=multioutput)) def _assemble_r2_explained_variance( diff --git a/sklearn/metrics/tests/test_common.py b/sklearn/metrics/tests/test_common.py index 00e47f04b5b57..06718573c787a 100644 --- a/sklearn/metrics/tests/test_common.py +++ b/sklearn/metrics/tests/test_common.py @@ -2205,6 +2205,10 @@ def check_array_api_metric_pairwise(metric, array_namespace, device, dtype_name) check_array_api_regression_metric, check_array_api_regression_metric_multioutput, ], + median_absolute_error: [ + check_array_api_regression_metric, + check_array_api_regression_metric_multioutput, + ], d2_tweedie_score: [ check_array_api_regression_metric, ], diff --git a/sklearn/utils/_array_api.py b/sklearn/utils/_array_api.py index 9e640d2167ccb..2ab643f45527f 100644 --- a/sklearn/utils/_array_api.py +++ b/sklearn/utils/_array_api.py @@ -669,23 +669,32 @@ def _average(a, axis=None, weights=None, normalize=True, xp=None): return sum_ / scale -def _median(X, axis=None, xp=None): - xp, _ = get_namespace(X, xp=xp) - - if _is_numpy_namespace(xp): - return numpy.median(X, axis=axis) - - if X.ndim == 0: - return float(X) +def _median(x, axis=None, keepdims=False, xp=None): + # `median` is not included in the Array API spec, but is implemented in most + # array libraries (and all that we test). + xp, _ = get_namespace(x, xp=xp) + if hasattr(xp, "median"): + kwargs = {"axis": axis, "keepdims": keepdims} + if _is_xp_namespace(xp, "torch"): + # torch has no `None` option for `axis` + if axis is None: + x = xp.reshape(x, (-1,)) + # torch named their parameter `keepdim` + kwargs.pop("keepdims") + kwargs["keepdim"] = keepdims + return xp.median(x, **kwargs) + + if x.ndim == 0: + return float(x) if axis is None: - X = xp.reshape(X, (-1,)) + x = xp.reshape(x, (-1,)) axis = 0 - X_sorted = xp.sort(X, axis=axis) - indexer = [slice(None)] * X.ndim - index = X.shape[axis] // 2 - if X.shape[axis] % 2 == 1: + X_sorted = xp.sort(x, axis=axis) + indexer = [slice(None)] * x.ndim + index = x.shape[axis] // 2 + if x.shape[axis] % 2 == 1: # index with slice to allow mean (below) to work indexer[axis] = slice(index, index + 1) else: @@ -696,7 +705,6 @@ def _median(X, axis=None, xp=None): # using out array if needed. rout = xp.mean(X_sorted[indexer], axis=axis) return rout - # `xp.mean` not guaranteed to return nan if nan in input, def _xlogy(x, y, xp=None): From b277ac0cee706d28bf60a3e09daabcd989f71f5f Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Mon, 26 May 2025 15:51:30 +1000 Subject: [PATCH 04/27] add whats new --- .../upcoming_changes/sklearn.metrics/31406.enhancement.rst | 2 ++ 1 file changed, 2 insertions(+) create mode 100644 doc/whats_new/upcoming_changes/sklearn.metrics/31406.enhancement.rst diff --git a/doc/whats_new/upcoming_changes/sklearn.metrics/31406.enhancement.rst b/doc/whats_new/upcoming_changes/sklearn.metrics/31406.enhancement.rst new file mode 100644 index 0000000000000..4736c67c80132 --- /dev/null +++ b/doc/whats_new/upcoming_changes/sklearn.metrics/31406.enhancement.rst @@ -0,0 +1,2 @@ +- :func:`metrics.median_absolute_error` now supports Array API compatible inputs. + By :user:`Lucy Liu `. From dba4b7f006723aadbbeafb7ab60b09740cd0fa8b Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Mon, 26 May 2025 22:42:44 +1000 Subject: [PATCH 05/27] use quantile for torch --- sklearn/utils/_array_api.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/sklearn/utils/_array_api.py b/sklearn/utils/_array_api.py index 2ab643f45527f..f1595959bde2c 100644 --- a/sklearn/utils/_array_api.py +++ b/sklearn/utils/_array_api.py @@ -674,15 +674,11 @@ def _median(x, axis=None, keepdims=False, xp=None): # array libraries (and all that we test). xp, _ = get_namespace(x, xp=xp) if hasattr(xp, "median"): - kwargs = {"axis": axis, "keepdims": keepdims} + # When `x` has even number of elements, `torch.median` takes the lower of the + # two medians, thus we use `torch.quantile(q=0.5)`, which gives mean of the two if _is_xp_namespace(xp, "torch"): - # torch has no `None` option for `axis` - if axis is None: - x = xp.reshape(x, (-1,)) - # torch named their parameter `keepdim` - kwargs.pop("keepdims") - kwargs["keepdim"] = keepdims - return xp.median(x, **kwargs) + return xp.quantile(x, q=0.5, dim=axis, keepdim=keepdims) + return xp.median(x, axis=axis, keepdims=keepdims) if x.ndim == 0: return float(x) From 335a91d01dad0c2947563cbf8da300943b9acaa3 Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Mon, 26 May 2025 23:57:26 +1000 Subject: [PATCH 06/27] add support for helpers --- sklearn/metrics/tests/test_common.py | 213 +++++++++++++-------------- sklearn/utils/validation.py | 27 +++- 2 files changed, 120 insertions(+), 120 deletions(-) diff --git a/sklearn/metrics/tests/test_common.py b/sklearn/metrics/tests/test_common.py index 06718573c787a..490029edf49de 100644 --- a/sklearn/metrics/tests/test_common.py +++ b/sklearn/metrics/tests/test_common.py @@ -56,19 +56,6 @@ zero_one_loss, ) from sklearn.metrics._base import _average_binary_score -from sklearn.metrics.pairwise import ( - additive_chi2_kernel, - chi2_kernel, - cosine_distances, - cosine_similarity, - euclidean_distances, - linear_kernel, - paired_cosine_distances, - paired_euclidean_distances, - polynomial_kernel, - rbf_kernel, - sigmoid_kernel, -) from sklearn.preprocessing import LabelBinarizer from sklearn.utils import shuffle from sklearn.utils._array_api import ( @@ -2133,110 +2120,110 @@ def check_array_api_metric_pairwise(metric, array_namespace, device, dtype_name) array_api_metric_checkers = { - accuracy_score: [ - check_array_api_binary_classification_metric, - check_array_api_multiclass_classification_metric, - check_array_api_multilabel_classification_metric, - ], - f1_score: [ - check_array_api_binary_classification_metric, - check_array_api_multiclass_classification_metric, - check_array_api_multilabel_classification_metric, - ], - fbeta_score: [ - check_array_api_multiclass_classification_metric, - check_array_api_multilabel_classification_metric, - ], - jaccard_score: [ - check_array_api_binary_classification_metric, - check_array_api_multiclass_classification_metric, - check_array_api_multilabel_classification_metric, - ], - multilabel_confusion_matrix: [ - check_array_api_binary_classification_metric, - check_array_api_multiclass_classification_metric, - check_array_api_multilabel_classification_metric, - ], - precision_score: [ - check_array_api_binary_classification_metric, - check_array_api_multiclass_classification_metric, - check_array_api_multilabel_classification_metric, - ], - recall_score: [ - check_array_api_binary_classification_metric, - check_array_api_multiclass_classification_metric, - check_array_api_multilabel_classification_metric, - ], - zero_one_loss: [ - check_array_api_binary_classification_metric, - check_array_api_multiclass_classification_metric, - check_array_api_multilabel_classification_metric, - ], - hamming_loss: [ - check_array_api_binary_classification_metric, - check_array_api_multiclass_classification_metric, - check_array_api_multilabel_classification_metric, - ], - mean_tweedie_deviance: [check_array_api_regression_metric], - partial(mean_tweedie_deviance, power=-0.5): [check_array_api_regression_metric], - partial(mean_tweedie_deviance, power=1.5): [check_array_api_regression_metric], - r2_score: [ - check_array_api_regression_metric, - check_array_api_regression_metric_multioutput, - ], - cosine_similarity: [check_array_api_metric_pairwise], - explained_variance_score: [ - check_array_api_regression_metric, - check_array_api_regression_metric_multioutput, - ], - mean_absolute_error: [ - check_array_api_regression_metric, - check_array_api_regression_metric_multioutput, - ], - mean_pinball_loss: [ - check_array_api_regression_metric, - check_array_api_regression_metric_multioutput, - ], - mean_squared_error: [ - check_array_api_regression_metric, - check_array_api_regression_metric_multioutput, - ], - mean_squared_log_error: [ - check_array_api_regression_metric, - check_array_api_regression_metric_multioutput, - ], + # accuracy_score: [ + # check_array_api_binary_classification_metric, + # check_array_api_multiclass_classification_metric, + # check_array_api_multilabel_classification_metric, + # ], + # f1_score: [ + # check_array_api_binary_classification_metric, + # check_array_api_multiclass_classification_metric, + # check_array_api_multilabel_classification_metric, + # ], + # fbeta_score: [ + # check_array_api_multiclass_classification_metric, + # check_array_api_multilabel_classification_metric, + # ], + # jaccard_score: [ + # check_array_api_binary_classification_metric, + # check_array_api_multiclass_classification_metric, + # check_array_api_multilabel_classification_metric, + # ], + # multilabel_confusion_matrix: [ + # check_array_api_binary_classification_metric, + # check_array_api_multiclass_classification_metric, + # check_array_api_multilabel_classification_metric, + # ], + # precision_score: [ + # check_array_api_binary_classification_metric, + # check_array_api_multiclass_classification_metric, + # check_array_api_multilabel_classification_metric, + # ], + # recall_score: [ + # check_array_api_binary_classification_metric, + # check_array_api_multiclass_classification_metric, + # check_array_api_multilabel_classification_metric, + # ], + # zero_one_loss: [ + # check_array_api_binary_classification_metric, + # check_array_api_multiclass_classification_metric, + # check_array_api_multilabel_classification_metric, + # ], + # hamming_loss: [ + # check_array_api_binary_classification_metric, + # check_array_api_multiclass_classification_metric, + # check_array_api_multilabel_classification_metric, + # ], + # mean_tweedie_deviance: [check_array_api_regression_metric], + # partial(mean_tweedie_deviance, power=-0.5): [check_array_api_regression_metric], + # partial(mean_tweedie_deviance, power=1.5): [check_array_api_regression_metric], + # r2_score: [ + # check_array_api_regression_metric, + # check_array_api_regression_metric_multioutput, + # ], + # cosine_similarity: [check_array_api_metric_pairwise], + # explained_variance_score: [ + # check_array_api_regression_metric, + # check_array_api_regression_metric_multioutput, + # ], + # mean_absolute_error: [ + # check_array_api_regression_metric, + # check_array_api_regression_metric_multioutput, + # ], + # mean_pinball_loss: [ + # check_array_api_regression_metric, + # check_array_api_regression_metric_multioutput, + # ], + # mean_squared_error: [ + # check_array_api_regression_metric, + # check_array_api_regression_metric_multioutput, + # ], + # mean_squared_log_error: [ + # check_array_api_regression_metric, + # check_array_api_regression_metric_multioutput, + # ], median_absolute_error: [ check_array_api_regression_metric, check_array_api_regression_metric_multioutput, ], - d2_tweedie_score: [ - check_array_api_regression_metric, - ], - paired_cosine_distances: [check_array_api_metric_pairwise], - mean_poisson_deviance: [check_array_api_regression_metric], - additive_chi2_kernel: [check_array_api_metric_pairwise], - mean_gamma_deviance: [check_array_api_regression_metric], - max_error: [check_array_api_regression_metric], - mean_absolute_percentage_error: [ - check_array_api_regression_metric, - check_array_api_regression_metric_multioutput, - ], - chi2_kernel: [check_array_api_metric_pairwise], - paired_euclidean_distances: [check_array_api_metric_pairwise], - cosine_distances: [check_array_api_metric_pairwise], - euclidean_distances: [check_array_api_metric_pairwise], - linear_kernel: [check_array_api_metric_pairwise], - polynomial_kernel: [check_array_api_metric_pairwise], - rbf_kernel: [check_array_api_metric_pairwise], - root_mean_squared_error: [ - check_array_api_regression_metric, - check_array_api_regression_metric_multioutput, - ], - root_mean_squared_log_error: [ - check_array_api_regression_metric, - check_array_api_regression_metric_multioutput, - ], - sigmoid_kernel: [check_array_api_metric_pairwise], + # d2_tweedie_score: [ + # check_array_api_regression_metric, + # ], + # paired_cosine_distances: [check_array_api_metric_pairwise], + # mean_poisson_deviance: [check_array_api_regression_metric], + # additive_chi2_kernel: [check_array_api_metric_pairwise], + # mean_gamma_deviance: [check_array_api_regression_metric], + # max_error: [check_array_api_regression_metric], + # mean_absolute_percentage_error: [ + # check_array_api_regression_metric, + # check_array_api_regression_metric_multioutput, + # ], + # chi2_kernel: [check_array_api_metric_pairwise], + # paired_euclidean_distances: [check_array_api_metric_pairwise], + # cosine_distances: [check_array_api_metric_pairwise], + # euclidean_distances: [check_array_api_metric_pairwise], + # linear_kernel: [check_array_api_metric_pairwise], + # polynomial_kernel: [check_array_api_metric_pairwise], + # rbf_kernel: [check_array_api_metric_pairwise], + # root_mean_squared_error: [ + # check_array_api_regression_metric, + # check_array_api_regression_metric_multioutput, + # ], + # root_mean_squared_log_error: [ + # check_array_api_regression_metric, + # check_array_api_regression_metric_multioutput, + # ], + # sigmoid_kernel: [check_array_api_metric_pairwise], } diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index 324827323168a..19b6b1967d06d 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -18,7 +18,13 @@ from .. import get_config as _get_config from ..exceptions import DataConversionWarning, NotFittedError, PositiveSpectrumWarning -from ..utils._array_api import _asarray_with_order, _is_numpy_namespace, get_namespace +from ..utils._array_api import ( + _asarray_with_order, + _is_numpy_namespace, + _max_precision_float_dtype, + get_namespace, + get_namespace_and_device, +) from ..utils.deprecation import _deprecate_force_all_finite from ..utils.fixes import ComplexWarning, _preserve_dia_indices_dtype from ._isfinite import FiniteStatus, cy_isfinite @@ -388,9 +394,10 @@ def _num_samples(x): if _use_interchange_protocol(x): return x.__dataframe__().num_rows() + xp, _ = get_namespace(x) if not hasattr(x, "__len__") and not hasattr(x, "shape"): if hasattr(x, "__array__"): - x = np.asarray(x) + x = xp.asarray(x) else: raise TypeError(message) @@ -2167,18 +2174,24 @@ def _check_sample_weight( sample_weight : ndarray of shape (n_samples,) Validated sample weight. It is guaranteed to be "C" contiguous. """ + xp, _, device = get_namespace_and_device(sample_weight, X) + n_samples = _num_samples(X) - if dtype is not None and dtype not in [np.float32, np.float64]: - dtype = np.float64 + max_float_type = _max_precision_float_dtype(xp, device) + float_dtypes = ( + [xp.float32] if max_float_type == xp.float32 else [xp.float64, xp.float32] + ) + if dtype is not None and dtype not in float_dtypes: + dtype = max_float_type if sample_weight is None: - sample_weight = np.ones(n_samples, dtype=dtype) + sample_weight = xp.ones(n_samples, dtype=dtype) elif isinstance(sample_weight, numbers.Number): - sample_weight = np.full(n_samples, sample_weight, dtype=dtype) + sample_weight = xp.full(n_samples, sample_weight, dtype=dtype) else: if dtype is None: - dtype = [np.float64, np.float32] + dtype = float_dtypes sample_weight = check_array( sample_weight, accept_sparse=False, From e708595c8c6cb4527b05ee1a3c18c94b1c41ac69 Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Wed, 28 May 2025 21:47:31 +1000 Subject: [PATCH 07/27] fix --- sklearn/utils/validation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index 19b6b1967d06d..d766ad16545da 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -394,9 +394,9 @@ def _num_samples(x): if _use_interchange_protocol(x): return x.__dataframe__().num_rows() - xp, _ = get_namespace(x) if not hasattr(x, "__len__") and not hasattr(x, "shape"): if hasattr(x, "__array__"): + xp, _ = get_namespace(x) x = xp.asarray(x) else: raise TypeError(message) From 274f1c63917f21d19b3e44c65314388027c8f57c Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Wed, 28 May 2025 21:50:26 +1000 Subject: [PATCH 08/27] uncomment --- sklearn/metrics/tests/test_common.py | 213 ++++++++++++++------------- 1 file changed, 113 insertions(+), 100 deletions(-) diff --git a/sklearn/metrics/tests/test_common.py b/sklearn/metrics/tests/test_common.py index 490029edf49de..06718573c787a 100644 --- a/sklearn/metrics/tests/test_common.py +++ b/sklearn/metrics/tests/test_common.py @@ -56,6 +56,19 @@ zero_one_loss, ) from sklearn.metrics._base import _average_binary_score +from sklearn.metrics.pairwise import ( + additive_chi2_kernel, + chi2_kernel, + cosine_distances, + cosine_similarity, + euclidean_distances, + linear_kernel, + paired_cosine_distances, + paired_euclidean_distances, + polynomial_kernel, + rbf_kernel, + sigmoid_kernel, +) from sklearn.preprocessing import LabelBinarizer from sklearn.utils import shuffle from sklearn.utils._array_api import ( @@ -2120,110 +2133,110 @@ def check_array_api_metric_pairwise(metric, array_namespace, device, dtype_name) array_api_metric_checkers = { - # accuracy_score: [ - # check_array_api_binary_classification_metric, - # check_array_api_multiclass_classification_metric, - # check_array_api_multilabel_classification_metric, - # ], - # f1_score: [ - # check_array_api_binary_classification_metric, - # check_array_api_multiclass_classification_metric, - # check_array_api_multilabel_classification_metric, - # ], - # fbeta_score: [ - # check_array_api_multiclass_classification_metric, - # check_array_api_multilabel_classification_metric, - # ], - # jaccard_score: [ - # check_array_api_binary_classification_metric, - # check_array_api_multiclass_classification_metric, - # check_array_api_multilabel_classification_metric, - # ], - # multilabel_confusion_matrix: [ - # check_array_api_binary_classification_metric, - # check_array_api_multiclass_classification_metric, - # check_array_api_multilabel_classification_metric, - # ], - # precision_score: [ - # check_array_api_binary_classification_metric, - # check_array_api_multiclass_classification_metric, - # check_array_api_multilabel_classification_metric, - # ], - # recall_score: [ - # check_array_api_binary_classification_metric, - # check_array_api_multiclass_classification_metric, - # check_array_api_multilabel_classification_metric, - # ], - # zero_one_loss: [ - # check_array_api_binary_classification_metric, - # check_array_api_multiclass_classification_metric, - # check_array_api_multilabel_classification_metric, - # ], - # hamming_loss: [ - # check_array_api_binary_classification_metric, - # check_array_api_multiclass_classification_metric, - # check_array_api_multilabel_classification_metric, - # ], - # mean_tweedie_deviance: [check_array_api_regression_metric], - # partial(mean_tweedie_deviance, power=-0.5): [check_array_api_regression_metric], - # partial(mean_tweedie_deviance, power=1.5): [check_array_api_regression_metric], - # r2_score: [ - # check_array_api_regression_metric, - # check_array_api_regression_metric_multioutput, - # ], - # cosine_similarity: [check_array_api_metric_pairwise], - # explained_variance_score: [ - # check_array_api_regression_metric, - # check_array_api_regression_metric_multioutput, - # ], - # mean_absolute_error: [ - # check_array_api_regression_metric, - # check_array_api_regression_metric_multioutput, - # ], - # mean_pinball_loss: [ - # check_array_api_regression_metric, - # check_array_api_regression_metric_multioutput, - # ], - # mean_squared_error: [ - # check_array_api_regression_metric, - # check_array_api_regression_metric_multioutput, - # ], - # mean_squared_log_error: [ - # check_array_api_regression_metric, - # check_array_api_regression_metric_multioutput, - # ], + accuracy_score: [ + check_array_api_binary_classification_metric, + check_array_api_multiclass_classification_metric, + check_array_api_multilabel_classification_metric, + ], + f1_score: [ + check_array_api_binary_classification_metric, + check_array_api_multiclass_classification_metric, + check_array_api_multilabel_classification_metric, + ], + fbeta_score: [ + check_array_api_multiclass_classification_metric, + check_array_api_multilabel_classification_metric, + ], + jaccard_score: [ + check_array_api_binary_classification_metric, + check_array_api_multiclass_classification_metric, + check_array_api_multilabel_classification_metric, + ], + multilabel_confusion_matrix: [ + check_array_api_binary_classification_metric, + check_array_api_multiclass_classification_metric, + check_array_api_multilabel_classification_metric, + ], + precision_score: [ + check_array_api_binary_classification_metric, + check_array_api_multiclass_classification_metric, + check_array_api_multilabel_classification_metric, + ], + recall_score: [ + check_array_api_binary_classification_metric, + check_array_api_multiclass_classification_metric, + check_array_api_multilabel_classification_metric, + ], + zero_one_loss: [ + check_array_api_binary_classification_metric, + check_array_api_multiclass_classification_metric, + check_array_api_multilabel_classification_metric, + ], + hamming_loss: [ + check_array_api_binary_classification_metric, + check_array_api_multiclass_classification_metric, + check_array_api_multilabel_classification_metric, + ], + mean_tweedie_deviance: [check_array_api_regression_metric], + partial(mean_tweedie_deviance, power=-0.5): [check_array_api_regression_metric], + partial(mean_tweedie_deviance, power=1.5): [check_array_api_regression_metric], + r2_score: [ + check_array_api_regression_metric, + check_array_api_regression_metric_multioutput, + ], + cosine_similarity: [check_array_api_metric_pairwise], + explained_variance_score: [ + check_array_api_regression_metric, + check_array_api_regression_metric_multioutput, + ], + mean_absolute_error: [ + check_array_api_regression_metric, + check_array_api_regression_metric_multioutput, + ], + mean_pinball_loss: [ + check_array_api_regression_metric, + check_array_api_regression_metric_multioutput, + ], + mean_squared_error: [ + check_array_api_regression_metric, + check_array_api_regression_metric_multioutput, + ], + mean_squared_log_error: [ + check_array_api_regression_metric, + check_array_api_regression_metric_multioutput, + ], median_absolute_error: [ check_array_api_regression_metric, check_array_api_regression_metric_multioutput, ], - # d2_tweedie_score: [ - # check_array_api_regression_metric, - # ], - # paired_cosine_distances: [check_array_api_metric_pairwise], - # mean_poisson_deviance: [check_array_api_regression_metric], - # additive_chi2_kernel: [check_array_api_metric_pairwise], - # mean_gamma_deviance: [check_array_api_regression_metric], - # max_error: [check_array_api_regression_metric], - # mean_absolute_percentage_error: [ - # check_array_api_regression_metric, - # check_array_api_regression_metric_multioutput, - # ], - # chi2_kernel: [check_array_api_metric_pairwise], - # paired_euclidean_distances: [check_array_api_metric_pairwise], - # cosine_distances: [check_array_api_metric_pairwise], - # euclidean_distances: [check_array_api_metric_pairwise], - # linear_kernel: [check_array_api_metric_pairwise], - # polynomial_kernel: [check_array_api_metric_pairwise], - # rbf_kernel: [check_array_api_metric_pairwise], - # root_mean_squared_error: [ - # check_array_api_regression_metric, - # check_array_api_regression_metric_multioutput, - # ], - # root_mean_squared_log_error: [ - # check_array_api_regression_metric, - # check_array_api_regression_metric_multioutput, - # ], - # sigmoid_kernel: [check_array_api_metric_pairwise], + d2_tweedie_score: [ + check_array_api_regression_metric, + ], + paired_cosine_distances: [check_array_api_metric_pairwise], + mean_poisson_deviance: [check_array_api_regression_metric], + additive_chi2_kernel: [check_array_api_metric_pairwise], + mean_gamma_deviance: [check_array_api_regression_metric], + max_error: [check_array_api_regression_metric], + mean_absolute_percentage_error: [ + check_array_api_regression_metric, + check_array_api_regression_metric_multioutput, + ], + chi2_kernel: [check_array_api_metric_pairwise], + paired_euclidean_distances: [check_array_api_metric_pairwise], + cosine_distances: [check_array_api_metric_pairwise], + euclidean_distances: [check_array_api_metric_pairwise], + linear_kernel: [check_array_api_metric_pairwise], + polynomial_kernel: [check_array_api_metric_pairwise], + rbf_kernel: [check_array_api_metric_pairwise], + root_mean_squared_error: [ + check_array_api_regression_metric, + check_array_api_regression_metric_multioutput, + ], + root_mean_squared_log_error: [ + check_array_api_regression_metric, + check_array_api_regression_metric_multioutput, + ], + sigmoid_kernel: [check_array_api_metric_pairwise], } From 4c38a53a950a7065ab0a895d64030fbc7ceccec0 Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Wed, 28 May 2025 22:06:46 +1000 Subject: [PATCH 09/27] add xfail --- sklearn/metrics/tests/test_common.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/sklearn/metrics/tests/test_common.py b/sklearn/metrics/tests/test_common.py index 06718573c787a..5b8f48b0914e3 100644 --- a/sklearn/metrics/tests/test_common.py +++ b/sklearn/metrics/tests/test_common.py @@ -2254,6 +2254,21 @@ def yield_metric_checker_combinations(metric_checkers=array_api_metric_checkers) @pytest.mark.parametrize("metric, check_func", yield_metric_checker_combinations()) def test_array_api_compliance(metric, array_namespace, device, dtype_name, check_func): check_func(metric, array_namespace, device, dtype_name) + if ( + getattr(metric, "__name__", None) == "median_absolute_error" + and array_namespace == "array_api_strict" + ): + try: + import array_api_strict + except ImportError: + pass + else: + if device == array_api_strict.Device("device1"): + # See https://github.com/data-apis/array-api-strict/issues/134 + pytest.xfail( + "`_weighted_percentile` is affected by array_api_strict bug when " + "indexing with tuple of arrays on non-'CPU_DEVICE' devices." + ) @pytest.mark.parametrize("df_lib_name", ["pandas", "polars"]) From d125a25c24665e1ac7d95859958540573763b966 Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Thu, 29 May 2025 11:13:25 +1000 Subject: [PATCH 10/27] add no implemented error --- sklearn/utils/_array_api.py | 30 +++++------------------------- 1 file changed, 5 insertions(+), 25 deletions(-) diff --git a/sklearn/utils/_array_api.py b/sklearn/utils/_array_api.py index f1595959bde2c..ada3d3fa457f6 100644 --- a/sklearn/utils/_array_api.py +++ b/sklearn/utils/_array_api.py @@ -670,37 +670,17 @@ def _average(a, axis=None, weights=None, normalize=True, xp=None): def _median(x, axis=None, keepdims=False, xp=None): - # `median` is not included in the Array API spec, but is implemented in most - # array libraries (and all that we test). xp, _ = get_namespace(x, xp=xp) if hasattr(xp, "median"): - # When `x` has even number of elements, `torch.median` takes the lower of the - # two medians, thus we use `torch.quantile(q=0.5)`, which gives mean of the two + # `torch.median` takes the lower of the two medians when `x` has even number + # of elements, thus we use `torch.quantile(q=0.5)`, which gives mean of the two if _is_xp_namespace(xp, "torch"): return xp.quantile(x, q=0.5, dim=axis, keepdim=keepdims) return xp.median(x, axis=axis, keepdims=keepdims) - if x.ndim == 0: - return float(x) - - if axis is None: - x = xp.reshape(x, (-1,)) - axis = 0 - - X_sorted = xp.sort(x, axis=axis) - indexer = [slice(None)] * x.ndim - index = x.shape[axis] // 2 - if x.shape[axis] % 2 == 1: - # index with slice to allow mean (below) to work - indexer[axis] = slice(index, index + 1) - else: - indexer[axis] = slice(index - 1, index + 1) - indexer = tuple(indexer) - - # Use mean in both odd and even case to coerce data type, - # using out array if needed. - rout = xp.mean(X_sorted[indexer], axis=axis) - return rout + # `median` is not included in the Array API spec, but is implemented in most + # array libraries, and all that we support (as of May 2025). + raise NotImplementedError(f"The array namespace {xp.__name__} is not supported.") def _xlogy(x, y, xp=None): From 3c4ba9c0938b15673139c8e075c0a1b09ab36ef3 Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Thu, 29 May 2025 11:14:13 +1000 Subject: [PATCH 11/27] fix xfail --- sklearn/metrics/tests/test_common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/metrics/tests/test_common.py b/sklearn/metrics/tests/test_common.py index 5b8f48b0914e3..375cbb69b7c8e 100644 --- a/sklearn/metrics/tests/test_common.py +++ b/sklearn/metrics/tests/test_common.py @@ -2253,7 +2253,6 @@ def yield_metric_checker_combinations(metric_checkers=array_api_metric_checkers) ) @pytest.mark.parametrize("metric, check_func", yield_metric_checker_combinations()) def test_array_api_compliance(metric, array_namespace, device, dtype_name, check_func): - check_func(metric, array_namespace, device, dtype_name) if ( getattr(metric, "__name__", None) == "median_absolute_error" and array_namespace == "array_api_strict" @@ -2269,6 +2268,7 @@ def test_array_api_compliance(metric, array_namespace, device, dtype_name, check "`_weighted_percentile` is affected by array_api_strict bug when " "indexing with tuple of arrays on non-'CPU_DEVICE' devices." ) + check_func(metric, array_namespace, device, dtype_name) @pytest.mark.parametrize("df_lib_name", ["pandas", "polars"]) From 27e049b9934e313dc2f6c3add6b41c33247c24b3 Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Thu, 29 May 2025 11:15:55 +1000 Subject: [PATCH 12/27] add comment --- sklearn/metrics/tests/test_common.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sklearn/metrics/tests/test_common.py b/sklearn/metrics/tests/test_common.py index 375cbb69b7c8e..efc7fd4c07ccd 100644 --- a/sklearn/metrics/tests/test_common.py +++ b/sklearn/metrics/tests/test_common.py @@ -2253,6 +2253,9 @@ def yield_metric_checker_combinations(metric_checkers=array_api_metric_checkers) ) @pytest.mark.parametrize("metric, check_func", yield_metric_checker_combinations()) def test_array_api_compliance(metric, array_namespace, device, dtype_name, check_func): + # TODO: Remove once array-api-strict > 2.3.1 + # https://github.com/data-apis/array-api-strict/issues/134 has been fixed but + # not released yet. if ( getattr(metric, "__name__", None) == "median_absolute_error" and array_namespace == "array_api_strict" @@ -2263,7 +2266,6 @@ def test_array_api_compliance(metric, array_namespace, device, dtype_name, check pass else: if device == array_api_strict.Device("device1"): - # See https://github.com/data-apis/array-api-strict/issues/134 pytest.xfail( "`_weighted_percentile` is affected by array_api_strict bug when " "indexing with tuple of arrays on non-'CPU_DEVICE' devices." From bb81083d267df3e8a4ad5d4a57946770690a5ba9 Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Thu, 29 May 2025 12:12:16 +1000 Subject: [PATCH 13/27] add back --- sklearn/utils/_array_api.py | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/sklearn/utils/_array_api.py b/sklearn/utils/_array_api.py index ada3d3fa457f6..82a088d8de197 100644 --- a/sklearn/utils/_array_api.py +++ b/sklearn/utils/_array_api.py @@ -680,7 +680,28 @@ def _median(x, axis=None, keepdims=False, xp=None): # `median` is not included in the Array API spec, but is implemented in most # array libraries, and all that we support (as of May 2025). - raise NotImplementedError(f"The array namespace {xp.__name__} is not supported.") + # This implementation is required for array-api-strict. + if x.ndim == 0: + return float(x) + + if axis is None: + x = xp.reshape(x, (-1,)) + axis = 0 + + X_sorted = xp.sort(x, axis=axis) + indexer = [slice(None)] * x.ndim + index = x.shape[axis] // 2 + if x.shape[axis] % 2 == 1: + # index with slice to allow mean (below) to work + indexer[axis] = slice(index, index + 1) + else: + indexer[axis] = slice(index - 1, index + 1) + indexer = tuple(indexer) + + # Use mean in both odd and even case to coerce data type, + # using out array if needed. + rout = xp.mean(X_sorted[indexer], axis=axis) + return rout def _xlogy(x, y, xp=None): From d69e07a13c3900fddc372c2f7d0db078b069e61c Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Thu, 29 May 2025 18:57:14 +1000 Subject: [PATCH 14/27] use numpy for strict --- sklearn/utils/_array_api.py | 27 +++++---------------------- 1 file changed, 5 insertions(+), 22 deletions(-) diff --git a/sklearn/utils/_array_api.py b/sklearn/utils/_array_api.py index 82a088d8de197..5bd28bc569b40 100644 --- a/sklearn/utils/_array_api.py +++ b/sklearn/utils/_array_api.py @@ -678,30 +678,13 @@ def _median(x, axis=None, keepdims=False, xp=None): return xp.quantile(x, q=0.5, dim=axis, keepdim=keepdims) return xp.median(x, axis=axis, keepdims=keepdims) + if _is_xp_namespace(xp, "array-api-strict"): + x_np = xp.asarray(x) + return numpy.median(x_np, axis=axis, keepdims=keepdims) + # `median` is not included in the Array API spec, but is implemented in most # array libraries, and all that we support (as of May 2025). - # This implementation is required for array-api-strict. - if x.ndim == 0: - return float(x) - - if axis is None: - x = xp.reshape(x, (-1,)) - axis = 0 - - X_sorted = xp.sort(x, axis=axis) - indexer = [slice(None)] * x.ndim - index = x.shape[axis] // 2 - if x.shape[axis] % 2 == 1: - # index with slice to allow mean (below) to work - indexer[axis] = slice(index, index + 1) - else: - indexer[axis] = slice(index - 1, index + 1) - indexer = tuple(indexer) - - # Use mean in both odd and even case to coerce data type, - # using out array if needed. - rout = xp.mean(X_sorted[indexer], axis=axis) - return rout + raise NotImplementedError(f"The array namespace {xp.__name__} is not supported.") def _xlogy(x, y, xp=None): From 8c3b86f79c3e57b2279b0d6c214e5e4dbfd22cf2 Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Thu, 29 May 2025 19:55:37 +1000 Subject: [PATCH 15/27] fix numpy as array --- sklearn/utils/_array_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/utils/_array_api.py b/sklearn/utils/_array_api.py index 5bd28bc569b40..4f1b774061d3a 100644 --- a/sklearn/utils/_array_api.py +++ b/sklearn/utils/_array_api.py @@ -679,7 +679,7 @@ def _median(x, axis=None, keepdims=False, xp=None): return xp.median(x, axis=axis, keepdims=keepdims) if _is_xp_namespace(xp, "array-api-strict"): - x_np = xp.asarray(x) + x_np = numpy.asarray(x) return numpy.median(x_np, axis=axis, keepdims=keepdims) # `median` is not included in the Array API spec, but is implemented in most From b71074a182cf1f82700eb6db8f9fa59daeaf92f0 Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Thu, 29 May 2025 20:14:57 +1000 Subject: [PATCH 16/27] use xpx is torch namespace --- sklearn/utils/_array_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/utils/_array_api.py b/sklearn/utils/_array_api.py index 4f1b774061d3a..29a4d40abfd90 100644 --- a/sklearn/utils/_array_api.py +++ b/sklearn/utils/_array_api.py @@ -674,7 +674,7 @@ def _median(x, axis=None, keepdims=False, xp=None): if hasattr(xp, "median"): # `torch.median` takes the lower of the two medians when `x` has even number # of elements, thus we use `torch.quantile(q=0.5)`, which gives mean of the two - if _is_xp_namespace(xp, "torch"): + if xpx.is_torch_namespace(xp): return xp.quantile(x, q=0.5, dim=axis, keepdim=keepdims) return xp.median(x, axis=axis, keepdims=keepdims) From 74b55ceeb4a865cec7c67693331303327f249166 Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Thu, 29 May 2025 20:41:30 +1000 Subject: [PATCH 17/27] fix --- sklearn/utils/_array_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/utils/_array_api.py b/sklearn/utils/_array_api.py index 29a4d40abfd90..04a67452542d5 100644 --- a/sklearn/utils/_array_api.py +++ b/sklearn/utils/_array_api.py @@ -674,7 +674,7 @@ def _median(x, axis=None, keepdims=False, xp=None): if hasattr(xp, "median"): # `torch.median` takes the lower of the two medians when `x` has even number # of elements, thus we use `torch.quantile(q=0.5)`, which gives mean of the two - if xpx.is_torch_namespace(xp): + if array_api_compat.is_torch_namespace(xp): return xp.quantile(x, q=0.5, dim=axis, keepdim=keepdims) return xp.median(x, axis=axis, keepdims=keepdims) From 94472f6937fc105b32fd5e75ac368b65e0282966 Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Thu, 29 May 2025 20:41:48 +1000 Subject: [PATCH 18/27] review --- sklearn/utils/_array_api.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/sklearn/utils/_array_api.py b/sklearn/utils/_array_api.py index 04a67452542d5..0e2731632f7a6 100644 --- a/sklearn/utils/_array_api.py +++ b/sklearn/utils/_array_api.py @@ -670,6 +670,8 @@ def _average(a, axis=None, weights=None, normalize=True, xp=None): def _median(x, axis=None, keepdims=False, xp=None): + # `median` is not included in the Array API spec, but is implemented in most + # array libraries, and all that we support (as of May 2025). xp, _ = get_namespace(x, xp=xp) if hasattr(xp, "median"): # `torch.median` takes the lower of the two medians when `x` has even number @@ -682,10 +684,6 @@ def _median(x, axis=None, keepdims=False, xp=None): x_np = numpy.asarray(x) return numpy.median(x_np, axis=axis, keepdims=keepdims) - # `median` is not included in the Array API spec, but is implemented in most - # array libraries, and all that we support (as of May 2025). - raise NotImplementedError(f"The array namespace {xp.__name__} is not supported.") - def _xlogy(x, y, xp=None): # TODO: Remove this once https://github.com/scipy/scipy/issues/21736 is fixed From a7fe344413f52ee22333fa6b3f9d4218d146a258 Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Thu, 29 May 2025 21:17:30 +1000 Subject: [PATCH 19/27] fix test --- sklearn/utils/tests/test_array_api.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/sklearn/utils/tests/test_array_api.py b/sklearn/utils/tests/test_array_api.py index 2b1c9490d595b..5792b611d5558 100644 --- a/sklearn/utils/tests/test_array_api.py +++ b/sklearn/utils/tests/test_array_api.py @@ -623,6 +623,9 @@ def test_median(namespace, device, dtype_name, axis): with config_context(array_api_dispatch=True): result_xp = _median(X_xp, axis=axis) + if xp.__name__ != "array-api-strict": + # We covert array-api-strict arrays to numpy arrays as `median` is not + # part of the Array API spec assert get_namespace(result_xp)[0].__name__ == xp.__name__ assert result_xp.device == X_xp.device - assert_allclose(result_np, _convert_to_numpy(result_xp, xp=xp)) + assert_allclose(result_np, _convert_to_numpy(result_xp, xp=xp)) From 166bd8d9af2f10b8a72039ce009ffb307a1e0d1d Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Thu, 29 May 2025 21:35:29 +1000 Subject: [PATCH 20/27] fix --- sklearn/utils/_array_api.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/sklearn/utils/_array_api.py b/sklearn/utils/_array_api.py index 0e2731632f7a6..ee8d7e4decd41 100644 --- a/sklearn/utils/_array_api.py +++ b/sklearn/utils/_array_api.py @@ -673,17 +673,18 @@ def _median(x, axis=None, keepdims=False, xp=None): # `median` is not included in the Array API spec, but is implemented in most # array libraries, and all that we support (as of May 2025). xp, _ = get_namespace(x, xp=xp) - if hasattr(xp, "median"): - # `torch.median` takes the lower of the two medians when `x` has even number - # of elements, thus we use `torch.quantile(q=0.5)`, which gives mean of the two - if array_api_compat.is_torch_namespace(xp): - return xp.quantile(x, q=0.5, dim=axis, keepdim=keepdims) - return xp.median(x, axis=axis, keepdims=keepdims) if _is_xp_namespace(xp, "array-api-strict"): x_np = numpy.asarray(x) return numpy.median(x_np, axis=axis, keepdims=keepdims) + # `torch.median` takes the lower of the two medians when `x` has even number + # of elements, thus we use `torch.quantile(q=0.5)`, which gives mean of the two + if array_api_compat.is_torch_namespace(xp): + return xp.quantile(x, q=0.5, dim=axis, keepdim=keepdims) + + return xp.median(x, axis=axis, keepdims=keepdims) + def _xlogy(x, y, xp=None): # TODO: Remove this once https://github.com/scipy/scipy/issues/21736 is fixed From 3e80603266fb1a7c8bde3267bbc0502b973046e8 Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Thu, 29 May 2025 22:28:25 +1000 Subject: [PATCH 21/27] fix test --- sklearn/utils/_array_api.py | 4 ++-- sklearn/utils/tests/test_array_api.py | 10 +++++----- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/sklearn/utils/_array_api.py b/sklearn/utils/_array_api.py index ee8d7e4decd41..e341f1d6b496b 100644 --- a/sklearn/utils/_array_api.py +++ b/sklearn/utils/_array_api.py @@ -674,8 +674,8 @@ def _median(x, axis=None, keepdims=False, xp=None): # array libraries, and all that we support (as of May 2025). xp, _ = get_namespace(x, xp=xp) - if _is_xp_namespace(xp, "array-api-strict"): - x_np = numpy.asarray(x) + if _is_xp_namespace(xp, "array_api_strict"): + x_np = numpy.asarray(xp.asarray(x, device=xp.Device("CPU_DEVICE"))) return numpy.median(x_np, axis=axis, keepdims=keepdims) # `torch.median` takes the lower of the two medians when `x` has even number diff --git a/sklearn/utils/tests/test_array_api.py b/sklearn/utils/tests/test_array_api.py index 5792b611d5558..3d636d250c324 100644 --- a/sklearn/utils/tests/test_array_api.py +++ b/sklearn/utils/tests/test_array_api.py @@ -623,9 +623,9 @@ def test_median(namespace, device, dtype_name, axis): with config_context(array_api_dispatch=True): result_xp = _median(X_xp, axis=axis) - if xp.__name__ != "array-api-strict": - # We covert array-api-strict arrays to numpy arrays as `median` is not - # part of the Array API spec - assert get_namespace(result_xp)[0].__name__ == xp.__name__ - assert result_xp.device == X_xp.device + if xp.__name__ != "array_api_strict": + # We covert array-api-strict arrays to numpy arrays as `median` is not + # part of the Array API spec + assert get_namespace(result_xp)[0].__name__ == xp.__name__ + assert result_xp.device == X_xp.device assert_allclose(result_np, _convert_to_numpy(result_xp, xp=xp)) From c55d44fd26a7f7bd2ff60cb6ad8abbe1b2050275 Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Thu, 29 May 2025 22:31:19 +1000 Subject: [PATCH 22/27] namespace check --- sklearn/utils/tests/test_array_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/utils/tests/test_array_api.py b/sklearn/utils/tests/test_array_api.py index 3d636d250c324..feb0d556a4d7e 100644 --- a/sklearn/utils/tests/test_array_api.py +++ b/sklearn/utils/tests/test_array_api.py @@ -626,6 +626,6 @@ def test_median(namespace, device, dtype_name, axis): if xp.__name__ != "array_api_strict": # We covert array-api-strict arrays to numpy arrays as `median` is not # part of the Array API spec - assert get_namespace(result_xp)[0].__name__ == xp.__name__ + assert get_namespace(result_xp)[0] == xp assert result_xp.device == X_xp.device assert_allclose(result_np, _convert_to_numpy(result_xp, xp=xp)) From 232d1b0333dae9852b39ef38bfc66d5925eaf2f3 Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Fri, 30 May 2025 11:09:17 +1000 Subject: [PATCH 23/27] fix median --- sklearn/utils/_array_api.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/sklearn/utils/_array_api.py b/sklearn/utils/_array_api.py index e341f1d6b496b..3abdff9f4038f 100644 --- a/sklearn/utils/_array_api.py +++ b/sklearn/utils/_array_api.py @@ -674,16 +674,18 @@ def _median(x, axis=None, keepdims=False, xp=None): # array libraries, and all that we support (as of May 2025). xp, _ = get_namespace(x, xp=xp) - if _is_xp_namespace(xp, "array_api_strict"): - x_np = numpy.asarray(xp.asarray(x, device=xp.Device("CPU_DEVICE"))) - return numpy.median(x_np, axis=axis, keepdims=keepdims) - # `torch.median` takes the lower of the two medians when `x` has even number # of elements, thus we use `torch.quantile(q=0.5)`, which gives mean of the two if array_api_compat.is_torch_namespace(xp): return xp.quantile(x, q=0.5, dim=axis, keepdim=keepdims) - return xp.median(x, axis=axis, keepdims=keepdims) + if hasattr(xp, "median"): + return xp.median(x, axis=axis, keepdims=keepdims) + + # Intended mostly for array-api-strict, which as no "median", as per the spec, + # as `_convert_to_numpy` does not necessarily work for all array types. + x_np = _convert_to_numpy(x) + return numpy.median(x_np, axis=axis, keepdims=keepdims) def _xlogy(x, y, xp=None): From 150c14c5cf6a831af69ec1fd2bfc55082a15b0a7 Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Fri, 30 May 2025 11:11:21 +1000 Subject: [PATCH 24/27] fix med output --- sklearn/utils/_array_api.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/utils/_array_api.py b/sklearn/utils/_array_api.py index 3abdff9f4038f..6b48a26a85269 100644 --- a/sklearn/utils/_array_api.py +++ b/sklearn/utils/_array_api.py @@ -672,7 +672,7 @@ def _average(a, axis=None, weights=None, normalize=True, xp=None): def _median(x, axis=None, keepdims=False, xp=None): # `median` is not included in the Array API spec, but is implemented in most # array libraries, and all that we support (as of May 2025). - xp, _ = get_namespace(x, xp=xp) + xp, _, device = get_namespace_and_device(x, xp=xp) # `torch.median` takes the lower of the two medians when `x` has even number # of elements, thus we use `torch.quantile(q=0.5)`, which gives mean of the two @@ -685,7 +685,7 @@ def _median(x, axis=None, keepdims=False, xp=None): # Intended mostly for array-api-strict, which as no "median", as per the spec, # as `_convert_to_numpy` does not necessarily work for all array types. x_np = _convert_to_numpy(x) - return numpy.median(x_np, axis=axis, keepdims=keepdims) + return xp.asarray(numpy.median(x_np, axis=axis, keepdims=keepdims), device=device) def _xlogy(x, y, xp=None): From b1dcc531da32a42fce5c6db96e8d2ce7d97f49f6 Mon Sep 17 00:00:00 2001 From: Omar Salman Date: Fri, 30 May 2025 09:59:21 +0500 Subject: [PATCH 25/27] Update sklearn/utils/_array_api.py --- sklearn/utils/_array_api.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sklearn/utils/_array_api.py b/sklearn/utils/_array_api.py index 6b48a26a85269..46844dd9b826b 100644 --- a/sklearn/utils/_array_api.py +++ b/sklearn/utils/_array_api.py @@ -682,9 +682,9 @@ def _median(x, axis=None, keepdims=False, xp=None): if hasattr(xp, "median"): return xp.median(x, axis=axis, keepdims=keepdims) - # Intended mostly for array-api-strict, which as no "median", as per the spec, - # as `_convert_to_numpy` does not necessarily work for all array types. - x_np = _convert_to_numpy(x) + # Intended mostly for array-api-strict, which has no "median", as per the spec, + # as `_convert_to_numpy` does not generically work for all array types. + x_np = _convert_to_numpy(x, xp=xp) return xp.asarray(numpy.median(x_np, axis=axis, keepdims=keepdims), device=device) From de7ccbbfecd37e54689f461e60b64ef723ce5b98 Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Mon, 2 Jun 2025 21:24:34 +1000 Subject: [PATCH 26/27] review --- sklearn/utils/_array_api.py | 11 +++++++---- sklearn/utils/tests/test_array_api.py | 7 ++++++- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/sklearn/utils/_array_api.py b/sklearn/utils/_array_api.py index 46844dd9b826b..fd70cce80f4ac 100644 --- a/sklearn/utils/_array_api.py +++ b/sklearn/utils/_array_api.py @@ -670,8 +670,11 @@ def _average(a, axis=None, weights=None, normalize=True, xp=None): def _median(x, axis=None, keepdims=False, xp=None): - # `median` is not included in the Array API spec, but is implemented in most - # array libraries, and all that we support (as of May 2025). + # XXX: `median` is not included in the array API spec, but is implemented + # in most array libraries, and all that we support (as of May 2025). + # TODO: simplify this code to use scipy instead once the oldest supported + # SciPy version provides `scipy.stats.quantile` with native array API + # support (likely scipy 1.6 at the time of writing). xp, _, device = get_namespace_and_device(x, xp=xp) # `torch.median` takes the lower of the two medians when `x` has even number @@ -682,8 +685,8 @@ def _median(x, axis=None, keepdims=False, xp=None): if hasattr(xp, "median"): return xp.median(x, axis=axis, keepdims=keepdims) - # Intended mostly for array-api-strict, which has no "median", as per the spec, - # as `_convert_to_numpy` does not generically work for all array types. + # Intended mostly for array-api-strict (which as no "median", as per the spec) + # as `_convert_to_numpy` does not necessarily work for all array types. x_np = _convert_to_numpy(x, xp=xp) return xp.asarray(numpy.median(x_np, axis=axis, keepdims=keepdims), device=device) diff --git a/sklearn/utils/tests/test_array_api.py b/sklearn/utils/tests/test_array_api.py index feb0d556a4d7e..4d74b0bf8db43 100644 --- a/sklearn/utils/tests/test_array_api.py +++ b/sklearn/utils/tests/test_array_api.py @@ -613,10 +613,15 @@ def test_sparse_device(csr_container, dispatch): ) @pytest.mark.parametrize("axis", [None, 0, 1]) def test_median(namespace, device, dtype_name, axis): + # Note: depending on the value of `axis`, this test will compare median + # computations on arrays of even (4) or odd (5) numbers of elements, hence + # will test for median computation with and without interpolation to check + # that array API namespaces yield consistent results even when the median is + # not mathematically uniquely defined. xp = _array_api_for_tests(namespace, device) rng = numpy.random.RandomState(0) - X_np = numpy.array(rng.random_sample((5, 4)), dtype=dtype_name) + X_np = rng.uniform(low=0.0, high=1.0, size=(5, 4)).astype(dtype_name) result_np = numpy.median(X_np, axis=axis) X_xp = xp.asarray(X_np, device=device) From 15b8d231a27b87217d83fb10fdcaa578d77eff47 Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Mon, 2 Jun 2025 22:28:40 +1000 Subject: [PATCH 27/27] review --- sklearn/utils/_array_api.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/sklearn/utils/_array_api.py b/sklearn/utils/_array_api.py index fd70cce80f4ac..e2bee3530f26f 100644 --- a/sklearn/utils/_array_api.py +++ b/sklearn/utils/_array_api.py @@ -672,9 +672,11 @@ def _average(a, axis=None, weights=None, normalize=True, xp=None): def _median(x, axis=None, keepdims=False, xp=None): # XXX: `median` is not included in the array API spec, but is implemented # in most array libraries, and all that we support (as of May 2025). - # TODO: simplify this code to use scipy instead once the oldest supported - # SciPy version provides `scipy.stats.quantile` with native array API - # support (likely scipy 1.6 at the time of writing). + # TODO: consider simplifying this code to use scipy instead once the oldest + # supported SciPy version provides `scipy.stats.quantile` with native array API + # support (likely scipy 1.6 at the time of writing). Proper benchmarking of + # either option with popular array namespaces is required to evaluate the + # impact of this choice. xp, _, device = get_namespace_and_device(x, xp=xp) # `torch.median` takes the lower of the two medians when `x` has even number