@@ -105,10 +105,10 @@ def sample_gaussian(mean, covar, cvtype='diag', n_samples=1):
105
105
obs : array, shape (n_features, n)
106
106
Randomly generated sample
107
107
"""
108
- ndim = len (mean )
109
- rand = np .random .randn (ndim , n_samples )
108
+ n_dim = len (mean )
109
+ rand = np .random .randn (n_dim , n_samples )
110
110
if n_samples == 1 :
111
- rand .shape = (ndim ,)
111
+ rand .shape = (n_dim ,)
112
112
113
113
if cvtype == 'spherical' :
114
114
rand *= np .sqrt (covar )
@@ -526,11 +526,11 @@ def _do_mstep(self, X, posteriors, params, min_covar=0):
526
526
527
527
528
528
def _lmvnpdfdiag (obs , means = 0.0 , covars = 1.0 ):
529
- nobs , ndim = obs .shape
529
+ nobs , n_dim = obs .shape
530
530
# (x-y).T A (x-y) = x.T A x - 2x.T A y + y.T A y
531
531
#lpr = -0.5 * (np.tile((np.sum((means**2) / covars, 1)
532
532
# + np.sum(np.log(covars), 1))[np.newaxis,:], (nobs,1))
533
- lpr = - 0.5 * (ndim * np .log (2 * np .pi ) + np .sum (np .log (covars ), 1 )
533
+ lpr = - 0.5 * (n_dim * np .log (2 * np .pi ) + np .sum (np .log (covars ), 1 )
534
534
+ np .sum ((means ** 2 ) / covars , 1 )
535
535
- 2 * np .dot (obs , (means / covars ).T )
536
536
+ np .dot (obs ** 2 , (1.0 / covars ).T ))
@@ -546,10 +546,10 @@ def _lmvnpdfspherical(obs, means=0.0, covars=1.0):
546
546
547
547
def _lmvnpdftied (obs , means , covars ):
548
548
from scipy import linalg
549
- nobs , ndim = obs .shape
549
+ nobs , n_dim = obs .shape
550
550
# (x-y).T A (x-y) = x.T A x - 2x.T A y + y.T A y
551
551
icv = linalg .pinv (covars )
552
- lpr = - 0.5 * (ndim * np .log (2 * np .pi ) + np .log (linalg .det (covars ))
552
+ lpr = - 0.5 * (n_dim * np .log (2 * np .pi ) + np .log (linalg .det (covars ))
553
553
+ np .sum (obs * np .dot (obs , icv ), 1 )[:,np .newaxis ]
554
554
- 2 * np .dot (np .dot (obs , icv ), means .T )
555
555
+ np .sum (means * np .dot (means , icv ), 1 ))
@@ -568,42 +568,42 @@ def _lmvnpdffull(obs, means, covars):
568
568
else :
569
569
# slower, but works
570
570
solve_triangular = linalg .solve
571
- nobs , ndim = obs .shape
571
+ nobs , n_dim = obs .shape
572
572
nmix = len (means )
573
573
log_prob = np .empty ((nobs ,nmix ))
574
574
for c , (mu , cv ) in enumerate (itertools .izip (means , covars )):
575
575
cv_chol = linalg .cholesky (cv , lower = True )
576
576
cv_log_det = 2 * np .sum (np .log (np .diagonal (cv_chol )))
577
577
cv_sol = solve_triangular (cv_chol , (obs - mu ).T , lower = True ).T
578
578
log_prob [:, c ] = - .5 * (np .sum (cv_sol ** 2 , axis = 1 ) + \
579
- ndim * np .log (2 * np .pi ) + cv_log_det )
579
+ n_dim * np .log (2 * np .pi ) + cv_log_det )
580
580
581
581
return log_prob
582
582
583
583
584
- def _validate_covars (covars , cvtype , nmix , ndim ):
584
+ def _validate_covars (covars , cvtype , nmix , n_dim ):
585
585
from scipy import linalg
586
586
if cvtype == 'spherical' :
587
587
if len (covars ) != nmix :
588
588
raise ValueError ("'spherical' covars must have length nmix" )
589
589
elif np .any (covars <= 0 ):
590
590
raise ValueError ("'spherical' covars must be non-negative" )
591
591
elif cvtype == 'tied' :
592
- if covars .shape != (ndim , ndim ):
593
- raise ValueError ("'tied' covars must have shape (ndim, ndim )" )
592
+ if covars .shape != (n_dim , n_dim ):
593
+ raise ValueError ("'tied' covars must have shape (n_dim, n_dim )" )
594
594
elif (not np .allclose (covars , covars .T )
595
595
or np .any (linalg .eigvalsh (covars ) <= 0 )):
596
596
raise ValueError ("'tied' covars must be symmetric, "
597
597
"positive-definite" )
598
598
elif cvtype == 'diag' :
599
- if covars .shape != (nmix , ndim ):
600
- raise ValueError ("'diag' covars must have shape (nmix, ndim )" )
599
+ if covars .shape != (nmix , n_dim ):
600
+ raise ValueError ("'diag' covars must have shape (nmix, n_dim )" )
601
601
elif np .any (covars <= 0 ):
602
602
raise ValueError ("'diag' covars must be non-negative" )
603
603
elif cvtype == 'full' :
604
- if covars .shape != (nmix , ndim , ndim ):
604
+ if covars .shape != (nmix , n_dim , n_dim ):
605
605
raise ValueError ("'full' covars must have shape "
606
- "(nmix, ndim, ndim )" )
606
+ "(nmix, n_dim, n_dim )" )
607
607
for n ,cv in enumerate (covars ):
608
608
if (not np .allclose (cv , cv .T )
609
609
or np .any (linalg .eigvalsh (cv ) <= 0 )):
0 commit comments