-
-
Notifications
You must be signed in to change notification settings - Fork 25.9k
Add array API support to median_absolute_error
#31406
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
sklearn/utils/_array_api.py
Outdated
|
||
# Use mean in both odd and even case to coerce data type, | ||
# using out array if needed. | ||
rout = xp.mean(X_sorted[indexer], axis=axis) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Technically the spec states that NaNs are propagated (https://data-apis.org/array-api/latest/API_specification/generated/array_api.mean.html#mean) but there is also a note that says:
Array libraries, such as NumPy, PyTorch, and JAX, currently deviate from this specification in their handling of components which are NaN when computing the arithmetic mean.
I think writing our own Is the change to |
Thanks for responding @betatim ! More changes required, I just wanted to start the conversation about I like the
and I would rather let dask deal with that, then try to implement in The only concern I have about pushing for adding it to the spec is that it is not used much in other libraries outside of scikit-learn. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think adding median in our own utils or contributing to array-api-extra are both fine.
Maybe we should do both: our own/array api extra implementation (while we wait for the standard) and ask for it to be added to the standard? I somehow dismissed the idea of adding it because it takes a long time :-/ The array-api-extra implementation could even just forward the call to the array library if it implements |
I totally forgot about this part. Okay let's move forward with our own implementation and also starting a discussion about adding median to the spec. |
BTW we need to use >>> from sklearn.metrics import median_absolute_error
>>> import numpy as np
>>> median_absolute_error(np.zeros(4), np.arange(4))
1.5
>>> median_absolute_error(np.zeros(4), np.arange(4), sample_weight=np.ones(4))
1.0 But this can be addressed in a dedicated PR: it impacts both numpy and other array namespaces. |
sklearn/metrics/tests/test_common.py
Outdated
if ( | ||
getattr(metric, "__name__", None) == "median_absolute_error" | ||
and array_namespace == "array_api_strict" | ||
): | ||
try: | ||
import array_api_strict | ||
except ImportError: | ||
pass | ||
else: | ||
if device == array_api_strict.Device("device1"): | ||
# See https://github.com/data-apis/array-api-strict/issues/134 | ||
pytest.xfail( | ||
"`_weighted_percentile` is affected by array_api_strict bug when " | ||
"indexing with tuple of arrays on non-'CPU_DEVICE' devices." | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not ideal. We have a similar xfail in a _weighted_percentile
test:
scikit-learn/sklearn/utils/tests/test_stats.py
Lines 186 to 196 in 398e8fe
if array_namespace == "array_api_strict": | |
try: | |
import array_api_strict | |
except ImportError: | |
pass | |
else: | |
if device == array_api_strict.Device("device1"): | |
# See https://github.com/data-apis/array-api-strict/issues/134 | |
pytest.xfail( | |
"array_api_strict has bug when indexing with tuple of arrays " | |
"on non-'CPU_DEVICE' devices." |
Note that as we add array API support for more regression metrics, we will need to add them to the xfail, as several others use _weighted_percentile
.
(Note that this bug has been fixed but we'd need a new release of array-api-strict to see it)
@ev-br - you did mention there could be an array-api-strict release soon though?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR @lucyleeow
A few comments
(ref: #31406 (comment)) FYI @ogrisel that will be fixed in #30787, as once |
Thank you @OmarManzoor ! CI is finally green 😅 ! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good now. Thank you for the PR @lucyleeow
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM with a few comments. Thanks @lucyleeow!
sklearn/utils/_array_api.py
Outdated
# `median` is not included in the Array API spec, but is implemented in most | ||
# array libraries, and all that we support (as of May 2025). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# `median` is not included in the Array API spec, but is implemented in most | |
# array libraries, and all that we support (as of May 2025). | |
# XXX: `median` is not included in the array API spec, but is implemented | |
# in most array libraries, and all that we support (as of May 2025). | |
# TODO: consider simplifying this code to use scipy instead once the oldest | |
# supported SciPy version provides `scipy.stats.quantile` with native array API | |
# support (likely scipy 1.6 at the time of writing). Proper benchmarking of | |
# either option with popular array namespaces is required to evaluate the | |
# impact of this choice. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting.
I did think about what we should do in the future. Ralf suggested maybe quantile
should go into array api extra (I need to open a RFC issue about this in array api extra).
Despite most array libraries having a median
(including dask), which may be somewhat faster than scipy.stats.quantile
(benchmarking required), would you still be inclined to use scipy.stats.quantile
over median
of the native library?
(I don't think median
will be in the spec, mostly because quantile
is more versatile and torch's version of median
implementation differs from everyone else)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know for sure, we would need some proper evaluation of different options. Maybe I can rephrase the suggestion to be less assertive in the comment.
EDIT: done.
Looks like all the discussion topics were addressed/resolved and the robots are happy -> merging edit: sorry, Olivier's suggestion hadn't actually been committed (I was tricked by the "outdated" marker). Should I open a new PR to add that? edit edit: maybe it was actually done. |
Thanks @betatim ! |
Reference Issues/PRs
Towards #26024
What does this implement/fix? Explain your changes.
Add array API support to
median_absolute_error
. (Currently the only change made was to add an array API supporting_median
function, see below.)Any other comments?
This is the only metric to use
median
, howevermedian
is used in a fair number of estimators. I think the first item to address is whichmedian
should we use.Array API spec currently does not support
median
so these are our options:median
function (that usesnp.median
when namespace is numpy) - included in this PR, maintenance_weighted_percentile
- slowmedian
inclusion in array API. Admittedly,median
is not used much outside of scikit-learn (RFC: array-agnosticquantile
data-apis/array-api#795 (comment)), BUT it seems that most (all?) array libraries have an implementation. I would be in favour of pushing for inclusion, less so because of use, and more so because the implementation ofmedian
is well defined (vs e.g. quantile) and I think other array libraries do have an implementation, including dask. They may be open to this: RFC: array-agn 8000 osticquantile
data-apis/array-api#795 (comment)Here are some benchmarking I did with numpy and cupy arrays. I wanted to increase the size of the arrays tested and include the new scipy quantile (which supports array API but not weights - as a reference, as I think we ultimately want to use this) but I ran out of GPU time in colab 🙃
Also maybe I should have also included torch CPU in the mix?
(Randomly generated 1D array)
_median
_weighted_percentile_
median