8000 [MRG] Add Deprecation Warning to address making "batch" the default l… · scikit-learn/scikit-learn@2f43bf2 · GitHub
[go: up one dir, main page]

Skip to content

Commit 2f43bf2

Browse files
doshytamueller
authored andcommitted
[MRG] Add Deprecation Warning to address making "batch" the default learning method in LDA in 0.19 release (#6999)
* update default LDA method to batch * update tests to match default LDA learning scheme * DOC trim information on faster learning scheme with LDA * Add deprecation warning and revert learning_method back to the default After the discussion, it was decided that a proper deprecation cycle for the parameter is required. Therefore, in this PR only the warning is added. The default learning_method will be changed to 'batch' in 0.19. * Revert back tests to the default LDA * DOC deprecate in 0.20 instead of 0.19
1 parent c43bafa commit 2f43bf2

File tree

2 files changed

+12
-4
lines changed

2 files changed

+12
-4
lines changed

doc/modules/decomposition.rst

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -803,8 +803,7 @@ between :math:`q(z,\theta,\beta)` and the true posterior
803803
:class:`LatentDirichletAllocation` implements online variational Bayes algorithm and supports
804804
both online and batch update method.
805805
While batch method updates variational variables after each full pass through the data,
806-
online method updates variational variables from mini-batch data points. Therefore,
807-
online method usually converges faster than batch method.
806+
online method updates variational variables from mini-batch data points.
808807

809808
.. note::
810809

sklearn/decomposition/online_lda.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import numpy as np
1515
import scipy.sparse as sp
1616
from scipy.special import gammaln
17+
import warnings
1718

1819
from ..base import BaseEstimator, TransformerMixin
1920
from ..utils import (check_random_state, check_array,
@@ -159,6 +160,7 @@ class LatentDirichletAllocation(BaseEstimator, TransformerMixin):
159160
Method used to update `_component`. Only used in `fit` method.
160161
In general, if the data size is large, the online update will be much
161162
faster than the batch update.
163+
The default learning method is going to be changed to 'batch' in the 0.20 release.
162164
Valid options::
163165
164166
'batch': Batch variational Bayes method. Use all training data in
@@ -246,7 +248,7 @@ class LatentDirichletAllocation(BaseEstimator, TransformerMixin):
246248
"""
247249

248250
def __init__(self, n_topics=10, doc_topic_prior=None,
249-
topic_word_prior=None, learning_method='online',
251+
topic_word_prior=None, learning_method=None,
250252
learning_decay=.7, learning_offset=10., max_iter=10,
251253
batch_size=128, evaluate_every=-1, total_samples=1e6,
252254
perp_tol=1e-1, mean_change_tol=1e-3, max_doc_update_iter=100,
@@ -283,7 +285,7 @@ def _check_params(self):
283285
raise ValueError("Invalid 'learning_offset' parameter: %r"
284286
% self.learning_offset)
285287

286-
if self.learning_method not in ("batch", "online"):
288+
if self.learning_method not in ("batch", "online", None):
287289
raise ValueError("Invalid 'learning_method' parameter: %r&qu 8000 ot;
288290
% self.learning_method)
289291

@@ -499,6 +501,13 @@ def fit(self, X, y=None):
499501
max_iter = self.max_iter
500502
evaluate_every = self.evaluate_every
501503
learning_method = self.learning_method
504+
if learning_method == None:
505+
warnings.warn("The default value for 'learning_method' will be "
506+
"changed from 'online' to 'batch' in the release 0.20. "
507+
"This warning was introduced in 0.18.",
508+
DeprecationWarning)
509+
learning_method = 'online'
510+
502511
batch_size = self.batch_size
503512

504513
# initialize parameters

0 commit comments

Comments
 (0)
0