From eda9911943b9175011ad36b38a6b857353366a2f Mon Sep 17 00:00:00 2001 From: Vladimir Feinberg Date: Tue, 26 Apr 2016 13:51:37 -0400 Subject: [PATCH] Modified GMM initialization to only use linear memory and time in spherical and diag cases --- sklearn/mixture/gmm.py | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/sklearn/mixture/gmm.py b/sklearn/mixture/gmm.py index 85aa384a22e98..6735ba75abdb9 100644 --- a/sklearn/mixture/gmm.py +++ b/sklearn/mixture/gmm.py @@ -491,7 +491,12 @@ def _fit(self, X, y=None, do_prediction=False): print('\tWeights have been initialized.') if 'c' in self.init_params or not hasattr(self, 'covars_'): - cv = np.cov(X.T) + self.min_covar * np.eye(X.shape[1]) + # Only compute the full covariance if full or tied + cv = None + if self.covariance_type in ['full', 'tied']: + cv = np.cov(X.T) + self.min_covar * np.eye(X.shape[1]) + else: + cv = np.var(X, axis=0) + self.min_covar if not cv.shape: cv.shape = (1, 1) self.covars_ = \ @@ -759,15 +764,24 @@ def _validate_covars(covars, covariance_type, n_components): def distribute_covar_matrix_to_match_covariance_type( tied_cv, covariance_type, n_components): - """Create all the covariance matrices from a given template.""" + """Create all the covariance matrices from a given template. + + If covariance_type is 'spherical' or 'diag', then tied_cv should be + a vector, representing the diagonal of the initial tied covariance matrix. + + Otherwise, tied_cv should be the entire covariance matrix.""" if covariance_type == 'spherical': - cv = np.tile(tied_cv.mean() * np.ones(tied_cv.shape[1]), + assert len(tied_cv.shape) == 1 + cv = np.tile(tied_cv.mean() * np.ones(len(tied_cv)), (n_components, 1)) elif covariance_type == 'tied': + assert len(tied_cv.shape) == 2 cv = tied_cv elif covariance_type == 'diag': - cv = np.tile(np.diag(tied_cv), (n_components, 1)) + assert len(tied_cv.shape) == 1 + cv = np.tile(tied_cv, (n_components, 1)) elif covariance_type == 'full': + assert len(tied_cv.shape) == 2 cv = np.tile(tied_cv, (n_components, 1, 1)) else: raise ValueError("covariance_type must be one of " +