@@ -99,7 +99,7 @@ def sample_gaussian(mean, covar, cvtype='diag', n=1):
99
99
100
100
Returns
101
101
-------
102
- obs : array, shape (n, n_features )
102
+ obs : array, shape (n_features, n )
103
103
Randomly generated sample
104
104
"""
105
105
ndim = len (mean )
@@ -403,14 +403,19 @@ def rvs(self, n=1):
403
403
weight_cdf = np .cumsum (weight_pdf )
404
404
405
405
obs = np .empty ((n , self .n_features ))
406
- for x in xrange (n ):
407
- rand = np .random .rand ()
408
- c = (weight_cdf > rand ).argmax ()
409
- if self ._cvtype == 'tied' :
410
- cv = self ._covars
411
- else :
412
- cv = self ._covars [c ]
413
- obs [x ] = sample_gaussian (self ._means [c ], cv , self ._cvtype )
406
+ rand = np .random .rand (n )
407
+ # decide which component to use for each sample
408
+ c = weight_cdf .searchsorted (rand )
409
+ # for each component, generate all needed samples
410
+ for cc in xrange (self ._n_states ):
411
+ ccc = (c == cc ) # occurences of current component in obs
412
+ nccc = ccc .sum () # number of those occurences
413
+ if nccc > 0 :
414
+ if self ._cvtype == 'tied' :
415
+ cv = self ._covars
416
+ else :
417
+ cv = self ._covars [cc ]
418
+ obs [ccc ] = sample_gaussian (self ._means [cc ], cv , self ._cvtype , nccc ).T
414
419
return obs
415
420
416
421
def fit (self , X , n_iter = 10 , min_covar = 1e-3 , thresh = 1e-2 , params = 'wmc' ,
0 commit comments