8000 [MRG + 1] Fix perplexity method by adding _unnormalized_transform met… · Sundrique/scikit-learn@d44bf89 · GitHub
[go: up one dir, main page]

Skip to content

Commit d44bf89

Browse files
garyForemanSundrique
authored andcommitted
[MRG + 1] Fix perplexity method by adding _unnormalized_transform method, Issue scikit-learn#7954 (scikit-learn#7992)
Also deprecate doc_topic_distr argument in perplexity method
1 parent a7e6bc0 commit d44bf89

File tree

3 files changed

+128
-28
lines changed

3 files changed

+128
-28
lines changed

doc/whats_new.rst

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,11 @@ Bug fixes
122122
when a numpy array is passed in for weights. :issue:`7983` by
123123
:user:`Vincent Pham <vincentpham1991>`.
124124

125+
- Fix a bug in :class:`sklearn.decomposition.LatentDirichletAllocation`
126+
where the ``perplexity`` method was returning incorrect results because
127+
the ``transform`` method returns normalized document topic distributions
128+
as of version 0.18. :issue:`7954` by :user:`Gary Foreman <garyForeman>`.
129+
125130
- Fix a bug where :class:`sklearn.ensemble.GradientBoostingClassifier` and
126131
:class:`sklearn.ensemble.GradientBoostingRegressor` ignored the
127132
``min_impurity_split`` parameter.
@@ -135,6 +140,12 @@ API changes summary
135140
ensemble estimators (deriving from :class:`ensemble.BaseEnsemble`)
136141
now only have ``self.estimators_`` available after ``fit``.
137142
:issue:`7464` by `Lars Buitinck`_ and `Loic Esteve`_.
143+
8000 144+
- Deprecate the ``doc_topic_distr`` argument of the ``perplexity`` method
145+
in :class:`sklearn.decomposition.LatentDirichletAllocation` because the
146+
user no longer has access to the unnormalized document topic distribution
147+
needed for the perplexity calculation. :issue:`7954` by
148+
:user:`Gary Foreman <garyForeman>`.
138149

139150
.. _changes_0_18_1:
140151

sklearn/decomposition/online_lda.py

Lines changed: 71 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -505,7 +505,7 @@ def fit(self, X, y=None):
505505
warnings.warn("The default value for 'learning_method' will be "
506506
"changed from 'online' to 'batch' in the release 0.20. "
507507
"This warning was introduced in 0.18.",
508-
DeprecationWarning)
508+
DeprecationWarning)
509509
learning_method = 'online'
510510

511511
batch_size = self.batch_size
@@ -531,8 +531,8 @@ def fit(self, X, y=None):
531531
doc_topics_distr, _ = self._e_step(X, cal_sstats=False,
532532
random_init=False,
533533
parallel=parallel)
534-
bound = self.perplexity(X, doc_topics_distr,
535-
sub_sampling=False)
534+
bound = self._perplexity_precomp_distr(X, doc_topics_distr,
535+
sub_sampling=False)
536536
if self.verbose:
537537
print('iteration: %d, perplexity: %.4f'
538538
% (i + 1, bound))
@@ -541,10 +541,18 @@ def fit(self, X, y=None):
541541
break
542542
last_bound = bound
543543
self.n_iter_ += 1
544+
545+
# calculate final perplexity value on train set
546+
doc_topics_distr, _ = self._e_step(X, cal_sstats=False,
547+
random_init=False,
548+
parallel=parallel)
549+
self.bound_ = self._perplexity_precomp_distr(X, doc_topics_distr,
550+
sub_sampling=False)
551+
544552
return self
545553

546-
def transform(self, X):
547-
"""Transform data X according to the fitted model.
554+
def _unnormalized_transform(self, X):
555+
"""Transform data X according to fitted model.
548556
549557
Parameters
550558
----------
@@ -556,7 +564,6 @@ def transform(self, X):
556564
doc_topic_distr : shape=(n_samples, n_topics)
557565
Document topic distribution for X.
558566
"""
559-
560567
if not hasattr(self, 'components_'):
561568
raise NotFittedError("no 'components_' attribute in model."
562569
" Please fit model first.")
@@ -572,7 +579,26 @@ def transform(self, X):
572579

573580
doc_topic_distr, _ = self._e_step(X, cal_sstats=False,
574581
random_init=False)
575-
# normalize doc_topic_distr
582+
583+
return doc_topic_distr
584+
585+
def transform(self, X):
586+
"""Transform data X according to the fitted model.
587+
588+
.. versionchanged:: 0.18
589+
*doc_topic_distr* is now normalized
590+
591+
Parameters
592+
----------
593+
X : array-like or sparse matrix, shape=(n_samples, n_features)
594+
Document word matrix.
595+
596+
Returns
597+
-------
598+
doc_topic_distr : shape=(n_samples, n_topics)
599+
Document topic distribution for X.
600+
"""
601+
doc_topic_distr = self._unnormalized_transform(X)
576602
doc_topic_distr /= doc_topic_distr.sum(axis=1)[:, np.newaxis]
577603
return doc_topic_distr
578604

@@ -665,15 +691,16 @@ def score(self, X, y=None):
665691
score : float
666692
Use approximate bound as score.
667693
"""
668-
669694
X = self._check_non_neg_array(X, "LatentDirichletAllocation.score")
670695

671-
doc_topic_distr = self.transform(X)
696+
doc_topic_distr = self._unnormalized_transform(X)
672697
score = self._approx_bound(X, doc_topic_distr, sub_sampling=False)
673698
return score
674699

675-
def perplexity(self, X, doc_topic_distr=None, sub_sampling=False):
676-
"""Calculate approximate perplexity for data X.
700+
def _perplexity_precomp_distr(self, X, doc_topic_distr=None,
701+
sub_sampling=False):
702+
"""Calculate approximate perplexity for data X with ability to accept
703+
precomputed doc_topic_distr
677704
678705
Perplexity is defined as exp(-1. * log-likelihood per word)
679706
@@ -699,7 +726,7 @@ def perplexity(self, X, doc_topic_distr=None, sub_sampling=False):
699726
"LatentDirichletAllocation.perplexity")
700727

701728
if doc_topic_distr is None:
702-
doc_topic_distr = self.transform(X)
729+
doc_topic_distr = self._unnormalized_transform(X)
703730
else:
704731
n_samples, n_topics = doc_topic_distr.shape
705732
if n_samples != X.shape[0]:
@@ -719,3 +746,35 @@ def perplexity(self, X, doc_topic_distr=None, sub_sampling=False):
719746
perword_bound = bound / word_cnt
720747

721748
return np.exp(-1.0 * perword_bound)
749+
750+
def perplexity(self, X, doc_topic_distr='deprecated', sub_sampling=False):
751+
"""Calculate approximate perplexity for data X.
752+
753+
Perplexity is defined as exp(-1. * log-likelihood per word)
754+
755+
.. versionchanged:: 0.19
756+
*doc_topic_distr* argument has been depricated because user no
757+
longer has access to unnormalized distribution
758+
759+
Parameters
760+
----------
761+
X : array-like or sparse matrix, [n_samples, n_features]
762+
Document word matrix.
763+
764+
doc_topic_distr : None or array, shape=(n_samples, n_topics)
765+
Document topic distribution.
766+
If it is None, it will be generated by applying transform on X.
767+
768+
.. deprecated:: 0.19
769+
770+
Returns
771+
-------
772+
score : float
773+
Perplexity score.
774+
"""
775+
if doc_topic_distr != 'deprecated':
776+
warnings.warn("Argument 'doc_topic_distr' is deprecated and will "
777+
"be ignored as of 0.19. Support for this argument "
778+
"will be removed in 0.21.", DeprecationWarning)
779+
780+
return self._perplexity_precomp_distr(X, sub_sampling=sub_sampling)

sklearn/decomposition/tests/test_online_lda.py

Lines changed: 46 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from sklearn.utils.testing import assert_greater_equal
1515
from sklearn.utils.testing import assert_raises_regexp
1616
from sklearn.utils.testing import if_safe_multiprocessing_with_blas
17+
from sklearn.utils.testing import assert_warns
1718

1819
from sklearn.exceptions import NotFittedError
1920
from sklearn.externals.six.moves import xrange
@@ -238,12 +239,12 @@ def test_lda_preplexity_mismatch():
238239
lda.fit(X)
239240
# invalid samples
240241
invalid_n_samples = rng.randint(4, size=(n_samples + 1, n_topics))
241-
assert_raises_regexp(ValueError, r'Number of samples', lda.perplexity, X,
242-
invalid_n_samples)
242+
assert_raises_regexp(ValueError, r'Number of samples',
243+
lda._perplexity_precomp_distr, X, invalid_n_samples)
243244
# invalid topic number
244245
invalid_n_topics = rng.randint(4, size=(n_samples, n_topics + 1))
245-
assert_raises_regexp(ValueError, r'Number of topics', lda.perplexity, X,
246-
invalid_n_topics)
246+
assert_raises_regexp(ValueError, r'Number of topics',
247+
lda._perplexity_precomp_distr, X, invalid_n_topics)
247248

248249

249250
def test_lda_perplexity():
@@ -257,15 +258,15 @@ def test_lda_perplexity():
257258
lda_2 = LatentDirichletAllocation(n_topics=n_topics, max_iter=10,
258259
learning_method=method,
259260
total_samples=100, random_state=0)
260-
distr_1 = lda_1.fit_transform(X)
261-
perp_1 = lda_1.perplexity(X, distr_1, sub_sampling=False)
261+
lda_1.fit(X)
262+
perp_1 = lda_1.perplexity(X, sub_sampling=False)
262263

263-
distr_2 = lda_2.fit_transform(X)
264-
perp_2 = lda_2.perplexity(X, distr_2, sub_sampling=False)
264+
lda_2.fit(X)
265+
perp_2 = lda_2.perplexity(X, sub_sampling=False)
265266
assert_greater_equal(perp_1, perp_2)
266267

267-
perp_1_subsampling = lda_1.perplexity(X, distr_1, sub_sampling=True)
268-
perp_2_subsampling = lda_2.perplexity(X, distr_2, sub_sampling=True)
268+
perp_1_subsampling = lda_1.perplexity(X, sub_sampling=True)
269+
perp_2_subsampling = lda_2.perplexity(X, sub_sampling=True)
269270
assert_greater_equal(perp_1_subsampling, perp_2_subsampling)
270271

271272

@@ -295,27 +296,56 @@ def test_perplexity_input_format():
295296
lda = LatentDirichletAllocation(n_topics=n_topics, max_iter=1,
296297
learning_method='batch',
297298
total_samples=100, random_state=0)
298-
distr = lda.fit_transform(X)
299+
lda.fit(X)
299300
perp_1 = lda.perplexity(X)
300-
perp_2 = lda.perplexity(X, distr)
301-
perp_3 = lda.perplexity(X.toarray(), distr)
301+
perp_2 = lda.perplexity(X.toarray())
302302
assert_almost_equal(perp_1, perp_2)
303-
assert_almost_equal(perp_1, perp_3)
304303

305304

306305
def test_lda_score_perplexity():
307306
# Test the relationship between LDA score and perplexity
308307
n_topics, X = _build_sparse_mtx()
309308
lda = LatentDirichletAllocation(n_topics=n_topics, max_iter=10,
310309
random_state=0)
311-
distr = lda.fit_transform(X)
312-
perplexity_1 = lda.perplexity(X, distr, sub_sampling=False)
310+
lda.fit(X)
311+
perplexity_1 = lda.perplexity(X, sub_sampling=False)
313312

314313
score = lda.score(X)
315314
perplexity_2 = np.exp(-1. * (score / np.sum(X.data)))
316315
assert_almost_equal(perplexity_1, perplexity_2)
317316

318317

318+
def test_lda_fit_perplexity():
319+
# Test that the perplexity computed during fit is consistent with what is
320+
# returned by the perplexity method
321+
n_topics, X = _build_sparse_mtx()
322+
lda = LatentDirichletAllocation(n_topics=n_topics, max_iter=1,
323+
learning_method='batch', random_state=0,
324+
evaluate_every=1)
325+
lda.fit(X)
326+
327+
# Perplexity computed at end of fit method
328+
perplexity1 = lda.bound_
329+
330+
# Result of perplexity method on the train set
331+
perplexity2 = lda.perplexity(X)
332+
333+
assert_almost_equal(perplexity1, perplexity2)
334+
335+
336+
def test_doc_topic_distr_deprecation():
337+
# Test that the appropriate warning message is displayed when a user
338+
# attempts to pass the doc_topic_distr argument to the perplexity method
339+
n_topics, X = _build_sparse_mtx()
340+
lda = LatentDirichletAllocation(n_topics=n_topics, max_iter=1,
341+
learning_method='batch',
342+
total_samples=100, random_state=0)
343+
distr1 = lda.fit_transform(X)
344+
distr2 = None
345+
assert_warns(DeprecationWarning, lda.perplexity, X, distr1)
346+
assert_warns(DeprecationWarning, lda.perplexity, X, distr2)
347+
348+
319349
def test_lda_empty_docs():
320350
"""Test LDA on empty document (all-zero rows)."""
321351
Z = np.zeros((5, 4))

0 commit comments

Comments
 (0)
0