8000 Add some tests. · scikit-learn/scikit-learn@4965259 · GitHub
[go: up one dir, main page]

Skip to content

Commit 4965259

Browse files
committed
Add some tests.
Remove @ignore_warnings.
1 parent e55a94f commit 4965259

File tree

4 files changed

+404
-292
lines changed

4 files changed

+404
-292
lines changed

sklearn/mixture/base.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,16 @@ def _check_initial_parameters(self, X):
9090
----------
9191
X : array-like, shape (n_samples, n_features)
9292
"""
93+
if self.n_components < 1:
94+
raise ValueError("Invalid value for 'n_components': %d "
95+
"Estimation requires at least one component"
96+
% self.n_components)
97+
98+
if self.tol < 0.:
99+
raise ValueError("Invalid value for 'tol': %.5f "
100+
"Tolerance used by the EM must be non-negative"
101+
% self.tol)
102+
93103
if self.n_init < 1:
94104
raise ValueError("Invalid value for 'n_init': %d "
95105
"Estimation requires at least one run"

sklearn/mixture/gaussian_mixture.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -438,7 +438,7 @@ class GaussianMixture(BaseMixture):
438438
'diag' (each component has its own diagonal covariance matrix),
439439
'spherical' (each component has its own single variance),
440440
441-
tol : float, defaults to 1e-6.
441+
tol : float, defaults to 1e-3.
442442
The convergence threshold. EM iterations will stop when the
443443
log_likelihood average gain is below this threshold.
444444
@@ -518,7 +518,7 @@ class GaussianMixture(BaseMixture):
518518
`n_iter_` will not exist before a call to fit.
519519
"""
520520

521-
def __init__(self, n_components=1, covariance_type='full', tol=1e-6,
521+
def __init__(self, n_components=1, covariance_type='full', tol=1e-3,
522522
reg_covar=1e-6, max_iter=100, n_init=1, init_params='kmeans',
523523
weights_init=None, means_init=None, covariances_init=None,
524524
random_state=None, warm_start=False,

sklearn/mixture/tests/test_gaussian_mixture.py

Lines changed: 96 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -86,52 +86,79 @@ def __init__(self, rng, n_samples=500, n_components=2, n_features=2,
8686
for k, w in enumerate(self.weights)])
8787

8888

89-
def test_gaussian_mixture_parameters():
89+
def test_gaussian_mixture_attributes():
9090
# test bad parameters
9191
rng = np.random.RandomState(0)
9292
X = rng.rand(10, 2)
9393

94-
n_init = 0
95-
gmm = GaussianMixture(n_init=n_init)
94+
n_components_bad = 0
95+
gmm = GaussianMixture(n_components=n_components_bad)
9696
assert_raise_message(ValueError,
97-
"Invalid value for 'n_init': %d "
98-
"Estimation requires at least one run"
99-
% n_init,
100-
gmm.fit, X)
97+
"Invalid value for 'n_components': %d "
98+
"Estimation requires at least one component"
99+
% n_components_bad, gmm.fit, X)
101100

102-
max_iter = 0
103-
gmm = GaussianMixture(max_iter=max_iter)
101+
# covariance_type should be in [spherical, diag, tied, full]
102+
covariance_type_bad = 'bad_covariance_type'
103+
gmm = GaussianMixture(covariance_type=covariance_type_bad)
104104
assert_raise_message(ValueError,
105-
"Invalid value for 'max_iter': %d "
106-
"Estimation requires at least one iteration"
107-
% max_iter,
105+
"Invalid value for 'covariance_type': %s "
106+
"'covariance_type' should be in "
107+
"['spherical', 'tied', 'diag', 'full']"
108+
% covariance_type_bad,
108109
gmm.fit, X)
109110

110-
reg_covar = -1
111-
gmm = GaussianMixture(reg_covar=reg_covar)
111+
tol_bad = -1
112+
gmm = GaussianMixture(tol=tol_bad)
113+
assert_raise_message(ValueError,
114+
"Invalid value for 'tol': %.5f "
115+
"Tolerance used by the EM must be non-negative"
116+
% tol_bad, gmm.fit, X)
117+
118+
reg_covar_bad = -1
119+
gmm = GaussianMixture(reg_covar=reg_covar_bad)
112120
assert_raise_message(ValueError,
113121
"Invalid value for 'reg_covar': %.5f "
114122
"regularization on covariance must be "
115-
"non-negative" % reg_covar,
116-
gmm.fit, X)
123+
"non-negative" % reg_covar_bad, gmm.fit, X)
117124

118-
# covariance_type should be in [spherical, diag, tied, full]
119-
covariance_type = 'bad_covariance_type'
120-
gmm = GaussianMixture(covariance_type=covariance_type)
125+
max_iter_bad = 0
126+
gmm = GaussianMixture(max_iter=max_iter_bad)
121127
assert_raise_message(ValueError,
122-
"Invalid value for 'covariance_type': %s "
123-
"'covariance_type' should be in "
124-
"['spherical', 'tied', 'diag', 'full']"
125-
% covariance_type,
126-
gmm.fit, X)
128+
"Invalid value for 'max_iter': %d "
129+
"Estimation requires at least one iteration"
130+
% max_iter_bad, gmm.fit, X)
131+
132+
n_init_bad = 0
133+
gmm = GaussianMixture(n_init=n_init_bad)
134+
assert_raise_message(ValueError,
135+
"Invalid value for 'n_init': %d "
136+
"Estimation requires at least one run"
137+
% n_init_bad, gmm.fit, X)
127138

128-
init_params = 'bad_method'
129-
gmm = GaussianMixture(init_params=init_params)
139+
init_params_bad = 'bad_method'
140+
gmm = GaussianMixture(init_params=init_params_bad)
130141
assert_raise_message(ValueError,
131142
"Unimplemented initialization method '%s'"
132-
% init_params,
143+
% init_params_bad,
133144
gmm.fit, X)
134145

146+
# test good parameters
147+
n_components, tol, n_init, max_iter, reg_covar = 2, 1e-4, 3, 30, 1e-1
148+
covariance_type, init_params = 'full', 'random'
149+
gmm = GaussianMixture(n_components=n_components, tol=tol, n_init=n_init,
150+
max_iter=max_iter, reg_covar=reg_covar,
151+
covariance_type=covariance_type,
152+
init_params=init_params).fit(X)
153+
154+
assert_equal(gmm.n_components, n_components)
155+
assert_equal(gmm.covariance_type, covariance_type)
156+
assert_equal(gmm.tol, tol)
157+
assert_equal(gmm.reg_covar, reg_covar)
158+
assert_equal(gmm.max_iter, max_iter)
159+
assert_equal(gmm.n_init, n_init)
160+
assert_equal(gmm.init_params, init_params)
161+
135162

136163
def test_check_X():
137164
from sklearn.mixture.base import _check_X
@@ -447,6 +474,9 @@ def test_gaussian_mixture_estimate_log_prob_resp():
447474
g.fit(X)
448475
resp = g.predict_proba(X)
449476
assert_array_almost_equal(resp.sum(axis=1), np.ones(n_samples))
477+
assert_array_equal(g.weights_init, weights)
478+
assert_array_equal(g.means_init, means)
479+
assert_array_equal(g.covariances_init, covariances)
450480

451481

452482
def test_gaussian_mixture_predict_predict_proba():
@@ -560,6 +590,21 @@ def test_gaussian_mixture_fit_convergence_warning():
560590
% max_iter, g.fit, X)
561591

562592

593+
def test_multiple_init():
594+
# Test that multiple inits does not much worse than a single one
595+
rng = np.random.RandomState(0)
596+
n_samples, n_features, n_components = 50, 5, 2
597+
X = rng.randn(n_samples, n_features)
598+
for cv_type in COVARIANCE_TYPE:
599+
train1 = GaussianMixture(n_components=n_components,
600+
covariance_type=cv_type,
601+
random_state=rng).fit(X).score(X)
602+
train2 = GaussianMixture(n_components=n_components,
603+
covariance_type=cv_type,
604+
random_state=rng, n_init=5).fit(X).score(X)
605+
assert_greater_equal(train2, train1)
606+
607+
563608
def test_gaussian_mixture_n_parameters():
564609
# Test that the right number of parameters is estimated
565610
F422 rng = np.random.RandomState(0)
@@ -573,6 +618,22 @@ def test_gaussian_mixture_n_parameters():
573618
assert_equal(g._n_parameters(), n_params[cv_type])
574619

575620

621+
def test_bic_1d_1component():
622+
# Test all of the covariance_types return the same BIC score for
623+
# 1-dimensional, 1 component fits.
624+
rng = np.random.RandomState(0)
625+
n_samples, n_dim, n_components = 100, 1, 1
626+
X = rng.randn(n_samples, n_dim)
627+
bic_full = GaussianMixture(n_components=n_components,
628+
covariance_type='full',
629+
random_state=rng).fit(X).bic(X)
630+
for covariance_type in ['tied', 'diag', 'spherical']:
631+
bic = GaussianMixture(n_components=n_components,
632+
covariance_type=covariance_type,
633+
random_state=rng).fit(X).bic(X)
634+
assert_almost_equal(bic_full, bic)
635+
636+
576637
def test_gaussian_mixture_aic_bic():
577638
# Test the aic and bic criteria
578639
rng = np.random.RandomState(0)
@@ -644,10 +705,10 @@ def test_warm_start():
644705
# Assert that by using warm_start we can converge to a good solution
645706
g = GaussianMixture(n_components=n_components, n_init=1,
646707
max_iter=5, reg_covar=0, random_state=random_state,
647-
warm_start=False)
708+
warm_start=False, tol=1e-6)
648709
h = GaussianMixture(n_components=n_components, n_init=1,
649710
max_iter=5, reg_covar=0, random_state=random_state,
650-
warm_start=True)
711+
warm_start=True, tol=1e-6)
651712

652713
with warnings.catch_warnings():
653714
warnings.simplefilter("ignore", ConvergenceWarning)
@@ -720,10 +781,13 @@ def test_monotonic_likelihood():
720781
X = rand_data.X[cov_type]
721782
gmm = GaussianMixture(n_components=n_components,
722783
covariance_type=cov_type, reg_covar=0,
723-
warm_start=True, max_iter=1, random_state=rng)
784+
warm_start=True, max_iter=1, random_state=rng,
785+
tol=1e-7)
724786
current_log_likelihood = -np.infty
725787
with warnings.catch_warnings():
726788
warnings.simplefilter("ignore", ConvergenceWarning)
789+
# Do one training iteration at a time so we can make sure that the
790+
# training log likelihood increases after each iteration.
727791
for _ in range(300):
728792
prev_log_likelihood = current_log_likelihood
729793
try:
@@ -738,6 +802,8 @@ def test_monotonic_likelihood():
738802

739803

740804
def test_regularisation():
805+
# We train the GaussianMixture on degenerate data by defining two clusters
806+
# of a 0 covariance.
741807
rng = np.random.RandomState(0)
742808
n_samples, n_features = 10, 5
743809

0 commit comments

Comments
 (0)
0