8000 MNT Fix assert raises in sklearn/feature_extraction/tests/ (#14694) · scikit-learn/scikit-learn@a05c8d8 · GitHub
[go: up one dir, main page]

Skip to content

Commit a05c8d8

Browse files
sameshlqinhanmin2014
authored andcommitted
MNT Fix assert raises in sklearn/feature_extraction/tests/ (#14694)
1 parent a2968c2 commit a05c8d8

File tree

3 files changed

+53
-29
lines changed

3 files changed

+53
-29
lines changed

sklearn/feature_extraction/tests/test_feature_hasher.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11

22
import numpy as np
33
from numpy.testing import assert_array_equal
4+
import pytest
45

56
from sklearn.feature_extraction import FeatureHasher
6-
from sklearn.utils.testing import (assert_raises, ignore_warnings,
7+
from sklearn.utils.testing import (ignore_warnings,
78
fails_if_pypy)
89

910
pytestmark = fails_if_pypy
@@ -86,22 +87,30 @@ def test_hash_empty_input():
8687

8788

8889
def test_hasher_invalid_input():
89-
assert_raises(ValueError, FeatureHasher, input_type="gobbledygook")
90-
assert_raises(ValueError, FeatureHasher, n_features=-1)
91-
assert_raises(ValueError, FeatureHasher, n_features=0)
92-
assert_raises(TypeError, FeatureHasher, n_features='ham')
90+
with pytest.raises(ValueError):
91+
FeatureHasher(input_type="gobbledygook")
92+
with pytest.raises(ValueError):
93+
FeatureHasher(n_features=-1)
94+
with pytest.raises(ValueError):
95+
FeatureHasher(n_features=0)
96+
with pytest.raises(TypeError):
97+
FeatureHasher(n_features='ham')
9398

9499
h = FeatureHasher(n_features=np.uint16(2 ** 6))
95-
assert_raises(ValueError, h.transform, [])
96-
assert_raises(Exception, h.transform, [[5.5]])
97-
assert_raises(Exception, h.transform, [[None]])
100+
with pyt 10000 est.raises(ValueError):
101+
h.transform([])
102+
with pytest.raises(Exception):
103+
h.transform([[5.5]])
104+
with pytest.raises(Exception):
105+
h.transform([[None]])
98106

99107

100108
def test_hasher_set_params():
101109
# Test delayed input validation in fit (useful for grid search).
102110
hasher = FeatureHasher()
103111
hasher.set_params(n_features=np.inf)
104-
assert_raises(TypeError, hasher.fit)
112+
with pytest.raises(TypeError):
113+
hasher.fit()
105114

106115

107116
def test_hasher_zeros():

sklearn/feature_extraction/tests/test_image.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,12 @@
66
import scipy as sp
77
from scipy import ndimage
88
from scipy.sparse.csgraph import connected_components
9+
import pytest
910

1011
from sklearn.feature_extraction.image import (
1112
img_to_graph, grid_to_graph, extract_patches_2d,
1213
reconstruct_from_patches_2d, PatchExtractor, extract_patches)
13-
from sklearn.utils.testing import assert_raises, ignore_warnings
14+
from sklearn.utils.testing import ignore_warnings
1415

1516

1617
def test_img_to_graph():
@@ -172,10 +173,10 @@ def test_extract_patches_max_patches():
172173
patches = extract_patches_2d(face, (p_h, p_w), max_patches=0.5)
173174
assert patches.shape == (expected_n_patches, p_h, p_w)
174175

175-
assert_raises(ValueError, extract_patches_2d, face, (p_h, p_w),
176-
max_patches=2.0)
177-
assert_raises(ValueError, extract_patches_2d, face, (p_h, p_w),
178-
max_patches=-1.0)
176+
with pytest.raises(ValueError):
177+
extract_patches_2d(face, (p_h, p_w), max_patches=2.0)
178+
with pytest.raises(ValueError):
179+
extract_patches_2d(face, (p_h, p_w), max_patches=-1.0)
179180

180181

181182
def test_extract_patch_same_size_image():
@@ -328,5 +329,7 @@ def test_extract_patches_square():
328329
def test_width_patch():
329330
# width and height of the patch should be less than the image
330331
x = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
331-
assert_raises(ValueError, extract_patches_2d, x, (4, 1))
332-
assert_raises(ValueError, extract_patches_2d, x, (1, 4))
332+
with pytest.raises(ValueError):
333+
extract_patches_2d(x, (4, 1))
334+
with pytest.raises(ValueError):
335+
extract_patches_2d(x, (1, 4))

sklearn/feature_extraction/tests/test_text.py

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from sklearn.utils.testing import (assert_almost_equal,
3434
assert_warns_message, assert_raise_message,
3535
clean_warning_registry,
36-
SkipTest, assert_raises, assert_no_warnings,
36+
SkipTest, assert_no_warnings,
3737
fails_if_pypy, assert_allclose_dense_sparse,
3838
skip_if_32bit)
3939
from collections import defaultdict
@@ -178,11 +178,13 @@ def test_unicode_decode_error():
178178
# Then let the Analyzer try to decode it as ascii. It should fail,
179179
# because we have given it an incorrect encoding.
180180
wa = CountVectorizer(ngram_range=(1, 2), encoding='ascii').build_analyzer()
181-
assert_raises(UnicodeDecodeError, wa, text_bytes)
181+
with pytest.raises(UnicodeDecodeError):
182+
wa(text_bytes)
182183

183184
ca = CountVectorizer(analyzer='char', ngram_range=(3, 6),
184185
encoding='ascii').build_analyzer()
185-
assert_raises(UnicodeDecodeError, ca, text_bytes)
186+
with pytest.raises(UnicodeDecodeError):
187+
ca(text_bytes)
186188

187189

188190
def test_char_ngram_analyzer():
@@ -299,9 +301,11 @@ def test_countvectorizer_stop_words():
299301
cv.set_params(stop_words='english')
300302
assert cv.get_stop_words() == ENGLISH_STOP_WORDS
301303
cv.set_params(stop_words='_bad_str_stop_')
302-
assert_raises(ValueError, cv.get_stop_words)
304+
with pytest.raises(ValueError):
305+
cv.get_stop_words()
303306
cv.set_params(stop_words='_bad_unicode_stop_')
304-
assert_raises(ValueError, cv.get_stop_words)
307+
with pytest.raises(ValueError):
308+
cv.get_stop_words()
305309
stoplist = ['some', 'other', 'words']
306310
cv.set_params(stop_words=stoplist)
307311
assert cv.get_stop_words() == set(stoplist)
@@ -451,15 +455,17 @@ def test_vectorizer():
451455

452456
# test idf transform with unlearned idf vector
453457
t3 = TfidfTransformer(use_idf=True)
454-
assert_raises(ValueError, t3.transform, counts_train)
458+
with pytest.raises(ValueError):
459+
t3.transform(counts_train)
455460

456461
# test idf transform with incompatible n_features
457462
X = [[1, 1, 5],
458463
[1, 1, 0]]
459464
t3.fit(X)
460465
X_incompt = [[1, 3],
461466
[1, 3]]
462-
assert_raises(ValueError, t3.transform, X_incompt)
467+
with pytest.raises(ValueError):
468+
t3.transform(X_incompt)
463469

464470
# L1-normalized term frequencies sum to one
465471
assert_array_almost_equal(np.sum(tf, axis=1), [1.0] * n_train)
@@ -480,7 +486,8 @@ def test_vectorizer():
480486

481487
# test transform on unfitted vectorizer with empty vocabulary
482488
v3 = CountVectorizer(vocabulary=None)
483-
assert_raises(ValueError, v3.transform, train_data)
489+
with pytest.raises(ValueError):
490+
v3.transform(train_data)
484491

485492
# ascii preprocessor?
486493
v3.set_params(strip_accents='ascii', lowercase=False)
@@ -493,11 +500,13 @@ def test_vectorizer():
493500

494501
# error on bad strip_accents param
495502
v3.set_params(strip_accents='_gabbledegook_', preprocessor=None)
496-
assert_raises(ValueError, v3.build_preprocessor)
503+
with pytest.raises(ValueError):
504+
v3.build_preprocessor()
497505

498506
# error with bad analyzer type
499507
v3.set_params = '_invalid_analyzer_type_'
500-
assert_raises(ValueError, v3.build_analyzer)
508+
with pytest.raises(ValueError):
509+
v3.build_analyzer()
501510

502511

503512
def test_tfidf_vectorizer_setters():
@@ -568,7 +577,8 @@ def test_feature_names():
568577
cv = CountVectorizer(max_df=0.5)
569578

570579
# test for Value error on unfitted/empty vocabulary
571-
assert_raises(ValueError, cv.get_feature_names)
580+
with pytest.raises(ValueError):
581+
cv.get_feature_names()
572582
assert not cv.fixed_vocabulary_
573583

574584
# test for vocabulary learned from data
@@ -1014,13 +1024,15 @@ def test_tfidfvectorizer_invalid_idf_attr():
10141024
copy = TfidfVectorizer(vocabulary=vect.vocabulary_, use_idf=True)
10151025
expected_idf_len = len(vect.idf_)
10161026
invalid_idf = [1.0] * (expected_idf_len + 1)
1017-
assert_raises(ValueError, setattr, copy, 'idf_', invalid_idf)
1027+
with pytest.raises(ValueError):
1028+
setattr(copy, 'idf_', invalid_idf)
10181029

10191030

10201031
def test_non_unique_vocab():
10211032
vocab = ['a', 'b', 'c', 'a', 'a']
10221033
vect = CountVectorizer(vocabulary=vocab)
1023-
assert_raises(ValueError, vect.fit, [])
1034+
with pytest.raises(ValueError):
1035+
vect.fit([])
10241036

10251037

10261038
@fails_if_pypy

0 commit comments

Comments
 (0)
0