8000 FIX CountVectorizer: check upper case in vocab only in fit (#21251) · scikit-learn/scikit-learn@ff00693 · GitHub
[go: up one dir, main page]

Skip to content

Commit ff00693

Browse files
jeremiedbbglemaitre
authored andcommitted
FIX CountVectorizer: check upper case in vocab only in fit (#21251)
1 parent 7d74e93 commit ff00693

File tree

3 files changed

+28
-13
lines changed

3 files changed

+28
-13
lines changed

doc/whats_new/v1.0.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,14 @@ Fixed models
4343
between sparse and dense input. :pr:`21195`
4444
by :user:`Jérémie du Boisberranger <jeremiedbb>`.
4545

46+
:mod:`sklearn.feature_extraction`
47+
.................................
48+
49+
- |Efficiency| Fixed an efficiency regression introduced in version 1.0.0 in the
50+
`transform` method of :class:`feature_extraction.text.CountVectorizer` which no
51+
longer checks for uppercase characters in the provided vocabulary. :pr:`21251`
52+
by :user:`Jérémie du Boisberranger <jeremiedbb>`.
53+
4654
:mod:`sklearn.linear_model`
4755
...........................
4856

sklearn/feature_extraction/tests/test_text.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -436,7 +436,9 @@ def test_countvectorizer_custom_token_pattern_with_several_group():
436436

437437

438438
def test_countvectorizer_uppercase_in_vocab():
439-
vocabulary = ["Sample", "Upper", "CaseVocabulary"]
439+
# Check that the check for uppercase in the provided vocabulary is only done at fit
440+
# time and not at transform time (#21251)
441+
vocabulary = ["Sample", "Upper", "Case", "Vocabulary"]
440442
message = (
441443
"Upper case characters found in"
442444
" vocabulary while 'lowercase'"
@@ -445,8 +447,13 @@ def test_countvectorizer_uppercase_in_vocab():
445447
)
446448

447449
vectorizer = CountVectorizer(lowercase=True, vocabulary=vocabulary)
450+
448451
with pytest.warns(UserWarning, match=message):
449-
vectorizer.fit_transform(vocabulary)
452+
vectorizer.fit(vocabulary)
453+
454+
with pytest.warns(None) as record:
455+
vectorizer.transform(vocabulary)
456+
assert not record
450457

451458

452459
def test_tf_transformer_feature_names_out():

sklearn/feature_extraction/text.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1194,17 +1194,6 @@ def _count_vocab(self, raw_documents, fixed_vocab):
11941194
j_indices = []
11951195
indptr = []
11961196

1197-
if self.lowercase:
1198-
for vocab in vocabulary:
1199-
if any(map(str.isupper, vocab)):
1200-
warnings.warn(
1201-
"Upper case characters found in"
1202-
" vocabulary while 'lowercase'"
1203-
" is True. These entries will not"
1204-
" be matched with any documents"
1205-
)
1206-
break
1207-
12081197
values = _make_int_array()
12091198
indptr.append(0)
12101199
for doc in raw_documents:
@@ -1327,6 +1316,17 @@ def fit_transform(self, raw_documents, y=None):
13271316
min_df = self.min_df
13281317
max_features = self.max_features
13291318

1319+
if self.fixed_vocabulary_ and self.lowercase:
1320+
for term in self.vocabulary:
1321+
if any(map(str.isupper, term)):
1322+
warnings.warn(
1323+
"Upper case characters found in"
1324+
" vocabulary while 'lowercase'"
1325+
" is True. These entries will not"
1326+
" be matched with any documents"
1327+
)
1328+
break
1329+
13301330
vocabulary, X = self._count_vocab(raw_documents, self.fixed_vocabulary_)
13311331

13321332
if self.binary:

0 commit comments

Comments
 (0)
0