-
-
Notifications
You must be signed in to change notification settings - Fork 25.9k
ENH Use Array API in r2_score
#27904
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
106 commits
Select commit
Hold shift + click to select a range
e0429db
update r2 score to use the array API, and write initial tests
elindgren b9c1720
Merge remote-tracking branch 'upstream/main'
elindgren 5666ce5
Merge branch 'main' into ENH/r2_score_array_api
ogrisel 4580d1c
Merge branch 'main' into ENH/r2_score_array_api
ogrisel a4dd594
Fix some review comments and move stuff to CPU
elindgren adc7680
Add regression tests to the test_common framework
elindgren 85469a9
Update sklearn/metrics/tests/test_regression.py
elindgren b7efaa5
Update sklearn/metrics/tests/test_regression.py
elindgren ac533c2
Remove hardcoded device choice in _weighted_sum
betatim 35be22e
Factor out max float precision determination
betatim 7c53e19
Use convenience function to find highest accuracy float in r2_score
elindgren 230ae46
add tests for _average for Array API
elindgren e4672d1
MNT Ignore ruff errors (#27094)
lesteve 8ba9485
DOC fix docstring for `sklearn.datasets.get_data_home` (#27073)
kachayev 490e0b4
TST Extend tests for `scipy.sparse.*array` in `sklearn/cluster/tests/…
jjerphan a8a820c
MNT Remove DeprecationWarning for scipy.sparse.linalg.cg tol vs rtol …
lesteve 552e421
Merge branch 'main' into ENH/r2_score_array_api
elindgren ff52710
Merge remote-tracking branch 'upstream/main' into ENH/r2_score_array_api
elindgren fe9cc1c
remove temporary file
elindgren 93257ba
WIP: solving dtype and device maze
fcharras 45bbe4e
Fix changelog conflict
fcharras 2145a6b
Tests fixups
fcharras bd4b224
Tests fixups
fcharras 34aceb1
Merge branch 'main' of https://github.com/scikit-learn/scikit-learn i…
fcharras 56d5308
Fix dtype parameterization in common metric tests
fcharras 75cb3f3
Tests fixups
fcharras d9fff24
Tests fixups
fcharras d72137c
Adds lru_cache on device inspection function + user _convert_to_numpy…
fcharras 16ab95f
Adequatly define hash of _ArrayAPIWrapper to avoid wrong equality
fcharras 9862a85
Merge branch 'main' of https://github.com/scikit-learn/scikit-learn i…
fcharras 143ce54
Remove _weighted_sum and only use _average
fcharras 4e9401b
Merge branch 'main' of https://github.com/scikit-learn/scikit-learn i…
fcharras 2b095c4
Linting on unrelated diff, pre-commit broken ? + fixes
fcharras 42f5d8d
Merge branch 'main' into ENH/r2_score_array_api
fcharras ff0b860
re add faster, simpler code branch for _weighted_sum in _classificati…
fcharras efe36f3
re add faster, simpler code branch for _weighted_sum in _classificati…
fcharras abb9ee9
Merge branch 'main' of https://github.com/scikit-learn/scikit-learn i…
fcharras 08f5433
fix
fcharras 38f56af
fix tests with torch+cuda
fcharras c09a84b
fix tests with torch+cuda
fcharras 13d9bd6
Merge branch 'main' into ENH/r2_score_array_api
fcharras c32fa92
FIX: always pass xp to _convert_to_numpy calls
ogrisel 1555f8d
FIX also update device_ in case of numpy fallback
ogrisel fc1b9f1
FIX pass xp to _convert_to_numpy instead of copy=True
ogrisel 1bf557d
Rename _weighted_sum to _weighted_sum_1d to make it explicit that tho…
ogrisel c41694a
Improve test coverage for _average function + some review changes
fcharras c71c3ce
fix torch+cuda
fcharras d2cd3ca
Merge branch 'main' into ENH/r2_score_array_api
fcharras 4be2ac0
Fix docstring formatting
fcharras 29260e1
Fix error for arrays on different devices
fcharras e47d53c
Merge branch 'main' into ENH/r2_score_array_api
fcharras ccbc92d
Adapt device inspection function to non hashable device objects
fcharras b09b653
Merge branch 'ENH/r2_score_array_api' of https://github.com/fcharras/…
fcharras 0b5b550
CI Remove unused mkl_no_coverage lock file (#28061)
lesteve 2266348
Fix device inspection function + adapt test to non-hashable device ob…
fcharras db22354
Merge branch 'ENH/r2_score_array_api' of https://github.com/fcharras/…
fcharras fc3b6e9
Merge branch 'main' of https://github.com/scikit-learn/scikit-learn i…
fcharras 6ff37fb
Apply suggestion
fcharras 3cda292
Fix device inspection test
fcharras 6179f10
Merge branch 'main' into ENH/r2_score_array_api
ogrisel bcaa3d8
modify changelog
glemaitre 647109c
Merge remote-tracking branch 'origin/main' into pr/fcharras/27904
glemaitre fc51090
Apply non-controversial suggestions from code review
ogrisel df14fca
Merge branch 'main' of https://github.com/scikit-learn/scikit-learn i…
fcharras db8a046
Adress review comments. NB:
fcharras 2c12856
fixup
fcharras cd53bd6
Factorize array filtering by type for get_namespace and device helpers
ogrisel 0d1c3bf
Do not upcast partial sums to float64 in r2_score
ogrisel 40dd9d1
Skip strings by default and rename private helper
ogrisel ac07d4c
WIP fixing type promotion logic
ogrisel 9ebc1ff
Merge branch 'main' into ENH/r2_score_array_api
ogrisel 3550daf
Fix use implementation defined default floating dtype
ogrisel 33177dd
Update test and remove non-reachable branch
ogrisel f35ea45
Fix error message when the default array filter leads to an empty lis…
ogrisel a67fe45
Use informative error message in _average while starting with the sam…
ogrisel d583d9e
Factorize floating point type promotion logic
ogrisel bae25f0
Merge branch 'main' into ENH/r2_score_array_api
ogrisel 41f99d2
Fix adapt dtype matching logic to non-array inputs, prior to the call…
ogrisel 87e4c8d
Simplification
ogrisel 1c2ea78
Remove device-specific dtype support as its no longer needed by r2_score
ogrisel c2cbd98
More simplifications
ogrisel 6429401
Improve numerical stability by scaling the weights prior to using the…
ogrisel 6636e4c
Fix test_nan_reductions
ogrisel be5a474
Revert "Improve numerical stability by scaling the weights prior to u…
ogrisel 6a728ac
Fix formatting
ogrisel 1d4c49e
Skip test_average_raises_with_wrong_dtype on cupy
ogrisel 98347c1
Simplify back _isdtype_single
ogrisel aff4840
Grammar.
ogrisel ad0a1fb
Need to conver to float explicitly
ogrisel d596494
Factorize the float conversion into _assemble_r2_explained_variance
ogrisel d6f0101
Move tuple conversion at the beginning of _skip_non_arrays
ogrisel ec84e44
Small fixes in comments and remove duplicated lines.
ogrisel 08405a5
One more get_namespace simplification
ogrisel a09866d
Remove useless import added by vs code...
ogrisel b59a7be
Apply suggestions from code review
ogrisel ef1631b
Rename _skip_non_arrays to _remove_non_arrays & co
ogrisel 388d670
Remove custom __hash__ method that is no longer needed
ogrisel 8042795
Remove redundant calls to xp.astype
ogrisel 92af1a8
Factorize the if xp is None: xp, _ = get_namespace(inputs) pattern
ogrisel 47fed64
Fix handling of xp is not None in get_namespace
ogrisel 3699353
get_namespace in _weighted_sum_1d
ogrisel c2b4b11
Merge _weighted_sum_1d into _average
ogrisel 9c2d9ac
One final 'if xp is None' occurrence
ogrisel 90076d3
DOC be explicit about return types
ogrisel 3cc74e5
Merge branch 'main' into ENH/r2_score_array_api
ogrisel 457531e
Update phrasing in the doc to avoid confusing array container type wi…
ogrisel File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.