8000 FIX set vectorizer vocabulary outside of init · scikit-learn/scikit-learn@be3ac79 · GitHub
[go: up one dir, main page]

Skip to content

Commit be3ac79

Browse files
committed
FIX set vectorizer vocabulary outside of init
1 parent c604ac3 commit be3ac79

File tree

2 files changed

+45
-29
lines changed

2 files changed

+45
-29
lines changed

sklearn/feature_extraction/tests/test_text.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from sklearn.pipeline import Pipeline
1919
from sklearn.svm import LinearSVC
2020

21+
from sklearn.base import clone
2122

2223
import numpy as np
2324
from nose import SkipTest
@@ -283,7 +284,8 @@ def test_countvectorizer_stop_words():
283284

284285
def test_countvectorizer_empty_vocabulary():
285286
try:
286-
CountVectorizer(vocabulary=[])
287+
vect = CountVectorizer(vocabulary=[])
288+
vect.fit(["foo"])
287289
assert False, "we shouldn't get here"
288290
except ValueError as e:
289291
assert_in("empty vocabulary", str(e).lower())
@@ -379,7 +381,6 @@ def test_vectorizer():
379381

380382
# build a vectorizer v1 with the same vocabulary as the one fitted by v1
381383
v2 = CountVectorizer(vocabulary=v1.vocabulary_)
382-
383384
# compare that the two vectorizer give the same output on the test sample
384385
for v in (v1, v2):
385386
counts_test = v.transform(test_data)
@@ -405,7 +406,6 @@ def test_vectorizer():
405406
assert_equal(counts_test[0, vocabulary["burger"]], 0)
406407
assert_equal(counts_test[0, vocabulary["beer"]], 0)
407408
assert_equal(counts_test[0, vocabulary["pizza"]], 0)
408-
409409
# test tf-idf
410410
t1 = TfidfTransformer(norm='l1')
411411
tfidf = t1.fit(counts_train).transform(counts_train).toarray()
@@ -440,10 +440,10 @@ def test_vectorizer():
440440
# (equivalent to term count vectorizer + tfidf transformer)
441441
train_data = iter(ALL_FOOD_DOCS[:-1])
442442
tv = TfidfVectorizer(norm='l1')
443-
assert_false(tv.fixed_vocabulary)
444443

445444
tv.max_df = v1.max_df
446445
tfidf2 = tv.fit_transform(train_data).toarray()
446+
assert_false(tv.fixed_vocabulary)
447447
assert_array_almost_equal(tfidf, tfidf2)
448448

449449
# test the direct tfidf vectorizer with new data
@@ -824,7 +824,6 @@ def test_tfidf_vectorizer_with_fixed_vocabulary():
824824
# non regression smoke test for inheritance issues
825825
vocabulary = ['pizza', 'celeri']
826826
vect = TfidfVectorizer(vocabulary=vocabulary)
827-
assert_true(vect.fixed_vocabulary)
828827
X_1 = vect.fit_transform(ALL_FOOD_DOCS)
829828
X_2 = vect.transform(ALL_FOOD_DOCS)
830829
assert_array_almost_equal(X_1.toarray(), X_2.toarray())
@@ -870,7 +869,8 @@ def test_pickling_transformer():
870869

871870
def test_non_unique_vocab():
872871
vocab = ['a', 'b', 'c', 'a', 'a']
873-
assert_raises(ValueError, CountVectorizer, vocabulary=vocab)
872+
vect = CountVectorizer(vocabulary=vocab)
873+
assert_raises(ValueError, vect.fit, [])
874874

875875

876876
def test_hashingvectorizer_nan_in_docs():
@@ -901,3 +901,11 @@ def test_tfidfvectorizer_export_idf():
901901
vect = TfidfVectorizer(use_idf=True)
902902
vect.fit(JUNK_FOOD_DOCS)
903903
assert_array_al 8000 most_equal(vect.idf_, vect._tfidf.idf_)
904+
905+
906+
def test_vectorizer_vocab_clone():
907+
vect_vocab = TfidfVectorizer(vocabulary=["the"])
908+
vect_vocab_clone = clone(vect_vocab)
909+
vect_vocab.fit(ALL_FOOD_DOCS)
910+
vect_vocab_clone.fit(ALL_FOOD_DOCS)
911+
assert_equal(vect_vocab_clone.vocabulary_, vect_vocab.vocabulary_)

sklearn/feature_extraction/text.py

Lines changed: 31 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,32 @@ def build_analyzer(self):
236236
raise ValueError('%s is not a valid tokenization scheme/analyzer' %
237237
self.analyzer)
238238

239+
def _check_vocabulary(self):
240+
vocabulary = self.vocabulary
241+
if vocabulary is not None:
242+
if not isinstance(vocabulary, Mapping):
243+
vocab = {}
244+
for i, 6D40 t in enumerate(vocabulary):
245+
if vocab.setdefault(t, i) != i:
246+
msg = "Duplicate term in vocabulary: %r" % t
247+
raise ValueError(msg)
248+
vocabulary = vocab
249+
else:
250+
indices = set(six.itervalues(vocabulary))
251+
if len(indices) != len(vocabulary):
252+
raise ValueError("Vocabulary contains repeated indices.")
253+
for i in xrange(len(vocabulary)):
254+
if i not in indices:
255+
msg = ("Vocabulary of size %d doesn't contain index "
256+
"%d." % (len(vocabulary), i))
257+
raise ValueError(msg)
258+
if not vocabulary:
259+
raise ValueError("empty vocabulary passed to fit")
260+
self.fixed_vocabulary = True
261+
self.vocabulary_ = dict(vocabulary)
262+
else:
263+
self.fixed_vocabulary = False
264+
239265

240266
class HashingVectorizer(BaseEstimator, VectorizerMixin):
241267
"""Convert a collection of text documents to a matrix of token occurrences
@@ -616,29 +642,7 @@ def __init__(self, input='content', encoding='utf-8',
616642
"max_features=%r, neither a positive integer nor None"
617643
% max_features)
618644
self.ngram_range = ngram_range
619-
if vocabulary is not None:
620-
if not isinstance(vocabulary, Mapping):
621-
vocab = {}
622-
for i, t in enumerate(vocabulary):
623-
if vocab.setdefault(t, i) != i:
624-
msg = "Duplicate term in vocabulary: %r" % t
625-
raise ValueError(msg)
626-
vocabulary = vocab
627-
else:
628-
indices = set(six.itervalues(vocabulary))
629-
if len(indices) != len(vocabulary):
630-
raise ValueError("Vocabulary contains repeated indices.")
631-
for i in xrange(len(vocabulary)):
632-
if i not in indices:
633-
msg = ("Vocabulary of size %d doesn't contain index "
634-
"%d." % (len(vocabulary), i))
635-
raise ValueError(msg)
636-
if not vocabulary:
637-
raise ValueError("empty vocabulary passed to fit")
638-
self.fixed_vocabulary = True
639-
self.vocabulary_ = dict(vocabulary)
640-
else:
641-
self.fixed_vocabulary = False
645+
self.vocabulary = vocabulary
642646
self.binary = binary
643647
self.dtype = dtype
644648

@@ -773,6 +777,7 @@ def fit_transform(self, raw_documents, y=None):
773777
# We intentionally don't call the transform method to make
774778
# fit_transform overridable without unwanted side effects in
775779
# TfidfVectorizer.
780+
self._check_vocabulary()
776781
max_df = self.max_df
777782
min_df = self.min_df
778783
max_features = self.max_features
@@ -820,6 +825,9 @@ def transform(self, raw_documents):
820825
X : sparse matrix, [n_samples, n_features]
821826
Document-term matrix.
822827
"""
828+
if not hasattr(self, 'vocabulary_'):
829+
self._check_vocabulary()
830+
823831
if not hasattr(self, 'vocabulary_') or len(self.vocabulary_) == 0:
824832
raise ValueError("Vocabulary wasn't fitted or is empty!")
825833

0 commit comments

Comments
 (0)
0