8000 [MRG+2] Bugfix: Clone-safe vectorizers with custom vocabulary by vene · Pull Request #3552 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

[MRG+2] Bugfix: Clone-safe vectorizers with custom vocabulary #3552

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 16 additions & 7 deletions sklearn/feature_extraction/tests/test_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from sklearn.pipeline import Pipeline
from sklearn.svm import LinearSVC

from sklearn.base import clone

import numpy as np
from nose import SkipTest
Expand Down Expand Up @@ -283,7 +284,8 @@ def test_countvectorizer_stop_words():

def test_countvectorizer_empty_vocabulary():
try:
CountVectorizer(vocabulary=[])
vect = CountVectorizer(vocabulary=[])
vect.fit(["foo"])
assert False, "we shouldn't get here"
except ValueError as e:
assert_in("empty vocabulary", str(e).lower())
Expand Down Expand Up @@ -440,10 +442,10 @@ def test_vectorizer():
# (equivalent to term count vectorizer + tfidf transformer)
train_data = iter(ALL_FOOD_DOCS[:-1])
tv = TfidfVectorizer(norm='l1')
assert_false(tv.fixed_vocabulary)

tv.max_df = v1.max_df
tfidf2 = tv.fit_transform(train_data).toarray()
assert_false(tv.fixed_vocabulary_)
assert_array_almost_equal(tfidf, tfidf2)

# test the direct tfidf vectorizer with new data
Expand Down Expand Up @@ -767,7 +769,7 @@ def test_vectorizer_pipeline_grid_selection():
best_vectorizer = grid_search.best_estimator_.named_steps['vect']
assert_equal(best_vectorizer.ngram_range, (1, 1))
assert_equal(best_vectorizer.norm, 'l2')
assert_false(best_vectorizer.fixed_vocabulary)
assert_false(best_vectorizer.fixed_vocabulary_)


def test_vectorizer_pipeline_cross_validation():
Expand All @@ -777,7 +779,6 @@ def test_vectorizer_pipeline_cross_validation():
# label junk food as -1, the others as +1
target = [-1] * len(JUNK_FOOD_DOCS) + [1] * len(NOTJUNK_FOOD_DOCS)


pipeline = Pipeline([('vect', TfidfVectorizer()),
('svc', LinearSVC())])

Expand Down Expand Up @@ -824,11 +825,10 @@ def test_tfidf_vectorizer_with_fixed_vocabulary():
# non regression smoke test for inheritance issues
vocabulary = ['pizza', 'celeri']
vect = TfidfVectorizer(vocabulary=vocabulary)
assert_true(vect.fixed_vocabulary)
X_1 = vect.fit_transform(ALL_FOOD_DOCS)
X_2 = vect.transform(ALL_FOOD_DOCS)
assert_array_almost_equal(X_1.toarray(), X_2.toarray())
assert_true(vect.fixed_vocabulary)
assert_true(vect.fixed_vocabulary_)


def test_pickling_vectorizer():
Expand Down Expand Up @@ -870,7 +870,8 @@ def test_pickling_transformer():

def test_non_unique_vocab():
vocab = ['a', 'b', 'c', 'a', 'a']
assert_raises(ValueError, CountVectorizer, vocabulary=vocab)
vect = CountVectorizer(vocabulary=vocab)
assert_raises(ValueError, vect.fit, [])


def test_hashingvectorizer_nan_in_docs():
Expand Down Expand Up @@ -901,3 +902,11 @@ def test_tfidfvectorizer_export_idf():
vect = TfidfVectorizer(use_idf=True)
vect.fit(JUNK_FOOD_DOCS)
assert_array_almost_equal(vect.idf_, vect._tfidf.idf_)


def test_vectorizer_vocab_clone():
vect_vocab = TfidfVectorizer(vocabulary=["the"])
vect_vocab_clone = clone(vect_vocab)
vect_vocab.fit(ALL_FOOD_DOCS)
vect_vocab_clone.fit(ALL_FOOD_DOCS)
assert_equal(vect_vocab_clone.vocabulary_, vect_vocab.vocabulary_)
68 changes: 42 additions & 26 deletions sklearn/feature_extraction/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@
from ..preprocessing import normalize
from .hashing import FeatureHasher
from .stop_words import ENGLISH_STOP_WORDS
from sklearn.externals import six
from ..utils import deprecated
from ..externals import six

__all__ = ['CountVectorizer',
'ENGLISH_STOP_WORDS',
Expand Down Expand Up @@ -236,6 +237,38 @@ def build_analyzer(self):
raise ValueError('%s is not a valid tokenization scheme/analyzer' %
self.analyzer)

def _check_vocabulary(self):
vocabulary = self.vocabulary
if vocabulary is not None:
if not isinstance(vocabulary, Mapping):
vocab = {}
for i, t in enumerate(vocabulary):
if vocab.setdefault(t, i) != i:
msg = "Duplicate term in vocabulary: %r" % t
raise ValueError(msg)
vocabulary = vocab
else:
indices = set(six.itervalues(vocabulary))
if len(indices) != len(vocabulary):
raise ValueError("Vocabulary contains repeated indices.")
for i in xrange(len(vocabulary)):
if i not in indices:
msg = ("Vocabulary of size %d doesn't contain index "
"%d." % (len(vocabulary), i))
raise ValueError(msg)
if not vocabulary:
raise ValueError("empty vocabulary passed to fit")
self.fixed_vocabulary_ = True
self.vocabulary_ = dict(vocabulary)
else:
self.fixed_vocabulary_ = False

@property
@deprecated("The `fixed_vocabulary` attribute is deprecated and will be "
"removed in 0.18. Please use `fixed_vocabulary_` instead.")
def fixed_vocabulary(self):
return self.fixed_vocabulary_


class HashingVectorizer(BaseEstimator, VectorizerMixin):
"""Convert a collection of text documents to a matrix of token occurrences
Expand Down Expand Up @@ -616,29 +649,7 @@ def __init__(self, input='content', encoding='utf-8',
"max_features=%r, neither a positive integer nor None"
% max_features)
self.ngram_range = ngram_range
if vocabulary is not None:
if not isinstance(vocabulary, Mapping):
vocab = {}
for i, t in enumerate(vocabulary):
if vocab.setdefault(t, i) != i:
msg = "Duplicate term in vocabulary: %r" % t
raise ValueError(msg)
vocabulary = vocab
else:
indices = set(six.itervalues(vocabulary))
if len(indices) != len(vocabulary):
raise ValueError("Vocabulary contains repeated indices.")
for i in xrange(len(vocabulary)):
if i not in indices:
msg = ("Vocabulary of size %d doesn't contain index "
"%d." % (len(vocabulary), i))
raise ValueError(msg)
if not vocabulary:
raise ValueError("empty vocabulary passed to fit")
self.fixed_vocabulary = True
self.vocabulary_ = dict(vocabulary)
else:
self.fixed_vocabulary = False
self.vocabulary = vocabulary
self.binary = binary
self.dtype = dtype

Expand Down Expand Up @@ -773,16 +784,18 @@ def fit_transform(self, raw_documents, y=None):
# We intentionally don't call the transform method to make
# fit_transform overridable without unwanted side effects in
# TfidfVectorizer.
self._check_vocabulary()
max_df = self.max_df
min_df = self.min_df
max_features = self.max_features

vocabulary, X = self._count_vocab(raw_documents, self.fixed_vocabulary)
vocabulary, X = self._count_vocab(raw_documents,
self.fixed_vocabulary_)

if self.binary:
X.data.fill(1)

if not self.fixed_vocabulary:
if not self.fixed_vocabulary_:
X = self._sort_features(X, vocabulary)

n_doc = X.shape[0]
Expand Down Expand Up @@ -820,6 +833,9 @@ def transform(self, raw_documents):
X : sparse matrix, [n_samples, n_features]
Document-term matrix.
"""
if not hasattr(self, 'vocabulary_'):
self._check_vocabulary()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is needed to make an existing test case pass. The test case calls transform without calling fit. I think it's cleaner and more consistent not to allow such a call.


if not hasattr(self, 'vocabulary_') or len(self.vocabulary_) == 0:
raise ValueError("Vocabulary wasn't fitted or is empty!")

Expand Down
0