@@ -236,6 +236,32 @@ def build_analyzer(self):
236
236
raise ValueError ('%s is not a valid tokenization scheme/analyzer' %
237
237
self .analyzer )
238
238
239
+ def _check_vocabulary (self ):
240
+ vocabulary = self .vocabulary
241
+ if vocabulary is not None :
242
+ if not isinstance (vocabulary , Mapping ):
243
+ vocab = {}
244
+ for i , t in enumerate (vocabulary ):
245
+ if vocab .setdefault (t , i ) != i :
246
+ msg = "Duplicate term in vocabulary: %r" % t
247
+ raise ValueError (msg )
248
+ vocabulary = vocab
249
+ else :
250
+ indices = set (six .itervalues (vocabulary ))
251
+ if len (indices ) != len (vocabulary ):
252
+ raise ValueError ("Vocabulary contains repeated indices." )
253
+ for i in xrange (len (vocabulary )):
254
+ if i not in indices :
255
+ msg = ("Vocabulary of size %d doesn't contain index "
256
+ "%d." % (len (vocabulary ), i ))
257
+ raise ValueError (msg )
258
+ if not vocabulary :
259
+ raise ValueError ("empty vocabulary passed to fit" )
260
+ self .fixed_vocabulary = True
261
+ self .vocabulary_ = dict (vocabulary )
262
+ else :
263
+ self .fixed_vocabulary = False
264
+
239
265
240
266
class HashingVectorizer (BaseEstimator , VectorizerMixin ):
241
267
"""Convert a collection of text documents to a matrix of token occurrences
@@ -616,29 +642,7 @@ def __init__(self, input='content', encoding='utf-8',
616
642
"max_features=%r, neither a positive integer nor None"
617
643
% max_features )
618
644
self .ngram_range = ngram_range
619
- if vocabulary is not
9E88
None :
620
- if not isinstance (vocabulary , Mapping ):
621
- vocab = {}
622
- for i , t in enumerate (vocabulary ):
623
- if vocab .setdefault (t , i ) != i :
624
- msg = "Duplicate term in vocabulary: %r" % t
625
- raise ValueError (msg )
626
- vocabulary = vocab
627
- else :
628
- indices = set (six .itervalues (vocabulary ))
629
- if len (indices ) != len (vocabulary ):
630
- raise ValueError ("Vocabulary contains repeated indices." )
631
- for i in xrange (len (vocabulary )):
632
- if i not in indices :
633
- msg = ("Vocabulary of size %d doesn't contain index "
634
- "%d." % (len (vocabulary ), i ))
635
- raise ValueError (msg )
636
- if not vocabulary :
637
- raise ValueError ("empty vocabulary passed to fit" )
638
- self .fixed_vocabulary = True
639
- self .vocabulary_ = dict (vocabulary )
640
- else :
641
- self .fixed_vocabulary = False
645
+ self .vocabulary = vocabulary
642
646
self .binary = binary
643
647
self .dtype = dtype
644
648
@@ -773,6 +777,7 @@ def fit_transform(self, raw_documents, y=None):
773
777
# We intentionally don't call the transform method to make
774
778
# fit_transform overridable without unwanted side effects in
775
779
# TfidfVectorizer.
780
+ self ._check_vocabulary ()
776
781
max_df = self .max_df
777
782
min_df = self .min_df
778
783
max_features = self .max_features
@@ -820,6 +825,9 @@ def transform(self, raw_documents):
820
825
X : sparse matrix, [n_samples, n_features]
821
826
Document-term matrix.
822
827
"""
828
+ if not hasattr (self , 'vocabulary_' ):
829
+ self ._check_vocabulary ()
830
+
823
831
if not hasattr (self , 'vocabulary_' ) or len (self .vocabulary_ ) == 0 :
824
832
raise ValueError ("Vocabulary wasn't fitted or is empty!" )
825
833
0 commit comments