8000 Merge pull request #4284 from hbredin/dgpmm_convergence · ogrisel/scikit-learn@3f5277e · GitHub
[go: up one dir, main page]

Skip to content

Commit 3f5277e

Browse files
committed
Merge pull request scikit-learn#4284 from hbredin/dgpmm_convergence
[MRG + 1] Improve DPGMM/VBGMM convergence check
2 parents 4c389c6 + 4a0e281 commit 3f5277e

File tree

3 files changed

+37
-22
lines changed

3 files changed

+37
-22
lines changed

doc/whats_new.rst

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -190,9 +190,10 @@ Enhancements
190190
- More robust seeding and improved error messages in :class:`cluster.MeanShift`
191191
by `Andreas Müller`_.
192192

193-
- Make the :class:`GMM` stopping criterion less dependent on the number of
194-
samples by thresholding the average log-likelihood change instead of its
195-
sum over all samples. By `Hervé Bredin`_
193+
- Make the stopping criterion for :class:`GMM`, :class:`DPGMM`
194+
and :class:`VBGMM` less dependent on the number of samples by
195+
thresholding the average log-likelihood change instead of its sum over
196+
all samples. By `Hervé Bredin`_
196197

197198
Documentation improvements
198199
..........................
@@ -390,7 +391,8 @@ API changes summary
390391
as the first nearest neighbor.
391392

392393
- `thresh` parameter is deprecated in favor of new `tol` parameter in
393-
:class:`GMM`. See `Enhancements` section for details. By `Hervé Bredin`_.
394+
:class:`GMM`, :class:`DPGMM` and :class:`VBGMM`. See `Enhancements`
395+
section for details. By `Hervé Bredin`_.
394396

395397
- Estimators will treat input with dtype object as numeric when possible.
396398
By `Andreas Müller`_

sklearn/mixture/dpgmm.py

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ class DPGMM(GMM):
142142
higher alpha means more clusters, as the expected number
143143
of clusters is ``alpha*log(N)``.
144144
145-
thresh : float, default 1e-2
145+
tol : float, default 1e-3
146146
Convergence threshold.
147147
148148
n_iter : int, default 10
@@ -198,13 +198,13 @@ class DPGMM(GMM):
198198
"""
199199

200200
def __init__(self, n_components=1, covariance_type='diag', alpha=1.0,
201-
random_state=None, thresh=1e-2, verbose=False,
201+
random_state=None, thresh=None, tol=1e-3, verbose=False,
202202
min_covar=None, n_iter=10, params='wmc', init_params='wmc'):
203203
self.alpha = alpha
204204
self.verbose = verbose
205205
super(DPGMM, self).__init__(n_components, covariance_type,
206-
random_state=random_state,
207-
thresh=thresh, min_covar=min_covar,
206+
random_state=random_state, thresh=thresh,
207+
tol=tol, min_covar=min_covar,
208208
n_iter=n_iter, params=params,
209209
init_params=init_params)
210210

@@ -503,13 +503,13 @@ def fit(self, X, y=None):
503503
"""
504504
self.random_state_ = check_random_state(self.random_state)
505505

506-
## initialization step
506+
# initialization step
507507
X = check_array(X)
508508
if X.ndim == 1:
509509
X = X[:, np.newaxis]
510510

511-
n_features = X.shape[1]
512-
z = np.ones((X.shape[0], self.n_components))
511+
n_samples, n_features = X.shape
512 8000 +
z = np.ones((n_samples, self.n_components))
513513
z /= self.n_components
514514

515515
self._initial_bound = - 0.5 * n_features * np.log(2 * np.pi)
@@ -550,7 +550,7 @@ def fit(self, X, y=None):
550550
self.dof_, self.scale_, self.det_scale_, n_features)
551551
self.bound_prec_ -= 0.5 * self.dof_ * np.trace(self.scale_)
552552
elif self.covariance_type == 'full':
553-
self.dof_ = (1 + self.n_components + X.shape[0])
553+
self.dof_ = (1 + self.n_components + n_samples)
554554
self.dof_ *= np.ones(self.n_components)
555555
self.scale_ = [2 * np.identity(n_features)
556556
for _ in range(self.n_components)]
@@ -566,18 +566,31 @@ def fit(self, X, y=None):
566566
np.trace(self.scale_[k]))
567567
self.bound_prec_ *= 0.5
568568

569-
logprob = []
569+
# EM algorithms
570+
current_log_likelihood = None
570571
# reset self.converged_ to False
571572
self.converged_ = False
573+
574+
# this line should be removed when 'thresh' is removed in v0.18
575+
tol = (self.tol if self.thresh is None
576+
else self.thresh / float(n_samples))
577+
572578
for i in range(self.n_iter):
579+
prev_log_likelihood = current_log_likelihood
573580
# Expectation step
574581
curr_logprob, z = self.score_samples(X)
575-
logprob.append(curr_logprob.sum() + self._logprior(z))
582+
583+
current_log_likelihood = (
584+
curr_logprob.mean() + self._logprior(z) / n_samples)
576585

577586
# Check for convergence.
578-
if i > 0 and abs(logprob[-1] - logprob[-2]) < self.thresh:
579-
self.converged_ = True
580-
break
587+
# (should 9E88 compare to self.tol when dreprecated 'thresh' is
588+
# removed in v0.18)
589+
if prev_log_likelihood is not None:
590+
change = abs(current_log_likelihood - prev_log_likelihood)
591+
if change < tol:
592+
self.converged_ = True
593+
break
581594

582595
# Maximization step
583596
self._do_mstep(X, z, self.params)
@@ -613,7 +626,7 @@ class VBGMM(DPGMM):
613626
value of alpha the more likely the variational mixture of
614627
Gaussians model will use all components it can.
615628
616-
thresh : float, default 1e-2
629+
tol : float, default 1e-3
617630
Convergence threshold.
618631
619632
n_iter : int, default 10
@@ -671,11 +684,11 @@ class VBGMM(DPGMM):
671684
"""
672685

673686
def __init__(self, n_components=1, covariance_type='diag', alpha=1.0,
674-
random_state=None, thresh=1e-2, verbose=False,
687+
random_state=None, thresh=None, tol=1e-3, verbose=False,
675688
min_covar=None, n_iter=10, params='wmc', init_params='wmc'):
676689
super(VBGMM, self).__init__(
677690
n_components, covariance_type, random_state=random_state,
678-
thresh=thresh, verbose=verbose, min_covar=min_covar,
691+
thresh=thresh, tol=tol, verbose=verbose, min_covar=min_covar,
679692
n_iter=n_iter, params=params, init_params=init_params)
680693
self.alpha = float(alpha) / n_components
681694

sklearn/mixture/gmm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -459,7 +459,7 @@ def fit(self, X, y=None):
459459
# reset self.converged_ to False
460460
self.converged_ = False
461461

462-
# this line should be removed when 'thresh' is deprecated
462+
# this line should be removed when 'thresh' is removed in v0.18
463463
tol = (self.tol if self.thresh is None
464464
else self.thresh / float(X.shape[0]))
465465

@@ -471,7 +471,7 @@ def fit(self, X, y=None):
471471

472472
# Check for convergence.
473473
# (should compare to self.tol when dreprecated 'thresh' is
474-
# removed)
474+
# removed in v0.18)
475475
if prev_log_likelihood is not None:
476476
change = abs(current_log_likelihood - prev_log_likelihood)
477477
if change < tol:

0 commit comments

Comments
 (0)
0