8000 FIX correctly initialize precisions_cholesky_ in GaussianMixture (#22… · scikit-learn/scikit-learn@3e460e8 · GitHub
[go: up one dir, main page]

Skip to content

Commit 3e460e8

Browse files
authored
FIX correctly initialize precisions_cholesky_ in GaussianMixture (#22058)
1 parent 0882bd3 commit 3e460e8

File tree

3 files changed

+73
-1
lines changed

3 files changed

+73
-1
lines changed

doc/whats_new/v1.1.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,14 @@ Changelog
316316
now validate input parameters in `fit` instead of `__init__`.
317317
:pr:`21880` by :user:`Mrinal Tyagi <MrinalTyagi>`.
318318

319+
:mod:`sklearn.mixture`
320+
......................
321+
322+
- |Fix| Fix a bug that correctly initialize `precisions_cholesky_` in
323+
:class:`mixture.GaussianMixture` when providing `precisions_init` by taking
324+
its square root.
325+
:pr:`22058` by :user:`Guillaume Lemaitre <glemaitre>`.
326+
319327
:mod:`sklearn.pipeline`
320328
.......................
321329

sklearn/mixture/_gaussian_mixture.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -728,7 +728,7 @@ def _initialize(self, X, resp):
728728
self.precisions_init, lower=True
729729
)
730730
else:
731-
self.precisions_cholesky_ = self.precisions_init
731+
self.precisions_cholesky_ = np.sqrt(self.precisions_init)
732732

733733
def _m_step(self, X, log_resp):
734734
"""M step.

sklearn/mixture/tests/test_gaussian_mixture.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import numpy as np
1212
from scipy import stats, linalg
1313

14+
from sklearn.cluster import KMeans
1415
from sklearn.covariance import EmpiricalCovariance
1516
from sklearn.datasets import make_spd_matrix
1617
from io import StringIO
@@ -21,6 +22,7 @@
2122
_estimate_gaussian_covariances_tied,
2223
_estimate_gaussian_covariances_diag,
2324
_estimate_gaussian_covariances_spherical,
25+
_estimate_gaussian_parameters,
2426
_compute_precision_cholesky,
2527
_compute_log_det_cholesky,
2628
)
@@ -1241,6 +1243,7 @@ def test_gaussian_mixture_setting_best_params():
12411243
random_state=rnd,
12421244
n_components=len(weights_init),
12431245
precisions_init=precisions_init,
1246+
max_iter=1,
12441247
)
12451248
# ensure that no error is thrown during fit
12461249
gmm.fit(X)
@@ -1258,3 +1261,64 @@ def test_gaussian_mixture_setting_best_params():
12581261
"lower_bound_",
12591262
]:
12601263
assert hasattr(gmm, attr)
1264+
1265+
1266+
def test_gaussian_mixture_precisions_init_diag():
1267+
"""Check that we properly initialize `precision_cholesky_` when we manually
1268+
provide the precision matrix.
1269+
1270+
In this regard, we check the consistency between estimating the precision
1271+
matrix and providing the same precision matrix as initialization. It should
1272+
lead to the same results with the same number of iterations.
1273+
1274+
If the initialization is wrong then the number of iterations will increase.
1275+
1276+
Non-regression test for:
1277+
https://github.com/scikit-learn/scikit-learn/issues/16944
1278+
"""
1279+
# generate a toy dataset
1280+
n_samples = 300
1281+
rng = np.random.RandomState(0)
1282+
shifted_gaussian = rng.randn(n_samples, 2) + np.array([20, 20])
1283+
C = np.array([[0.0, -0.7], [3.5, 0.7]])
1284+
stretched_gaussian = np.dot(rng.randn(n_samples, 2), C)
1285+
X = np.vstack([shifted_gaussian, stretched_gaussian])
1286+
1287+
# common parameters to check the consistency of precision initialization
1288+
n_components, covariance_type, reg_covar, random_state = 2, "diag", 1e-6, 0
1289+
1290+
# execute the manual initialization to compute the precision matrix:
1291+
# - run KMeans to have an initial guess
1292+
# - estimate the covariance
1293+
# - compute the precision matrix from the estimated covariance
1294+
resp = np.zeros((X.shape[0], n_components))
1295+
label = (
1296+
KMeans(n_clusters=n_components, n_init=1, random_state=random_state)
1297+
.fit(X)
1298+
.labels_
1299+
)
1300+
resp[np.arange(X.shape[0]), label] = 1
1301+
_, _, covariance = _estimate_gaussian_parameters(
1302+
X, resp, reg_covar=reg_covar, covariance_type=covariance_type
1303+
)
1304+
precisions_init = 1 / covariance
1305+
1306+
gm_with_init = GaussianMixture(
1307+
n_components=n_components,
1308+
covariance_type=covariance_type,
1309+
reg_covar=reg_covar,
1310+
precisions_init=precisions_init,
1311+
random_state=random_state,
1312+
).fit(X)
1313+
1314+
gm_without_init = GaussianMixture(
1315+
n_components=n_components,
1316+
covariance_type=covariance_type,
1317+
reg_covar=reg_covar,
1318+
random_state=random_state,
1319+
).fit(X)
1320+
1321+
assert gm_without_init.n_iter_ == gm_with_init.n_iter_
1322+
assert_allclose(
1323+
gm_with_init.precisions_cholesky_, gm_without_init.precisions_cholesky_
1324+
)

0 commit comments

Comments
 (0)
0