From b176075ad418b777dfcf1cf97fad226259a03238 Mon Sep 17 00:00:00 2001 From: Vlad Niculae Date: Tue, 12 Aug 2014 14:44:20 +0200 Subject: [PATCH 1/2] FIX set vectorizer vocabulary outside of init --- sklearn/feature_extraction/tests/test_text.py | 19 +++++-- sklearn/feature_extraction/text.py | 54 +++++++++++-------- 2 files changed, 45 insertions(+), 28 deletions(-) diff --git a/sklearn/feature_extraction/tests/test_text.py b/sklearn/feature_extraction/tests/test_text.py index c20c849689f02..e326c79d405dd 100644 --- a/sklearn/feature_extraction/tests/test_text.py +++ b/sklearn/feature_extraction/tests/test_text.py @@ -18,6 +18,7 @@ from sklearn.pipeline import Pipeline from sklearn.svm import LinearSVC +from sklearn.base import clone import numpy as np from nose import SkipTest @@ -283,7 +284,8 @@ def test_countvectorizer_stop_words(): def test_countvectorizer_empty_vocabulary(): try: - CountVectorizer(vocabulary=[]) + vect = CountVectorizer(vocabulary=[]) + vect.fit(["foo"]) assert False, "we shouldn't get here" except ValueError as e: assert_in("empty vocabulary", str(e).lower()) @@ -440,10 +442,10 @@ def test_vectorizer(): # (equivalent to term count vectorizer + tfidf transformer) train_data = iter(ALL_FOOD_DOCS[:-1]) tv = TfidfVectorizer(norm='l1') - assert_false(tv.fixed_vocabulary) tv.max_df = v1.max_df tfidf2 = tv.fit_transform(train_data).toarray() + assert_false(tv.fixed_vocabulary) assert_array_almost_equal(tfidf, tfidf2) # test the direct tfidf vectorizer with new data @@ -777,7 +779,6 @@ def test_vectorizer_pipeline_cross_validation(): # label junk food as -1, the others as +1 target = [-1] * len(JUNK_FOOD_DOCS) + [1] * len(NOTJUNK_FOOD_DOCS) - pipeline = Pipeline([('vect', TfidfVectorizer()), ('svc', LinearSVC())]) @@ -824,7 +825,6 @@ def test_tfidf_vectorizer_with_fixed_vocabulary(): # non regression smoke test for inheritance issues vocabulary = ['pizza', 'celeri'] vect = TfidfVectorizer(vocabulary=vocabulary) - assert_true(vect.fixed_vocabulary) X_1 = vect.fit_transform(ALL_FOOD_DOCS) X_2 = vect.transform(ALL_FOOD_DOCS) assert_array_almost_equal(X_1.toarray(), X_2.toarray()) @@ -870,7 +870,8 @@ def test_pickling_transformer(): def test_non_unique_vocab(): vocab = ['a', 'b', 'c', 'a', 'a'] - assert_raises(ValueError, CountVectorizer, vocabulary=vocab) + vect = CountVectorizer(vocabulary=vocab) + assert_raises(ValueError, vect.fit, []) def test_hashingvectorizer_nan_in_docs(): @@ -901,3 +902,11 @@ def test_tfidfvectorizer_export_idf(): vect = TfidfVectorizer(use_idf=True) vect.fit(JUNK_FOOD_DOCS) assert_array_almost_equal(vect.idf_, vect._tfidf.idf_) + + +def test_vectorizer_vocab_clone(): + vect_vocab = TfidfVectorizer(vocabulary=["the"]) + vect_vocab_clone = clone(vect_vocab) + vect_vocab.fit(ALL_FOOD_DOCS) + vect_vocab_clone.fit(ALL_FOOD_DOCS) + assert_equal(vect_vocab_clone.vocabulary_, vect_vocab.vocabulary_) diff --git a/sklearn/feature_extraction/text.py b/sklearn/feature_extraction/text.py index d5590f0604067..c2f4a1f095f2d 100644 --- a/sklearn/feature_extraction/text.py +++ b/sklearn/feature_extraction/text.py @@ -236,6 +236,32 @@ def build_analyzer(self): raise ValueError('%s is not a valid tokenization scheme/analyzer' % self.analyzer) + def _check_vocabulary(self): + vocabulary = self.vocabulary + if vocabulary is not None: + if not isinstance(vocabulary, Mapping): + vocab = {} + for i, t in enumerate(vocabulary): + if vocab.setdefault(t, i) != i: + msg = "Duplicate term in vocabulary: %r" % t + raise ValueError(msg) + vocabulary = vocab + else: + indices = set(six.itervalues(vocabulary)) + if len(indices) != len(vocabulary): + raise ValueError("Vocabulary contains repeated indices.") + for i in xrange(len(vocabulary)): + if i not in indices: + msg = ("Vocabulary of size %d doesn't contain index " + "%d." % (len(vocabulary), i)) + raise ValueError(msg) + if not vocabulary: + raise ValueError("empty vocabulary passed to fit") + self.fixed_vocabulary = True + self.vocabulary_ = dict(vocabulary) + else: + self.fixed_vocabulary = False + class HashingVectorizer(BaseEstimator, VectorizerMixin): """Convert a collection of text documents to a matrix of token occurrences @@ -616,29 +642,7 @@ def __init__(self, input='content', encoding='utf-8', "max_features=%r, neither a positive integer nor None" % max_features) self.ngram_range = ngram_range - if vocabulary is not None: - if not isinstance(vocabulary, Mapping): - vocab = {} - for i, t in enumerate(vocabulary): - if vocab.setdefault(t, i) != i: - msg = "Duplicate term in vocabulary: %r" % t - raise ValueError(msg) - vocabulary = vocab - else: - indices = set(six.itervalues(vocabulary)) - if len(indices) != len(vocabulary): - raise ValueError("Vocabulary contains repeated indices.") - for i in xrange(len(vocabulary)): - if i not in indices: - msg = ("Vocabulary of size %d doesn't contain index " - "%d." % (len(vocabulary), i)) - raise ValueError(msg) - if not vocabulary: - raise ValueError("empty vocabulary passed to fit") - self.fixed_vocabulary = True - self.vocabulary_ = dict(vocabulary) - else: - self.fixed_vocabulary = False + self.vocabulary = vocabulary self.binary = binary self.dtype = dtype @@ -773,6 +777,7 @@ def fit_transform(self, raw_documents, y=None): # We intentionally don't call the transform method to make # fit_transform overridable without unwanted side effects in # TfidfVectorizer. + self._check_vocabulary() max_df = self.max_df min_df = self.min_df max_features = self.max_features @@ -820,6 +825,9 @@ def transform(self, raw_documents): X : sparse matrix, [n_samples, n_features] Document-term matrix. """ + if not hasattr(self, 'vocabulary_'): + self._check_vocabulary() + if not hasattr(self, 'vocabulary_') or len(self.vocabulary_) == 0: raise ValueError("Vocabulary wasn't fitted or is empty!") From 0aa278ef35db597ce64d297296d2dde0e224d477 Mon Sep 17 00:00:00 2001 From: Vlad Niculae Date: Tue, 12 Aug 2014 22:58:36 +0200 Subject: [PATCH 2/2] Deprecate vectorizer fixed_vocabulary attribute --- sklearn/feature_extraction/tests/test_text.py | 6 +++--- sklearn/feature_extraction/text.py | 18 +++++++++++++----- 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/sklearn/feature_extraction/tests/test_text.py b/sklearn/feature_extraction/tests/test_text.py index e326c79d405dd..0b3aa8f5ea7fb 100644 --- a/sklearn/feature_extraction/tests/test_text.py +++ b/sklearn/feature_extraction/tests/test_text.py @@ -445,7 +445,7 @@ def test_vectorizer(): tv.max_df = v1.max_df tfidf2 = tv.fit_transform(train_data).toarray() - assert_false(tv.fixed_vocabulary) + assert_false(tv.fixed_vocabulary_) assert_array_almost_equal(tfidf, tfidf2) # test the direct tfidf vectorizer with new data @@ -769,7 +769,7 @@ def test_vectorizer_pipeline_grid_selection(): best_vectorizer = grid_search.best_estimator_.named_steps['vect'] assert_equal(best_vectorizer.ngram_range, (1, 1)) assert_equal(best_vectorizer.norm, 'l2') - assert_false(best_vectorizer.fixed_vocabulary) + assert_false(best_vectorizer.fixed_vocabulary_) def test_vectorizer_pipeline_cross_validation(): @@ -828,7 +828,7 @@ def test_tfidf_vectorizer_with_fixed_vocabulary(): X_1 = vect.fit_transform(ALL_FOOD_DOCS) X_2 = vect.transform(ALL_FOOD_DOCS) assert_array_almost_equal(X_1.toarray(), X_2.toarray()) - assert_true(vect.fixed_vocabulary) + assert_true(vect.fixed_vocabulary_) def test_pickling_vectorizer(): diff --git a/sklearn/feature_extraction/text.py b/sklearn/feature_extraction/text.py index c2f4a1f095f2d..bce99f99eaf59 100644 --- a/sklearn/feature_extraction/text.py +++ b/sklearn/feature_extraction/text.py @@ -28,7 +28,8 @@ from ..preprocessing import normalize from .hashing import FeatureHasher from .stop_words import ENGLISH_STOP_WORDS -from sklearn.externals import six +from ..utils import deprecated +from ..externals import six __all__ = ['CountVectorizer', 'ENGLISH_STOP_WORDS', @@ -257,10 +258,16 @@ def _check_vocabulary(self): raise ValueError(msg) if not vocabulary: raise ValueError("empty vocabulary passed to fit") - self.fixed_vocabulary = True + self.fixed_vocabulary_ = True self.vocabulary_ = dict(vocabulary) else: - self.fixed_vocabulary = False + self.fixed_vocabulary_ = False + + @property + @deprecated("The `fixed_vocabulary` attribute is deprecated and will be " + "removed in 0.18. Please use `fixed_vocabulary_` instead.") + def fixed_vocabulary(self): + return self.fixed_vocabulary_ class HashingVectorizer(BaseEstimator, VectorizerMixin): @@ -782,12 +789,13 @@ def fit_transform(self, raw_documents, y=None): min_df = self.min_df max_features = self.max_features - vocabulary, X = self._count_vocab(raw_documents, self.fixed_vocabulary) + vocabulary, X = self._count_vocab(raw_documents, + self.fixed_vocabulary_) if self.binary: X.data.fill(1) - if not self.fixed_vocabulary: + if not self.fixed_vocabulary_: X = self._sort_features(X, vocabulary) n_doc = X.shape[0]