8000 FEA Add array API support for GaussianMixture by lesteve · Pull Request #30777 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

FEA Add array API support for GaussianMixture #30777

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 103 commits into from
Jun 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
103 commits
Select commit Hold shift + click to select a range
b04a9f7
wip
lesteve Jan 22, 2025
e6ba4e4
wip
lesteve Jan 22, 2025
2226a55
stuck on linalg.cholesky array API support
lesteve Feb 5, 2025
b1fdee7
a bit further with xp.cholesky but now linalg.solve_triangular
lesteve Feb 5, 2025
14fb0ba
more array api
StefanieSenger Feb 14, 2025
6010ff7
wip (problem with weights as numpy arrays)
lesteve Feb 19, 2025
aa2a383
array api for covariance_type='diag' and init_params='random'
StefanieSenger Feb 21, 2025
de4f3a5
add simple test
StefanieSenger Feb 21, 2025
7974931
Add comments about tricky bits
lesteve Feb 21, 2025
08e5f9b
lint
lesteve Feb 21, 2025
0f525ef
one more comment
lesteve Feb 21, 2025
4801e2b
revert unwanted change
lesteve Feb 28, 2025
de1343c
fix test_bayesian_mixture
lesteve Feb 28, 2025
b05eca0
Compare to numpy result in test
lesteve Feb 28, 2025
c35bdd6
Use global_random_seed
lesteve Feb 28, 2025
4516920
retrigger CI
StefanieSenger Mar 12, 2025
61c8b5d
Merge branch 'gmm-array-api' of github.com:lesteve/scikit-learn into …
StefanieSenger Mar 12, 2025
e974051
retrigger CI
StefanieSenger Mar 12, 2025
1a7f262
retrigger CI [azure parallel]
StefanieSenger Mar 12, 2025
fb40870
A bit further with setting the device more correctly
lesteve Mar 13, 2025
f2eba56
Add our own implementation of logsumexp [azure parallel]
lesteve Mar 14, 2025
a0f8d25
Fix implementation of logsumexp
lesteve Mar 14, 2025
53e9917
Fix for older numpy versions
lesteve Mar 14, 2025
ac66a02
[azure parallel] Add changelog template
lesteve Mar 15, 2025
b3c1c8b
Merge branch 'main' into gmm-array-api
ogrisel Mar 18, 2025
dfa92d9
Remove "# noqa" inline comment
ogrisel Mar 18, 2025
5f440a9
add test for _logsumexp
StefanieSenger Mar 19, 2025
dd59446
slightly improve tests
StefanieSenger Mar 19, 2025
9e93dfa
improve device checking
StefanieSenger Mar 19, 2025
76cf0fa
tweak
lesteve Mar 21, 2025
489c3e3
Pass xp along the call chain
lesteve Mar 21, 2025
6dccb47
tweak
lesteve Mar 21, 2025
3bbb2fc
Merge branch 'main' of https://github.com/scikit-learn/scikit-learn i…
lesteve Mar 25, 2025
30894cd
add NotImplementedError and test
StefanieSenger Mar 26, 2025
ae06fe1
add array api support for init_params='random_from_data'
StefanieSenger Mar 26, 2025
3f2d928
Fix?
lesteve Mar 26, 2025
6be6aa2
Add a sumlogexp test without nans or +inf
lesteve Mar 27, 2025
805742b
tweak
lesteve Mar 27, 2025
90bf491
Add test for logsumexp on default device with array API dispatch disa…
lesteve Mar 27, 2025
b07b171
Cleaner way to skip when array API dispatch is disabled
lesteve Mar 27, 2025
baf6982
[azure parallel]
lesteve Mar 27, 2025
778763f
Merge branch 'main' of https://github.com/scikit-learn/scikit-learn i…
lesteve Mar 28, 2025
c7e909a
Merge branch 'main' into gmm-array-api
lesteve Apr 1, 2025
58ad0fe
Merge branch 'main' into gmm-array-api
lesteve Apr 1, 2025
339c16b
add support for weights_init
StefanieSenger Apr 2, 2025
cbc8811
fix signature and add assert to test
StefanieSenger Apr 2, 2025
614f7b5
some small things
StefanieSenger Apr 3, 2025
90baf84
Fix BayesianGaussianMixture
lesteve Apr 3, 2025
1e7a385
Add comment
lesteve Apr 3, 2025
e4618cf
Remove all remaining code using np and make most tests pass
lesteve Apr 3, 2025
2b80ac9
Fix easy failures
lesteve Apr 3, 2025
3287a50
Fix [azure parallel]
lesteve Apr 3, 2025
fb72f79
array api support for covariance type 'full' + test
StefanieSenger Apr 4, 2025
9641997
fix support for covariance_type='spherical'
StefanieSenger Apr 7, 2025
35a4644
add test for GaussianMixture.sample()
StefanieSenger Apr 7, 2025
502d3e6
fix array api support in sample() with covariance_type='full'
StefanieSenger Apr 7, 2025
148381d
fix array api support in sample() with other covariance_types for arr…
StefanieSenger Apr 7, 2025
d565cf9
fix torch dtype issue in xp.full
StefanieSenger Apr 7, 2025
c836e8d
use numpy for random reneration in sample
StefanieSenger Apr 9, 2025
668c1b0
remove old comment
StefanieSenger Apr 9, 2025
7fef10a
Only use np.errstate for numpy namespace
lesteve Apr 9, 2025
c9a355d
Use int64 to be closer to previous code that was doing dtype=int
lesteve Apr 9, 2025
a712181
colons instead of elipsis
StefanieSenger May 7, 2025
038632f
revert changes in k-means initialisation
StefanieSenger May 7, 2025
18b3fe0
add smote test for other methods
StefanieSenger May 7, 2025
8f00364
add lacking check_is_fitted to BaseMixture.score
StefanieSenger May 7, 2025
cc8fa42
Merge branch 'main' into gmm-array-api
StefanieSenger May 7, 2025
3aaabf5
re-trigger CI
StefanieSenger May 9, 2025
c9b2088
Merge branch 'main' into gmm-array-api
lesteve May 9, 2025
0084640
Add torch import
lesteve May 9, 2025
f9b2946
different branch for numpy.linalg; only re-raise numpy error
StefanieSenger May 14, 2025
7a38674
Merge branch 'gmm-array-api' of github.com:lesteve/scikit-learn into …
StefanieSenger May 14, 2025
adc992e
Remove comment
lesteve May 14, 2025
0bb750c
Remove script
lesteve May 15, 2025
7874231
update TODOs
lesteve May 15, 2025
96d8d8c
only use X array namespace at prediction time
lesteve May 15, 2025
27a8cd2
Fix predict
lesteve May 15, 2025
4c62715
remove TODO
lesteve May 15, 2025
303f392
Fix
lesteve May 15, 2025
c232e39
Better variable name
lesteve May 16, 2025
a43eeb2
Simplify with math.log
lesteve May 16, 2025
3a72ec9
Use math.pi
lesteve May 16, 2025
8f4079f
Improve tests + make score return float
lesteve May 16, 2025
de1e575
List GaussianMixture in the estimators supporting array API
lesteve May 16, 2025
3a7dfd1
Merge branch 'main' of https://github.com/scikit-learn/scikit-learn i…
lesteve May 20, 2025
910aa1f
Remove temporary array-api-compat work-around
lesteve May 20, 2025
23b543d
Merge branch 'main' into gmm-array-api
StefanieSenger Jun 4, 2025
4fe3766
lint
lesteve Jun 6, 2025
ce214a6
Revert changes to test_bayesian_mixture.py
lesteve Jun 13, 2025
a69cd62
Remove unnecessary check_is_fitted
lesteve Jun 13, 2025
1a0e33b
Add all array constructor params to test
lesteve Jun 13, 2025
1dca29a
[azure parallel] tweak docstring
8000 lesteve Jun 13, 2025
b990682
Update sklearn/utils/_array_api.py
OmarManzoor Jun 14, 2025
72cd185
Remove commented out test
lesteve Jun 16, 2025
3af1470
Handle comments
lesteve Jun 16, 2025
ecac610
use _call_cholesky
lesteve Jun 16, 2025
341b659
More explicit use of scipy.linalg
lesteve Jun 18, 2025
7ffc5c7
[azure parallel] Increase rtol for float32 tests + some minor cleanups
lesteve Jun 18, 2025
3b95a5f
rename variables
lesteve Jun 18, 2025
45ba1ee
[azure parallel] test more precisely when array constructor arguments…
lesteve Jun 18, 2025
4f89101
[azure parallel] Remove debug
lesteve Jun 18, 2025
d2ca209
Test more attributes
lesteve Jun 19, 2025
d46840b
Increase tol to make tests pass
lesteve Jun 19, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/modules/array_api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@ Estimators
- :class:`preprocessing.MaxAbsScaler`
- :class:`preprocessing.MinMaxScaler`
- :class:`preprocessing.Normalizer`
- :class:`mixture.GaussianMixture` (with `init_params="random"` or
`init_params="random_from_data"` and `warm_start=False`)

Meta-estimators
---------------
Expand Down
4 changes: 4 additions & 0 deletions doc/whats_new/upcoming_changes/array-api/30777.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
- :class:`sklearn.gaussian_mixture.GaussianMixture` with
`init_params="random"` or `init_params="random_from_data"` and
`warm_start=False` now supports Array API compatible inputs.
By :user:`Stefanie Senger <StefanieSenger>` and :user:`Loïc Estève <lesteve>`
131 changes: 86 additions & 45 deletions sklearn/mixture/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,24 @@

import warnings
from abc import ABCMeta, abstractmethod
from contextlib import nullcontext
from numbers import Integral, Real
from time import time

import numpy as np
from scipy.special import logsumexp

from .. import cluster
from ..base import BaseEstimator, DensityMixin, _fit_context
from ..cluster import kmeans_plusplus
from ..exceptions import ConvergenceWarning
from ..utils import check_random_state
from ..utils._array_api import (
_convert_to_numpy,
_is_numpy_namespace,
_logsumexp,
get_namespace,
get_namespace_and_device,
)
from ..utils._param_validation import Interval, StrOptions
from ..utils.validation import check_is_fitted, validate_data

Expand All @@ -31,7 +38,6 @@ def _check_shape(param, param_shape, name):

name : str
"""
param = np.array(param)
if param.shape != param_shape:
raise ValueError(
"The parameter '%s' should have the shape of %s, but got %s"
Expand Down Expand Up @@ -86,7 +92,7 @@ def __init__(
self.verbose_interval = verbose_interval

@abstractmethod
def _check_parameters(self, X):
def _check_parameters(self, X, xp=None):
"""Check initial parameters of the derived class.

Parameters
Expand All @@ -95,7 +101,7 @@ def _check_parameters(self, X):
"""
pass

def _initialize_parameters(self, X, random_state):
def _initialize_parameters(self, X, random_state, xp=None):
"""Initialize the model parameters.

Parameters
Expand All @@ -106,6 +112,7 @@ def _initialize_parameters(self, X, random_state):
A random number generator instance that controls the random seed
used for the method chosen to initialize the parameters.
"""
xp, _, device = get_namespace_and_device(X, xp=xp)
n_samples, _ = X.shape

if self.init_params == "kmeans":
Expand All @@ -119,16 +126,25 @@ def _initialize_parameters(self, X, random_state):
)
resp[np.arange(n_samples), label] = 1
elif self.init_params == "random":
resp = np.asarray(
random_state.uniform(size=(n_samples, self.n_components)), dtype=X.dtype
resp = xp.asarray(
random_state.uniform(size=(n_samples, self.n_components)),
dtype=X.dtype,
device=device,
)
resp /= resp.sum(axis=1)[:, np.newaxis]
resp /= xp.sum(resp, axis=1)[:, xp.newaxis]
elif self.init_params == "random_from_data":
resp = np.zeros((n_samples, self.n_components), dtype=X.dtype)
resp = xp.zeros(
(n_samples, self.n_components), dtype=X.dtype, device=device
)
indices = random_state.choice(
n_samples, size=self.n_components, replace=False
)
resp[indices, np.arange(self.n_components)] = 1
# TODO: when array API supports __setitem__ with fancy indexing we
# can use the previous code:
# resp[indices, xp.arange(self.n_components)] = 1
# Until then we use a for loop on one dimension.
for col, index in enumerate(indices):
resp[index, col] = 1
elif self.init_params == "k-means++":
resp = np.zeros((n_samples, self.n_components), dtype=X.dtype)
_, indices = kmeans_plusplus(
Expand Down Expand Up @@ -210,20 +226,21 @@ def fit_predict(self, X, y=None):
labels : array, shape (n_samples,)
Component labels.
"""
X = validate_data(self, X, dtype=[np.float64, np.float32], ensure_min_samples=2)
xp, _ = get_namespace(X)
X = validate_data(self, X, dtype=[xp.float64, xp.float32], ensure_min_samples=2)
if X.shape[0] < self.n_components:
raise ValueError(
"Expected n_samples >= n_components "
f"but got n_components = {self.n_components}, "
f"n_samples = {X.shape[0]}"
)
self._check_parameters(X)
self._check_parameters(X, xp=xp)

# if we enable warm_start, we will have a unique initialisation
do_init = not (self.warm_start and hasattr(self, "converged_"))
n_init = self.n_init if do_init else 1

max_lower_bound = -np.inf
max_lower_bound = -xp.inf
best_lower_bounds = []
self.converged_ = False

Expand All @@ -234,9 +251,9 @@ def fit_predict(self, X, y=None):
self._print_verbose_msg_init_beg(init)

if do_init:
self._initialize_parameters(X, random_state)
self._initialize_parameters(X, random_state, xp=xp)

lower_bound = -np.inf if do_init else self.lower_bound_
lower_bound = -xp.inf if do_init else self.lower_bound_
current_lower_bounds = []

if self.max_iter == 0:
Expand All @@ -247,8 +264,8 @@ def fit_predict(self, X, y=None):
for n_iter in range(1, self.max_iter + 1):
prev_lower_bound = lower_bound

log_prob_norm, log_resp = self._e_step(X)
self._m_step(X, log_resp)
log_prob_norm, log_resp = self._e_step(X, xp=xp)
self._m_step(X, log_resp, xp=xp)
lower_bound = self._compute_lower_bound(log_resp, log_prob_norm)
current_lower_bounds.append(lower_bound)

Expand All @@ -261,7 +278,7 @@ def fit_predict(self, X, y=None):

self._print_verbose_msg_init_end(lower_bound, converged)

if lower_bound > max_lower_bound or max_lower_bound == -np.inf:
if lower_bound > max_lower_bound or max_lower_bound == -xp.inf:
max_lower_bound = lower_bound
best_params = self._get_parameters()
best_n_iter = n_iter
Expand All @@ -281,19 +298,19 @@ def fit_predict(self, X, y=None):
ConvergenceWarning,
)

self._set_parameters(best_params)
self._set_parameters(best_params, xp=xp)
self.n_iter_ = best_n_iter
self.lower_bound_ = max_lower_bound
self.lower_bounds_ = best_lower_bounds

# Always do a final e-step to guarantee that the labels returned by
# fit_predict(X) are always consistent with fit(X).predict(X)
# for any value of max_iter and tol (and any random_state).
_, log_resp = self._e_step(X)
_, log_resp = self._e_step(X, xp=xp)

return log_resp.argmax(axis=1)
return xp.argmax(log_resp, axis=1)

def _e_step(self, X):
def _e_step(self, X, xp=None):
"""E step.

Parameters
Expand All @@ -309,8 +326,9 @@ def _e_step(self, X):
Logarithm of the posterior probabilities (or responsibilities) of
the point of each sample in X.
"""
log_prob_norm, log_resp = self._estimate_log_prob_resp(X)
return np.mean(log_prob_norm), log_resp
xp, _ = get_namespace(X, xp=xp)
log_prob_norm, log_resp = self._estimate_log_prob_resp(X, xp=xp)
return xp.mean(log_prob_norm), log_resp

@abstractmethod
def _m_step(self, X, log_resp):
Expand Down Expand Up @@ -351,7 +369,7 @@ def score_samples(self, X):
check_is_fitted(self)
X = validate_data(self, X, reset=False)

return logsumexp(self._estimate_weighted_log_prob(X), axis=1)
return _logsumexp(self._estimate_weighted_log_prob(X), axis=1)

def score(self, X, y=None):
"""Compute the per-sample average log-likelihood of the given data X.
Expand All @@ -370,7 +388,8 @@ def score(self, X, y=None):
log_likelihood : float
Log-likelihood of `X` under the Gaussian mixture model.
"""
return self.score_samples(X).mean()
xp, _ = get_namespace(X)
return float(xp.mean(self.score_samples(X)))

def predict(self, X):
"""Predict the labels for the data samples in X using trained model.
Expand All @@ -387,8 +406,9 @@ def predict(self, X):
Component labels.
"""
check_is_fitted(self)
xp, _ = get_namespace(X)
X = validate_data(self, X, reset=False)
return self._estimate_weighted_log_prob(X).argmax(axis=1)
return xp.argmax(self._estimate_weighted_log_prob(X), axis=1)

def predict_proba(self, X):
"""Evaluate the components' density for each sample.
Expand All @@ -406,8 +426,9 @@ def predict_proba(self, X):
"""
check_is_fitted(self)
X = validate_data(self, X, reset=False)
_, log_resp = self._estimate_log_prob_resp(X)
return np.exp(log_resp)
xp, _ = get_namespace(X)
_, log_resp = self._estimate_log_prob_resp(X, xp=xp)
return xp.exp(log_resp)

def sample(self, n_samples=1):
"""Generate random samples from the fitted Gaussian distribution.
Expand All @@ -426,6 +447,7 @@ def sample(self, n_samples=1):
Component labels.
"""
check_is_fitted(self)
xp, _, device_ = get_namespace_and_device(self.means_)

if n_samples < 1:
raise ValueError(
Expand All @@ -435,22 +457,30 @@ def sample(self, n_samples=1):

_, n_features = self.means_.shape
rng = check_random_state(self.random_state)
n_samples_comp = rng.multinomial(n_samples, self.weights_)
n_samples_comp = rng.multinomial(
n_samples, _convert_to_numpy(self.weights_, xp)
)

if self.covariance_type == "full":
X = np.vstack(
[
rng.multivariate_normal(mean, covariance, int(sample))
for (mean, covariance, sample) in zip(
self.means_, self.covariances_, n_samples_comp
_convert_to_numpy(self.means_, xp),
_convert_to_numpy(self.covariances_, xp),
n_samples_comp,
)
]
)
elif self.covariance_type == "tied":
X = np.vstack(
[
rng.multivariate_normal(mean, self.covariances_, int(sample))
for (mean, sample) in zip(self.means_, n_samples_comp)
rng.multivariate_normal(
mean, _convert_to_numpy(self.covariances_, xp), int(sample)
)
for (mean, sample) in zip(
_convert_to_numpy(self.means_, xp), n_samples_comp
)
]
)
else:
Expand All @@ -460,18 +490,23 @@ def sample(self, n_samples=1):
+ rng.standard_normal(size=(sample, n_features))
* np.sqrt(covariance)
for (mean, covariance, sample) in zip(
self.means_, self.covariances_, n_samples_comp
_convert_to_numpy(self.means_, xp),
_convert_to_numpy(self.covariances_, xp),
n_samples_comp,
)
]
)

y = np.concatenate(
[np.full(sample, j, dtype=int) for j, sample in enumerate(n_samples_comp)]
y = xp.concat(
[
xp.full(int(n_samples_comp[i]), i, dtype=xp.int64, device=device_)
for i in range(len(n_samples_comp))
]
)

return (X, y)
return xp.asarray(X, device=device_), y

def _estimate_weighted_log_prob(self, X):
def _estimate_weighted_log_prob(self, X, xp=None):
"""Estimate the weighted log-probabilities, log P(X | Z) + log weights.

Parameters
Expand All @@ -482,10 +517,10 @@ def _estimate_weighted_log_prob(self, X):
-------
weighted_log_prob : array, shape (n_samples, n_component)
"""
return self._estimate_log_prob(X) + self._estimate_log_weights()
return self._estimate_log_prob(X, xp=xp) + self._estimate_log_weights(xp=xp)

@abstractmethod
def _estimate_log_weights(self):
def _estimate_log_weights(self, xp=None):
"""Estimate log-weights in EM algorithm, E[ log pi ] in VB algorithm.

Returns
Expand All @@ -495,7 +530,7 @@ def _estimate_log_weights(self):
pass

@abstractmethod
def _estimate_log_prob(self, X):
def _estimate_log_prob(self, X, xp=None):
"""Estimate the log-probabilities log P(X | Z).

Compute the log-probabilities per each component for each sample.
Expand All @@ -510,7 +545,7 @@ def _estimate_log_prob(self, X):
"""
pass

def _estimate_log_prob_resp(self, X):
def _estimate_log_prob_resp(self, X, xp=None):
"""Estimate log probabilities and responsibilities for each sample.

Compute the log probabilities, weighted log probabilities per
Expand All @@ -529,11 +564,17 @@ def _estimate_log_prob_resp(self, X):
log_responsibilities : array, shape (n_samples, n_components)
logarithm of the responsibilities
"""
weighted_log_prob = self._estimate_weighted_log_prob(X)
log_prob_norm = logsumexp(weighted_log_prob, axis=1)
with np.errstate(under="ignore"):
xp, _ = get_namespace(X, xp=xp)
weighted_log_prob = self._estimate_weighted_log_prob(X, xp=xp)
log_prob_norm = _logsumexp(weighted_log_prob, axis=1, xp=xp)

# There is no errstate equivalent for warning/error management in array API
context_manager = (
np.errstate(under="ignore") if _is_numpy_namespace(xp) else nullcontext()
)
with context_manager:
# ignore underflow
log_resp = weighted_log_prob - log_prob_norm[:, np.newaxis]
log_resp = weighted_log_prob - log_prob_norm[:, xp.newaxis]
return log_prob_norm, log_resp

def _print_verbose_msg_init_beg(self, n_init):
Expand Down
Loading
0