-
-
Notifications
You must be signed in to change notification settings - Fork 26.1k
FEA Add array API support for GaussianMixture #30777
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
103 commits
Select commit
Hold shift + click to select a range
b04a9f7
wip
lesteve e6ba4e4
wip
lesteve 2226a55
stuck on linalg.cholesky array API support
lesteve b1fdee7
a bit further with xp.cholesky but now linalg.solve_triangular
lesteve 14fb0ba
more array api
StefanieSenger 6010ff7
wip (problem with weights as numpy arrays)
lesteve aa2a383
array api for covariance_type='diag' and init_params='random'
StefanieSenger de4f3a5
add simple test
StefanieSenger 7974931
Add comments about tricky bits
lesteve 08e5f9b
lint
lesteve 0f525ef
one more comment
lesteve 4801e2b
revert unwanted change
lesteve de1343c
fix test_bayesian_mixture
lesteve b05eca0
Compare to numpy result in test
lesteve c35bdd6
Use global_random_seed
lesteve 4516920
retrigger CI
StefanieSenger 61c8b5d
Merge branch 'gmm-array-api' of github.com:lesteve/scikit-learn into …
StefanieSenger e974051
retrigger CI
StefanieSenger 1a7f262
retrigger CI [azure parallel]
StefanieSenger fb40870
A bit further with setting the device more correctly
lesteve f2eba56
Add our own implementation of logsumexp [azure parallel]
lesteve a0f8d25
Fix implementation of logsumexp
lesteve 53e9917
Fix for older numpy versions
lesteve ac66a02
[azure parallel] Add changelog template
lesteve b3c1c8b
Merge branch 'main' into gmm-array-api
ogrisel dfa92d9
Remove "# noqa" inline comment
ogrisel 5f440a9
add test for _logsumexp
StefanieSenger dd59446
slightly improve tests
StefanieSenger 9e93dfa
improve device checking
StefanieSenger 76cf0fa
tweak
lesteve 489c3e3
Pass xp along the call chain
lesteve 6dccb47
tweak
lesteve 3bbb2fc
Merge branch 'main' of https://github.com/scikit-learn/scikit-learn i…
lesteve 30894cd
add NotImplementedError and test
StefanieSenger ae06fe1
add array api support for init_params='random_from_data'
StefanieSenger 3f2d928
Fix?
lesteve 6be6aa2
Add a sumlogexp test without nans or +inf
lesteve 805742b
tweak
lesteve 90bf491
Add test for logsumexp on default device with array API dispatch disa…
lesteve b07b171
Cleaner way to skip when array API dispatch is disabled
lesteve baf6982
[azure parallel]
lesteve 778763f
Merge branch 'main' of https://github.com/scikit-learn/scikit-learn i…
lesteve c7e909a
Merge branch 'main' into gmm-array-api
lesteve 58ad0fe
Merge branch 'main' into gmm-array-api
lesteve 339c16b
add support for weights_init
StefanieSenger cbc8811
fix signature and add assert to test
StefanieSenger 614f7b5
some small things
StefanieSenger 90baf84
Fix BayesianGaussianMixture
lesteve 1e7a385
Add comment
lesteve e4618cf
Remove all remaining code using np and make most tests pass
lesteve 2b80ac9
Fix easy failures
lesteve 3287a50
Fix [azure parallel]
lesteve fb72f79
array api support for covariance type 'full' + test
StefanieSenger 9641997
fix support for covariance_type='spherical'
StefanieSenger 35a4644
add test for GaussianMixture.sample()
StefanieSenger 502d3e6
fix array api support in sample() with covariance_type='full'
StefanieSenger 148381d
fix array api support in sample() with other covariance_types for arr…
StefanieSenger d565cf9
fix torch dtype issue in xp.full
StefanieSenger c836e8d
use numpy for random reneration in sample
StefanieSenger 668c1b0
remove old comment
StefanieSenger 7fef10a
Only use np.errstate for numpy namespace
lesteve c9a355d
Use int64 to be closer to previous code that was doing dtype=int
lesteve a712181
colons instead of elipsis
StefanieSenger 038632f
revert changes in k-means initialisation
StefanieSenger 18b3fe0
add smote test for other methods
StefanieSenger 8f00364
add lacking check_is_fitted to BaseMixture.score
StefanieSenger cc8fa42
Merge branch 'main' into gmm-array-api
StefanieSenger 3aaabf5
re-trigger CI
StefanieSenger c9b2088
Merge branch 'main' into gmm-array-api
lesteve 0084640
Add torch import
lesteve f9b2946
different branch for numpy.linalg; only re-raise numpy error
StefanieSenger 7a38674
Merge branch 'gmm-array-api' of github.com:lesteve/scikit-learn into …
StefanieSenger adc992e
Remove comment
lesteve 0bb750c
Remove script
lesteve 7874231
update TODOs
lesteve 96d8d8c
only use X array namespace at prediction time
lesteve 27a8cd2
Fix predict
lesteve 4c62715
remove TODO
lesteve 303f392
Fix
lesteve c232e39
Better variable name
lesteve a43eeb2
Simplify with math.log
lesteve 3a72ec9
Use math.pi
lesteve 8f4079f
Improve tests + make score return float
lesteve de1e575
List GaussianMixture in the estimators supporting array API
lesteve 3a7dfd1
Merge branch 'main' of https://github.com/scikit-learn/scikit-learn i…
lesteve 910aa1f
Remove temporary array-api-compat work-around
lesteve 23b543d
Merge branch 'main' into gmm-array-api
StefanieSenger 4fe3766
lint
lesteve ce214a6
Revert changes to test_bayesian_mixture.py
lesteve a69cd62
Remove unnecessary check_is_fitted
lesteve 1a0e33b
Add all array constructor params to test
lesteve 1dca29a
[azure parallel] tweak docstring
8000
lesteve b990682
Update sklearn/utils/_array_api.py
OmarManzoor 72cd185
Remove commented out test
lesteve 3af1470
Handle comments
lesteve ecac610
use _call_cholesky
lesteve 341b659
More explicit use of scipy.linalg
lesteve 7ffc5c7
[azure parallel] Increase rtol for float32 tests + some minor cleanups
lesteve 3b95a5f
rename variables
lesteve 45ba1ee
[azure parallel] test more precisely when array constructor arguments…
lesteve 4f89101
[azure parallel] Remove debug
lesteve d2ca209
Test more attributes
lesteve d46840b
Increase tol to make tests pass
lesteve File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
- :class:`sklearn.gaussian_mixture.GaussianMixture` with | ||
`init_params="random"` or `init_params="random_from_data"` and | ||
`warm_start=False` now supports Array API compatible inputs. | ||
By :user:`Stefanie Senger <StefanieSenger>` and :user:`Loïc Estève <lesteve>` |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.