8000 TST improve feature_extraction.text coverage · deepatdotnet/scikit-learn@446e675 · GitHub
[go: up one dir, main page]

Skip to content

Commit 446e675

Browse files
rlmvlarsmans
authored andcommitted
TST improve feature_extraction.text coverage
1 parent 096610e commit 446e675

File tree

1 file changed

+51
-11
lines changed

1 file changed

+51
-11
lines changed

sklearn/feature_extraction/tests/test_text.py

Lines changed: 51 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from sklearn.feature_extraction.text import strip_tags
33
from sklearn.feature_extraction.text import strip_accents_unicode
44
from sklearn.feature_extraction.text import strip_accents_ascii
5-
from sklearn.feature_extraction.text import _check_stop_list
65

76
from sklearn.feature_extraction.text import HashingVectorizer
87
from sklearn.feature_extraction.text import CountVectorizer
@@ -177,14 +176,6 @@ def test_unicode_decode_error():
177176
assert_raises(UnicodeDecodeError, ca, text_bytes)
178177

179178

180-
def test_check_stop_list():
181-
assert_equal(_check_stop_list('english'), ENGLISH_STOP_WORDS)
182-
assert_raises(ValueError, _check_stop_list, 'bad_str_stop')
183-
assert_raises(ValueError, _check_stop_list, u'bad_unicode_stop')
184-
stoplist = ['some', 'other', 'words']
185-
assert_equal(_check_stop_list(stoplist), stoplist)
186-
187-
188179
def test_char_ngram_analyzer():
189180
cnga = CountVectorizer(analyzer='char', strip_accents='unicode',
190181
ngram_range=(3, 6)).build_analyzer()
@@ -255,6 +246,19 @@ def test_countvectorizer_custom_vocabulary_pipeline():
255246
assert_equal(X.shape[1], len(what_we_like))
256247

257248

249+
def test_countvectorizer_stop_words():
250+
cv = CountVectorizer()
251+
cv.set_params(stop_words='english')
252+
assert_equal(cv.get_stop_words(), ENGLISH_STOP_WORDS)
253+
cv.set_params(stop_words='_bad_str_stop_')
254+
assert_raises(ValueError, cv.get_stop_words)
255+
cv.set_params(stop_words=u'_bad_unicode_stop_')
256+
assert_raises(ValueError, cv.get_stop_words)
257+
stoplist = ['some', 'other', 'words']
258+
cv.set_params(stop_words=stoplist)
259+
assert_equal(cv.get_stop_words(), stoplist)
260+
261+
258262
def test_countvectorizer_empty_vocabulary():
259263
try:
260264
CountVectorizer(vocabulary=[])
@@ -400,11 +404,19 @@ def test_vectorizer():
400404
t2 = TfidfTransformer(norm='l1', use_idf=False)
401405
tf = t2.fit(counts_train).transform(counts_train).toarray()
402406
assert_equal(t2.idf_, None)
403-
407+
404408
# test idf transform with unlearned idf vector
405409
t3 = TfidfTransformer(use_idf=True)
406410
assert_raises(ValueError, t3.transform, counts_train)
407411

412+
# test idf transform with incompatible n_features
413+
X = [[1, 1, 5],
414+
[1, 1, 0]]
415+
t3.fit(X)
416+
X_incompt = [[1, 3],
417+
[1, 3]]
418+
assert_raises(ValueError, t3.transform, X_incompt)
419+
408420
# L1-normalized term frequencies sum to one
409421
assert_array_almost_equal(np.sum(tf, axis=1), [1.0] * n_train)
410422

@@ -426,6 +438,31 @@ def test_vectorizer():
426438
v3 = CountVectorizer(vocabulary=None)
427439
assert_raises(ValueError, v3.transform, train_data)
428440

441+
# ascii preprocessor?
442+
v3.set_params(strip_accents='ascii', lowercase=False)
443+
assert_equal(v3.build_preprocessor(), strip_accents_ascii)
444+
445+
# error on bad strip_accents param
446+
v3.set_params(strip_accents='_gabbledegook_', preprocessor=None)
447+
assert_raises(ValueError, v3.build_preprocessor)
448+
449+
# error with bad analyzer type
450+
v3.set_params = '_invalid_analyzer_type_'
451+
assert_raises(ValueError, v3.build_analyzer)
452+
453+
454+
def test_tfidf_vectorizer_setters():
455+
tv = TfidfVectorizer(norm='l2', use_idf=False,
456+
smooth_idf=False, sublinear_tf=False)
457+
tv.norm = 'l1'
458+
assert_equal(tv._tfidf.norm, 'l1')
459+
tv.use_idf = True
460+
assert_true(tv._tfidf.use_idf)
461+
tv.smooth_idf = True
462+
assert_true(tv._tfidf.smooth_idf)
463+
tv.sublinear_tf = True
464+
assert_true(tv._tfidf.sublinear_tf)
465+
429466

430467
def test_hashing_vectorizer():
431468
v = HashingVectorizer()
@@ -467,8 +504,11 @@ def test_hashing_vectorizer():
467504

468505
def test_feature_names():
469506
cv = CountVectorizer(max_df=0.5, min_df=1)
470-
X = cv.fit_transform(ALL_FOOD_DOCS)
471507

508+
# test for Value error on unfitted/empty vocabulary
509+
assert_raises(ValueError, cv.get_feature_names)
510+
511+
X = cv.fit_transform(ALL_FOOD_DOCS)
472512
n_samples, n_features = X.shape
473513
assert_equal(len(cv.vocabulary_), n_features)
474514

0 commit comments

Comments
 (0)
0