@@ -505,7 +505,7 @@ def fit(self, X, y=None):
505
505
warnings .warn ("The default value for 'learning_method' will be "
506
506
"changed from 'online' to 'batch' in the release 0.20. "
507
507
"This warning was introduced in 0.18." ,
508
- DeprecationWarning )
508
+ DeprecationWarning )
509
509
learning_method = 'online'
510
510
511
511
batch_size = self .batch_size
@@ -531,8 +531,8 @@ def fit(self, X, y=None):
531
531
doc_topics_distr , _ = self ._e_step (X , cal_sstats = False ,
532
532
random_init = False ,
533
533
parallel = parallel )
534
- bound = self .perplexity (X , doc_topics_distr ,
535
- sub_sampling = False )
534
+ bound = self ._perplexity_precomp_distr (X , doc_topics_distr ,
535
+ sub_sampling = False )
536
536
if self .verbose :
537
537
print ('iteration: %d, perplexity: %.4f'
538
538
% (i + 1 , bound ))
@@ -541,10 +541,18 @@ def fit(self, X, y=None):
541
541
break
542
542
last_bound = bound
543
543
self .n_iter_ += 1
544
+
545
+ # calculate final perplexity value on train set
546
+ doc_topics_distr , _ = self ._e_step (X , cal_sstats = False ,
547
+ random_init = False ,
548
+ parallel = parallel )
549
+ self .bound_ = self ._perplexity_precomp_distr (X , doc_topics_distr ,
550
+ sub_sampling = False )
551
+
544
552
return self
545
553
546
- def transform (self , X ):
547
- """Transform data X according to the fitted model.
554
+ def _unnormalized_transform (self , X ):
555
+ """Transform data X according to fitted model.
548
556
549
557
Parameters
550
558
----------
@@ -556,7 +564,6 @@ def transform(self, X):
556
564
doc_topic_distr : shape=(n_samples, n_topics)
557
565
Document topic distribution for X.
558
566
"""
559
-
560
567
if not hasattr (self , 'components_' ):
561
568
raise NotFittedError ("no 'components_' attribute in model."
562
569
" Please fit model first." )
@@ -572,7 +579,26 @@ def transform(self, X):
572
579
573
580
doc_topic_distr , _ = self ._e_step (X , cal_sstats = False ,
574
581
random_init = False )
575
- # normalize doc_topic_distr
582
+
583
+ return doc_topic_distr
584
+
585
+ def transform (self , X ):
586
+ """Transform data X according to the fitted model.
587
+
588
+ .. versionchanged:: 0.18
589
+ *doc_topic_distr* is now normalized
590
+
591
+ Parameters
592
+ ----------
593
+ X : array-like or sparse matrix, shape=(n_samples, n_features)
594
+ Document word matrix.
595
+
596
+ Returns
597
+ -------
598
+ doc_topic_distr : shape=(n_samples, n_topics)
599
+ Document topic distribution for X.
600
+ """
601
+ doc_topic_distr = self ._unnormalized_transform (X )
576
602
doc_topic_distr /= doc_topic_distr .sum (axis = 1 )[:, np .newaxis ]
577
603
return doc_topic_distr
578
604
@@ -665,15 +691,16 @@ def score(self, X, y=None):
665
691
score : float
666
692
Use approximate bound as score.
667
693
"""
668
-
669
694
X = self ._check_non_neg_array (X , "LatentDirichletAllocation.score" )
670
695
671
- doc_topic_distr = self .transform (X )
696
+ doc_topic_distr = self ._unnormalized_transform (X )
672
697
score = self ._approx_bound (X , doc_topic_distr , sub_sampling = False )
673
698
return score
674
699
675
- def perplexity (self , X , doc_topic_distr = None , sub_sampling = False ):
676
- """Calculate approximate perplexity for data X.
700
+ def _perplexity_precomp_distr (self , X , doc_topic_distr = None ,
701
+ sub_sampling = False ):
702
+ """Calculate approximate perplexity for data X with ability to accept
703
+ precomputed doc_topic_distr
677
704
678
705
Perplexity is defined as exp(-1. * log-likelihood per word)
679
706
@@ -699,7 +726,7 @@ def perplexity(self, X, doc_topic_distr=None, sub_sampling=False):
699
726
"LatentDirichletAllocation.perplexity" )
700
727
701
728
if doc_topic_distr is None :
702
- doc_topic_distr = self .transform (X )
729
+ doc_topic_distr = self ._unnormalized_transform (X )
703
730
else :
704
731
n_samples , n_topics = doc_topic_distr .shape
705
732
if n_samples != X .shape [0 ]:
@@ -719,3 +746,35 @@ def perplexity(self, X, doc_topic_distr=None, sub_sampling=False):
719
746
perword_bound = bound / word_cnt
720
747
721
748
return np .exp (- 1.0 * perword_bound )
749
+
750
+ def perplexity (self , X , doc_topic_distr = 'deprecated' , sub_sampling = False ):
751
+ """Calculate approximate perplexity for data X.
752
+
753
+ Perplexity is defined as exp(-1. * log-likelihood per word)
754
+
755
+ .. versionchanged:: 0.19
756
+ *doc_topic_distr* argument has been depricated because user no
757
+ longer has access to unnormalized distribution
758
+
759
+ Parameters
760
+ ----------
761
+ X : array-like or sparse matrix, [n_samples, n_features]
762
+ Document word matrix.
763
+
764
+ doc_topic_distr : None or array, shape=(n_samples, n_topics)
765
+ Document topic distribution.
766
+ If it is None, it will be generated by applying transform on X.
767
+
768
+ .. deprecated:: 0.19
769
+
770
+ Returns
771
+ -------
772
+ score : float
773
+ Perplexity score.
774
+ """
775
+ if doc_topic_distr != 'deprecated' :
776
+ warnings .warn ("Argument 'doc_topic_distr' is deprecated and will "
777
+ "be ignored as of 0.19. Support for this argument "
778
+ "will be removed in 0.21." , DeprecationWarning )
779
+
780
+ return self ._perplexity_precomp_distr (X , sub_sampling = sub_sampling )
0 commit comments