[go: up one dir, main page]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH Use Array API in r2_score #27904

Merged
merged 106 commits into from
Mar 11, 2024
Merged

Conversation

fcharras
Copy link
Contributor
@fcharras fcharras commented Dec 5, 2023

Reference Issues/PRs

The PR builds on preliminary explorations done by @elindgren in #27102

It tackles one of the items outlined in #26024.

Any other comments?

This PR proposes to fallbacks to cpu+numpy at the very beginning of the r2_score function whenever the array namespace and the device can't handle float64 precision, because explicit castings to float64 are unavoidable and are used in a lot of steps.

It also proposes improved ways to detect device support for dtypes, and uses it to act accordingly in r2_score and _average, but also updates weighted_sum function.

elindgren and others added 20 commits August 18, 2023 14:16
Co-authored-by: Tim Head <betatim@gmail.com>
Co-authored-by: Tim Head <betatim@gmail.com>
Some Array API compatible libraries do not have a device called 'cpu'.
Instead we try and detect the lib+device combination that does not
support float64.
…test_affinity_propagation` (scikit-learn#27095)

Signed-off-by: Julien Jerphanion <git@jjerphan.xyz>
Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
Copy link
github-actions bot commented Dec 5, 2023

✔️ Linting Passed

All linting checks passed. Your pull request is in excellent shape! ☀️

Generated for commit: 457531e. Link to the linter CI: here

@fcharras fcharras marked this pull request as ready for review December 6, 2023 11:28
@ogrisel
Copy link
Member
ogrisel commented Mar 8, 2024

@betatim @fcharras @adrinjalali I think this is ready for review.

Copy link
Member
@betatim betatim left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some comments. Haven't looked at the tests yet but the rest looks nice.

sklearn/metrics/_classification.py Outdated Show resolved Hide resolved
sklearn/utils/_array_api.py Outdated Show resolved Hide resolved
sklearn/metrics/_regression.py Show resolved Hide resolved
sklearn/metrics/_regression.py Show resolved Hide resolved
sklearn/utils/_array_api.py Outdated Show resolved Hide resolved
sklearn/utils/_array_api.py Outdated Show resolved Hide resolved
a = xp.astype(a, output_dtype)

if weights is None:
return (xp.mean if normalize else xp.sum)(a, axis=axis)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

😮

Kinda cool that we can do this in Python, but also strong stuff :D

sklearn/utils/_array_api.py Show resolved Hide resolved
sklearn/utils/_array_api.py Outdated Show resolved Hide resolved
Copy link
Member
@adrinjalali adrinjalali left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice. Other than the nits, this looks quite good to me now.

sklearn/metrics/_classification.py Outdated Show resolved Hide resolved
sklearn/metrics/_classification.py Outdated Show resolved Hide resolved
sklearn/metrics/_regression.py Show resolved Hide resolved
Copy link
Member
@adrinjalali adrinjalali left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @ogrisel

@ogrisel
Copy link
Member
ogrisel commented Mar 11, 2024

Thanks for the reviews @adrinjalali and @betatim. This PR has been much simplified as a result of those reviews.

doc/modules/array_api.rst Outdated Show resolved Hide resolved
Copy link
Member
@betatim betatim left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@ogrisel ogrisel enabled auto-merge (squash) March 11, 2024 15:19
@ogrisel ogrisel merged commit 612d93d into scikit-learn:main Mar 11, 2024
30 checks passed
@betatim
Copy link
Member
betatim commented Mar 11, 2024

What a mission this PR was! Thanks everyone who helped, I think it was worth the effort and wait :D

@fcharras
Copy link
Contributor Author
fcharras commented Mar 14, 2024

Thanks everyone for continuing this PR, I now caught up with latest diff and I'm also a happy bunny.

I want to mention 2 differences I think I've spotted between the state where I had left the branch and what has been merged:

  • in numpy an operation on array inputs of mixed dtypes that include int and float will always result in float64. (I think the rationale is that float64 is better because it's less likely to have issues with large ints) The behavior with array api dispatch that has been merged here is so that it will instead result in an output with the default float dtype (so e.g float32 with torch) rather than forcing float64 or keeping int dtype. Which is most likely fine for scikit learn usecases and simpler overall.

  • the error messages in _average where mimicking those of np.average, the merged _average have improved error messages but diverge slightly from np.average in this regard now. Which is ok but I just wanted to mention the reason the error messages had this original shape to begin with.

I'm happy that I got to learn the existence of xp.result_type 👍 .

We did make very conservative choices on this PR initially and in the end that was a source of several iterations, I'll try to be better at thinking at what is actually needed in scikit-learn and the cost in complexity.

As it has been pointed out already, I suspect some of the tools that we introduced then dropped during the PR (like _support_dtype which is still somewhere in the history of commits) will end up being necessary in other PRs ?

Last word, we had initiated documenting the policy when dealing with array api dispatch with no float64 support at #28034 , now I'm a bit unsure if #27904 (comment) had everyone aligned or if it moved to something a bit different in this PR, I'll try to sum up again and update it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

10 participants