8000 Correct Deprecation of DPGMM and VBGMM. (#7124) · scikit-learn/scikit-learn@05dfe6c · GitHub
[go: up one dir, main page]

Skip to content

Commit 05dfe6c

Browse files
tguillemotagramfort
authored andcommitted
Correct Deprecation of DPGMM and VBGMM. (#7124)
1 parent 5bddfdb commit 05dfe6c

File tree

3 files changed

+29
-19
lines changed

3 files changed

+29
-19
lines changed

sklearn/mixture/dpgmm.py

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -116,10 +116,7 @@ def _bound_state_log_lik(X, initial_bound, precs, means, covariance_type):
116116
return bound
117117

118118

119-
@deprecated("The DPGMM class is not working correctly and it's better "
120-
"to not use it. DPGMM is deprecated in 0.18 and "
121-
"will be removed in 0.20.")
122-
class DPGMM(_GMMBase):
119+
class _DPGMMBase(_GMMBase):
123120
"""Variational Inference for the Infinite Gaussian Mixture Model.
124121
125122
DPGMM stands for Dirichlet Process Gaussian Mixture Model, and it
@@ -211,16 +208,16 @@ class DPGMM(_GMMBase):
211208
algorithm, better for situations where there might be too little
212209
data to get a good estimate of the covariance matrix.
213210
"""
214-
215211
def __init__(self, n_components=1, covariance_type='diag', alpha=1.0,
216212
random_state=None, tol=1e-3, verbose=0, min_covar=None,
217213
n_iter=10, params='wmc', init_params='wmc'):
218214
self.alpha = alpha
219-
super(DPGMM, self).__init__(n_components, covariance_type,
220-
random_state=random_state,
221-
tol=tol, min_covar=min_covar,
222-
n_iter=n_iter, params=params,
223-
init_params=init_params, verbose=verbose)
215+
super(_DPGMMBase, self).__init__(n_components, covariance_type,
216+
random_state=random_state,
217+
tol=tol, min_covar=min_covar,
218+
n_iter=n_iter, params=params,
219+
init_params=init_params,
220+
verbose=verbose)
224221

225222
def _get_precisions(self):
226223
"""Return precisions as a full matrix."""
@@ -619,10 +616,24 @@ def _fit(self, X, y=None):
619616
return z
620617

621618

622-
@deprecated("The VBGMM class is not working correctly and it's better"
623-
" to not use it. VBGMM is deprecated in 0.18 and "
619+
@deprecated("The DPGMM class is not working correctly and it's better "
620+
"to not use it. DPGMM is deprecated in 0.18 and "
621+
"will be removed in 0.20.")
622+
class DPGMM(_DPGMMBase):
623+
def __init__(self, n_components=1, covariance_type='diag', alpha=1.0,
624+
random_state=None, tol=1e-3, verbose=0, min_covar=None,
625+
n_iter=10, params='wmc', init_params='wmc'):
626+
super(DPGMM, self).__init__(
627+
n_components=n_components, covariance_type=covariance_type,
628+
alpha=alpha, random_state=random_state, tol=tol, verbose=verbose,
629+
min_covar=min_covar, n_iter=n_iter, params=params,
630+
init_params=init_params)
631+
632+
633+
@deprecated("The VBGMM class is not working correctly and it's better "
634+
"to not use it. VBGMM is deprecated in 0.18 and "
624635
"will be removed in 0.20.")
625-
class VBGMM(DPGMM):
636+
class VBGMM(_DPGMMBase):
626637
"""Variational Inference for the Gaussian Mixture Model
627638
628639
Variational inference for a Gaussian mixture model probability

sklearn/mixture/tests/test_dpgmm.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -182,12 +182,11 @@ class TestDPGMMWithFullCovars(unittest.TestCase, DPGMMTester):
182182
setUp = GMMTester._setUp
183183

184184

185-
@ignore_warnings(category=DeprecationWarning)
186185
def test_VBGMM_deprecation():
187-
assert_warns_message(DeprecationWarning, "The VBGMM class is"
188-
" not working correctly and it's better"
189-
" to not use it. VBGMM is deprecated in 0.18"
190-
" and will be removed in 0.20.", VBGMM)
186+
assert_warns_message(
187+
DeprecationWarning,
188+
"The VBGMM class is not working correctly and it's better to not use "
189+
"it. VBGMM is deprecated in 0.18 and will be removed in 0.20.", VBGMM)
191190

192191

193192
class VBGMMTester(GMMTester):

sklearn/mixture/tests/test_gmm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ def test_train_1d(self, params='wmc'):
308308
with ignore_warnings(category=DeprecationWarning):
309309
g.fit(X)
310310
trainll = g.score(X)
311-
if isinstance(g, mixture.DPGMM):
311+
if isinstance(g, mixture.dpgmm._DPGMMBase):
312312
self.assertTrue(np.sum(np.abs(trainll / 100)) < 5)
313313
else:
314314
self.assertTrue(np.sum(np.abs(trainll / 100)) < 2)

0 commit comments

Comments
 (0)
0