8000 Add array API support to `median_absolute_error` by lucyleeow · Pull Request #31406 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

Add array API support to median_absolute_error #31406

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 28 commits into from
Jun 3, 2025

Conversation

lucyleeow
Copy link
Member

Reference Issues/PRs

Towards #26024

What does this implement/fix? Explain your changes.

Add array API support to median_absolute_error. (Currently the only change made was to add an array API supporting _median function, see below.)

Any other comments?

This is the only metric to use median, however median is used in a fair number of estimators. I think the first item to address is which median should we use.

Array API spec currently does not support median so these are our options:

  • Write our own median function (that uses np.median when namespace is numpy) - included in this PR, maintenance
  • Use our _weighted_percentile - slow
  • Push for median inclusion in array API. Admittedly, median is not used much outside of scikit-learn (RFC: array-agnostic quantile data-apis/array-api#795 (comment)), BUT it seems that most (all?) array libraries have an implementation. I would be in favour of pushing for inclusion, less so because of use, and more so because the implementation of median is well defined (vs e.g. quantile) and I think other array libraries do have an implementation, including dask. They may be open to this: RFC: array-agn 8000 ostic quantile data-apis/array-api#795 (comment)

Here are some benchmarking I did with numpy and cupy arrays. I wanted to increase the size of the arrays tested and include the new scipy quantile (which supports array API but not weights - as a reference, as I think we ultimately want to use this) but I ran out of GPU time in colab 🙃
Also maybe I should have also included torch CPU in the mix?

(Randomly generated 1D array)

Numpy (1e7) CuPy (1e7)
sklearn _median 0.182784s 0.017168s
sklearn _weighted_percentile_ 2.369427s 0.088325s
Cupy median n/a 0.015946s

Copy link
github-actions bot commented May 21, 2025

✔️ Linting Passed

All linting checks passed. Your pull request is in excellent shape! ☀️

Generated for commit: 15b8d23. Link to the linter CI: here


# Use mean in both odd and even case to coerce data type,
# using out array if needed.
rout = xp.mean(X_sorted[indexer], axis=axis)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically the spec states that NaNs are propagated (https://data-apis.org/array-api/latest/API_specification/generated/array_api.mean.html#mean) but there is also a note that says:

Array libraries, such as NumPy, PyTorch, and JAX, currently deviate from this specification in their handling of components which are NaN when computing the arithmetic mean.

@lucyleeow
Copy link
Member Author

@betatim
Copy link
Member
betatim commented May 21, 2025

I think writing our own median is fine. We could also contribute it to array-api-extra?

Is the change to median_absolute_error still to come or is this really all we need to do?

@lucyleeow
Copy link
Member Author

Thanks for responding @betatim !

More changes required, I just wanted to start the conversation about median and thought it would be nice to see what a re-implementation of it would look like and context, so opened this PR instead of an issue.

I like the array-api-extra idea, it opens a conversation re: median at least. I am of the opinion that it would be worthwhile asking for it to be added to the spec as it is implemented already in all the array libraries listed under 'actively considered' here, and e.g., dask's own implementation does stuff with chunking (ref):

This works by automatically chunking the reduced axes to a single chunk if necessary and then calling numpy.median function across the remaining dimensions

and I would rather let dask deal with that, then try to implement in array-api-extra. (though I guess the median function in array-api-extra could always just pass it to dask's own median function, but that almost seems to defeat the point)

The only concern I have about pushing for adding it to the spec is that it is not used much in other libraries outside of scikit-learn.

Copy link
Contributor
@OmarManzoor OmarManzoor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think adding median in our own utils or contributing to array-api-extra are both fine.

@betatim
Copy link
Member
betatim commented May 23, 2025

Maybe we should do both: our own/array api extra implementation (while we wait for the standard) and ask for it to be added to the standard? I somehow dismissed the idea of adding it because it takes a long time :-/

The array-api-extra implementation could even just forward the call to the array library if it implements median

@lucyleeow
Copy link
Member Author

I somehow dismissed the idea of adding it because it takes a long time :-/

I totally forgot about this part.

Okay let's move forward with our own implementation and also starting a discussion about adding median to the spec.

@ogrisel
Copy link
Member
ogrisel commented May 28, 2025

BTW we need to use _averaged_weighted_percentile instead of _weighted_percentile when sample_weight is not None. This it to ensure "centered" results when calling it on data with even number of weighted predictions. This would fix the discrepancy with the unweighted in cases where they should return the same results:

>>> from sklearn.metrics import median_absolute_error
>>> import numpy as np
>>> median_absolute_error(np.zeros(4), np.arange(4))
1.5
>>> median_absolute_error(np.zeros(4), np.arange(4), sample_weight=np.ones(4))
1.0

But this can be addressed in a dedicated PR: it impacts both numpy and other array namespaces.

Comment on lines 2257 to 2271
if (
getattr(metric, "__name__", None) == "median_absolute_error"
and array_namespace == "array_api_strict"
):
try:
import array_api_strict
except ImportError:
pass
else:
if device == array_api_strict.Device("device1"):
# See https://github.com/data-apis/array-api-strict/issues/134
pytest.xfail(
"`_weighted_percentile` is affected by array_api_strict bug when "
"indexing with tuple of arrays on non-'CPU_DEVICE' devices."
)
Copy link
Member Author
@lucyleeow lucyleeow May 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not ideal. We have a similar xfail in a _weighted_percentile test:

if array_namespace == "array_api_strict":
try:
import array_api_strict
except ImportError:
pass
else:
if device == array_api_strict.Device("device1"):
# See https://github.com/data-apis/array-api-strict/issues/134
pytest.xfail(
"array_api_strict has bug when indexing with tuple of arrays "
"on non-'CPU_DEVICE' devices."

Note that as we add array API support for more regression metrics, we will need to add them to the xfail, as several others use _weighted_percentile.

(Note that this bug has been fixed but we'd need a new release of array-api-strict to see it)

@ev-br - you did mention there could be an array-api-strict release soon though?

Copy link
Contributor
@OmarManzoor OmarManzoor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR @lucyleeow
A few comments

@lucyleeow
Copy link
Member Author
lucyleeow commented May 30, 2025

BTW we need to use _averaged_weighted_percentile instead of _weighted_percentile when sample_weight is not None.

(ref: #31406 (comment))

FYI @ogrisel that will be fixed in #30787, as once median_absolute_error is tested correctly, the difference will show up. Nevermind, test failure due to a different problem. I will amend this in a follow up PR 😬 Ignore this part, there were 2 problems, it will be fixed in #30787

@lucyleeow
Copy link
Member Author

Thank you @OmarManzoor ! CI is finally green 😅 !

@OmarManzoor OmarManzoor added CUDA CI and removed Needs Decision Requires decision labels May 30, 2025
@github-actions github-actions bot removed the CUDA CI label May 30, 2025
Copy link
Contributor
@OmarManzoor OmarManzoor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good now. Thank you for the PR @lucyleeow

Copy link
Member
@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM with a few comments. Thanks @lucyleeow!

Comment on lines 673 to 674
# `median` is not included in the Array API spec, but is implemented in most
# array libraries, and all that we support (as of May 2025).
Copy link
Member
@ogrisel ogrisel Jun 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# `median` is not included in the Array API spec, but is implemented in most
# array libraries, and all that we support (as of May 2025).
# XXX: `median` is not included in the array API spec, but is implemented
# in most array libraries, and all that we support (as of May 2025).
# TODO: consider simplifying this code to use scipy instead once the oldest
# supported SciPy version provides `scipy.stats.quantile` with native array API
# support (likely scipy 1.6 at the time of writing). Proper benchmarking of
# either option with popular array namespaces is required to evaluate the
# impact of this choice.

Copy link
Member Author
@lucyleeow lucyleeow Jun 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting.
I did think about what we should do in the future. Ralf suggested maybe quantile should go into array api extra (I need to open a RFC issue about this in array api extra).

Despite most array libraries having a median (including dask), which may be somewhat faster than scipy.stats.quantile (benchmarking required), would you still be inclined to use scipy.stats.quantile over median of the native library?
(I don't think median will be in the spec, mostly because quantile is more versatile and torch's version of median implementation differs from everyone else)

Copy link
Member
@ogrisel ogrisel Jun 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know for sure, we would need some proper evaluation of different options. Maybe I can rephrase the suggestion to be less assertive in the comment.

EDIT: done.

@betatim betatim added the CUDA CI label Jun 3, 2025
@github-actions github-actions bot removed the CUDA CI label Jun 3, 2025
@betatim
Copy link
Member
betatim commented Jun 3, 2025

Looks like all the discussion topics were addressed/resolved and the robots are happy -> merging

edit: sorry, Olivier's suggestion hadn't actually been committed (I was tricked by the "outdated" marker). Should I open a new PR to add that?

edit edit: maybe it was actually done.

@betatim betatim merged commit 5c21794 into scikit-learn:main Jun 3, 2025
40 checks passed
< 7802 /div>
@lucyleeow lucyleeow deleted the aapi_med_abs_er branch June 4, 2025 00:20
@lucyleeow
Copy link
Member Author

Thanks @betatim !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants
0