8000 Cosmetic fixes in GMM. · seckcoder/scikit-learn@e76eec5 · GitHub
[go: up one dir, main page]

Skip to content

Commit e76eec5

Browse files
author
Fabian Pedregosa
committed
Cosmetic fixes in GMM.
1 parent 7c53b1e commit e76eec5

File tree

1 file changed

+16
-12
lines changed

1 file changed

+16
-12
lines changed

scikits/learn/mixture.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -134,9 +134,10 @@ class GMM(BaseEstimator):
134134
135135
Parameters
136136
----------
137-
n_states : int
138-
Number of mixture components.
139-
cvtype : string (read-only)
137+
n_states : int, optional
138+
Number of mixture components. Defaults to 1.
139+
140+
cvtype : string (read-only), optional
140141
String describing the type of covariance parameters to
141142
use. Must be one of 'spherical', 'tied', 'diag', 'full'.
142143
Defaults to 'diag'.
@@ -386,36 +387,39 @@ def predict_proba(self, X):
386387
logprob, posteriors = self.eval(X)
387388
return posteriors
388389

389-
def rvs(self, n=1):
390+
def rvs(self, n_samples=1):
390391
"""Generate random samples from the model.
391392
392393
Parameters
393394
----------
394-
n : int
395-
Number of samples to generate.
395+
n_samples : int, optional
396+
Number of samples to generate. Defaults to 1.
396397
397398
Returns
398399
-------
399-
obs : array_like, shape (n, n_features)
400+
obs : array_like, shape (n_samples, n_features)
400401
List of samples
401402
"""
402403
weight_pdf = self.weights
403404
weight_cdf = np.cumsum(weight_pdf)
404405

405-
obs = np.empty((n, self.n_features))
406-
rand = np.random.rand(n)
406+
obs = np.empty((n_samples, self.n_features))
407+
rand = np.random.rand(n_samples)
407408
# decide which component to use for each sample
408409
comps = weight_cdf.searchsorted(rand)
409410
# for each component, generate all needed samples
410411
for comp in xrange(self._n_states):
411-
comp_in_obs = (comp==comps) # occurrences of current component in obs
412-
num_comp_in_obs = comp_in_obs.sum() # number of those occurrences
412+
# occurrences of current component in obs
413+
comp_in_obs = (comp==comps)
414< 9013 /td>+
# number of those occurrences
415+
num_comp_in_obs = comp_in_obs.sum()
413416
if num_comp_in_obs > 0:
414417
if self._cvtype == 'tied':
415418
cv = self._covars
416419
else:
417420
cv = self._covars[comp]
418-
obs[comp_in_obs] = sample_gaussian(self._means[comp], cv, self._cvtype, num_comp_in_obs).T
421+
obs[comp_in_obs] = sample_gaussian(
422+
self._means[comp], cv, self._cvtype, num_comp_in_obs).T
419423
return obs
420424

421425
def fit(self, X, n_iter=10, min_covar=1e-3, thresh=1e-2, params='wmc',

0 commit comments

Comments
 (0)
0