8000 Improve performance of GMM sampling · seckcoder/scikit-learn@bbb6033 · GitHub
[go: up one dir, main page]

Skip to content

Commit bbb6033

Browse files
author
Fabian Pedregosa
committed
Improve performance of GMM sampling
Patch by f0k.
1 parent a089f38 commit bbb6033

File tree

1 file changed

+14
-9
lines changed

1 file changed

+14
-9
lines changed

scikits/learn/mixture.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def sample_gaussian(mean, covar, cvtype='diag', n=1):
9999
100100
Returns
101101
-------
102-
obs : array, shape (n, n_features)
102+
obs : array, shape (n_features, n)
103103
Randomly generated sample
104104
"""
105105
ndim = len(mean)
@@ -403,14 +403,19 @@ def rvs(self, n=1):
403403
weight_cdf = np.cumsum(weight_pdf)
404404

405405
obs = np.empty((n, self.n_features))
406-
for x in xrange(n):
407-
rand = np.random.rand()
408-
c = (weight_cdf > rand).argmax()
409-
if self._cvtype == 'tied':
410-
cv = self._covars
411-
else:
412-
cv = self._covars[c]
413-
obs[x] = sample_gaussian(self._means[c], cv, self._cvtype)
406+
rand = np.random.rand(n)
407+
# decide which component to use for each sample
408+
c = weight_cdf.searchsorted(rand)
409+
# for each component, generate all needed samples
410+
for cc in xrange(self._n_states):
411+
ccc = (c==cc) # occurences of current component in obs
412+
nccc = ccc.sum() # number of those occurences
413+
if nccc > 0:
414+
if self._cvtype == 'tied':
415+
cv = self._covars
416+
else:
417+
cv = self._covars[cc]
418+
obs[ccc] = sample_gaussian(self._means[cc], cv, self._cvtype, nccc).T
414419
return obs
415420

416421
def fit(self, X, n_iter=10, min_covar=1e-3, thresh=1e-2, params='wmc',

0 commit comments

Comments
 (0)
0