8000 More cosmetic changes in GMM. · seckcoder/scikit-learn@1f9f965 · GitHub
[go: up one dir, main page]

Skip to content

Commit 1f9f965

Browse files
author
Fabian Pedregosa
committed
More cosmetic changes in GMM.
1 parent e76eec5 commit 1f9f965

File tree

2 files changed

+19
-12
lines changed

2 files changed

+19
-12
lines changed

scikits/learn/mixture.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -79,32 +79,35 @@ def lmvnpdf(obs, means, covars, cvtype='diag'):
7979
return lmvnpdf_dict[cvtype](obs, means, covars)
8080

8181

82-
def sample_gaussian(mean, covar, cvtype='diag', n=1):
82+
def sample_gaussian(mean, covar, cvtype='diag', n_samples=1):
8383
"""Generate random samples from a Gaussian distribution.
8484
8585
Parameters
8686
----------
8787
mean : array_like, shape (n_features,)
8888
Mean of the distribution.
89-
covars : array_like
89+
90+
covars : array_like, optional
9091
Covariance of the distribution. The shape depends on `cvtype`:
9192
scalar if 'spherical',
9293
(D) if 'diag',
9394
(D, D) if 'tied', or 'full'
94-
cvtype : string
95+
96+
cvtype : string, optional
9597
Type of the covariance parameters. Must be one of
9698
'spherical', 'tied', 'diag', 'full'. Defaults to 'diag'.
97-
n : int
98-
Number of samples to generate.
99+
100+
n_samples : int, optional
101+
Number of samples to generate. Defaults to 1.
99102
100103
Returns
101104
-------
102105
obs : array, shape (n_features, n)
103106
Randomly generated sample
104107
"""
105108
ndim = len(mean)
106-
rand = np.random.randn(ndim, n)
107-
if n == 1:
109+
rand = np.random.randn(ndim, n_samples)
110+
if n_samples == 1:
108111
rand.shape = (ndim,)
109112

110113
if cvtype == 'spherical':

scikits/learn/tests/test_mixture.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -62,22 +62,26 @@ def test_sample_gaussian():
6262
mu = np.random.randint(10) * np.random.rand(n_features)
6363
cv = (np.random.rand(n_features) + 1.0) ** 2
6464

65-
samples = mixture.sample_gaussian(mu, cv, cvtype='diag', n=n_samples)
65+
samples = mixture.sample_gaussian(
66+
mu, cv, cvtype='diag', n_samples=n_samples)
6667

6768
assert np.allclose(samples.mean(axis), mu, atol=0.3)
6869
assert np.allclose(samples.var(axis), cv, atol=0.5)
6970

7071
# the same for spherical covariances
7172
cv = (np.random.rand() + 1.0) ** 2
72-
samples = mixture.sample_gaussian(mu, cv, cvtype='spherical', n=n_samples)
73+
samples = mixture.sample_gaussian(
74+
mu, cv, cvtype='spherical', n_samples=n_samples)
7375

7476
assert np.allclose(samples.mean(axis), mu, atol=0.3)
75-
assert np.allclose(samples.var(axis), np.repeat(cv, n_features), atol=0.5)
77+
assert np.allclose(
78+
samples.var(axis), np.repeat(cv, n_features), atol=0.5)
7679

7780
# and for full covariances
7881
A = np.random.randn(n_features, n_features)
7982
cv = np.dot(A.T, A) + np.eye(n_features)
80-
samples = mixture.sample_gaussian(mu, cv, cvtype='full', n=n_samples)
83+
samples = mixture.sample_gaussian(
84+
mu, cv, cvtype='full', n_samples=n_samples)
8185
assert np.allclose(samples.mean(axis), mu, atol=0.3)
8286
assert np.allclose(np.cov(samples), cv, atol=1.)
8387

@@ -219,7 +223,7 @@ def test_train(self, params='wmc'):
219223
g._covars = 20 * self.covars[self.cvtype]
220224

221225
# Create a training set by sampling from the predefined distribution.
222-
train_obs = g.rvs(n=100)
226+
train_obs = g.rvs(n_samples=100)
223227

224228
g.fit(train_obs, n_iter=0, init_params=params)
225229

0 commit comments

Comments
 (0)
0