11
11
import numpy as np
12
12
from scipy import stats , linalg
13
13
14
+ from sklearn .cluster import KMeans
14
15
from sklearn .covariance import EmpiricalCovariance
15
16
from sklearn .datasets import make_spd_matrix
16
17
from io import StringIO
21
22
_estimate_gaussian_covariances_tied ,
22
23
_estimate_gaussian_covariances_diag ,
23
24
_estimate_gaussian_covariances_spherical ,
25
+ _estimate_gaussian_parameters ,
24
26
_compute_precision_cholesky ,
25
27
_compute_log_det_cholesky ,
26
28
)
@@ -1241,6 +1243,7 @@ def test_gaussian_mixture_setting_best_params():
1241
1243
random_state = rnd ,
1242
1244
n_components = len (weights_init ),
1243
1245
precisions_init = precisions_init ,
1246
+ max_iter = 1 ,
1244
1247
)
1245
1248
# ensure that no error is thrown during fit
1246
1249
gmm .fit (X )
@@ -1258,3 +1261,64 @@ def test_gaussian_mixture_setting_best_params():
1258
1261
"lower_bound_" ,
1259
1262
]:
1260
1263
assert hasattr (gmm , attr )
1264
+
1265
+
1266
+ def test_gaussian_mixture_precisions_init_diag ():
1267
+ """Check that we properly initialize `precision_cholesky_` when we manually
1268
+ provide the precision matrix.
1269
+
1270
+ In this regard, we check the consistency between estimating the precision
1271
+ matrix and providing the same precision matrix as initialization. It should
1272
+ lead to the same results with the same number of iterations.
1273
+
1274
+ If the initialization is wrong then the number of iterations will increase.
1275
+
1276
+ Non-regression test for:
1277
+ https://github.com/scikit-learn/scikit-learn/issues/16944
1278
+ """
1279
+ # generate a toy dataset
1280
+ n_samples = 300
1281
+ rng = np .random .RandomState (0 )
1282
+ shifted_gaussian = rng .randn (n_samples , 2 ) + np .array ([20 , 20 ])
1283
+ C = np .array ([[0.0 , - 0.7 ], [3.5 , 0.7 ]])
1284
+ stretched_gaussian = np .dot (rng .randn (n_samples , 2 ), C )
1285
+ X = np .vstack ([shifted_gaussian , stretched_gaussian ])
1286
+
1287
+ # common parameters to check the consistency of precision initialization
1288
+ n_components , covariance_type , reg_covar , random_state = 2 , "diag" , 1e-6 , 0
1289
+
1290
+ # execute the manual initialization to compute the precision matrix:
1291
+ # - run KMeans to have an initial guess
1292
+ # - estimate the covariance
1293
+ # - compute the precision matrix from the estimated covariance
1294
+ resp = np .zeros ((X .shape [0 ], n_components ))
1295
+ label = (
1296
+ KMeans (n_clusters = n_components , n_init = 1 , random_state = random_state )
1297
+ .fit (X )
1298
+ .labels_
1299
+ )
1300
+ resp [np .arange (X .shape [0 ]), label ] = 1
1301
+ _ , _ , covariance = _estimate_gaussian_parameters (
1302
+ X , resp , reg_covar = reg_covar , covariance_type = covariance_type
1303
+ )
1304
+ precisions_init = 1 / covariance
1305
+
1306
+ gm_with_init = GaussianMixture (
1307
+ n_components = n_components ,
1308
+ covariance_type = covariance_type ,
1309
+ reg_covar = reg_covar ,
1310
+ precisions_init = precisions_init ,
1311
+ random_state = random_state ,
1312
+ ).fit (X )
1313
+
1314
+ gm_without_init = GaussianMixture (
1315
+ n_components = n_components ,
1316
+ covariance_type = covariance_type ,
1317
+ reg_covar = reg_covar ,
1318
+ random_state = random_state ,
1319
+ ).fit (X )
1320
+
1321
+ assert gm_without_init .n_iter_ == gm_with_init .n_iter_
1322
+ assert_allclose (
1323
+ gm_with_init .precisions_cholesky_ , gm_without_init .precisions_cholesky_
1324
+ )
0 commit comments