diff --git a/doc/whats_new/v1.0.rst b/doc/whats_new/v1.0.rst index 40076c0b275dd..21c150e49de37 100644 --- a/doc/whats_new/v1.0.rst +++ b/doc/whats_new/v1.0.rst @@ -43,6 +43,14 @@ Fixed models between sparse and dense input. :pr:`21195` by :user:`Jérémie du Boisberranger `. +:mod:`sklearn.feature_extraction` +................................. + +- |Efficiency| Fixed an efficiency regression introduced in version 1.0.0 in the + `transform` method of :class:`feature_extraction.text.CountVectorizer` which no + longer checks for uppercase characters in the provided vocabulary. :pr:`21251` + by :user:`Jérémie du Boisberranger `. + :mod:`sklearn.linear_model` ........................... diff --git a/sklearn/feature_extraction/tests/test_text.py b/sklearn/feature_extraction/tests/test_text.py index 6abd731b4559a..da32e855fabb6 100644 --- a/sklearn/feature_extraction/tests/test_text.py +++ b/sklearn/feature_extraction/tests/test_text.py @@ -436,7 +436,9 @@ def test_countvectorizer_custom_token_pattern_with_several_group(): def test_countvectorizer_uppercase_in_vocab(): - vocabulary = ["Sample", "Upper", "CaseVocabulary"] + # Check that the check for uppercase in the provided vocabulary is only done at fit + # time and not at transform time (#21251) + vocabulary = ["Sample", "Upper", "Case", "Vocabulary"] message = ( "Upper case characters found in" " vocabulary while 'lowercase'" @@ -445,8 +447,13 @@ def test_countvectorizer_uppercase_in_vocab(): ) vectorizer = CountVectorizer(lowercase=True, vocabulary=vocabulary) + with pytest.warns(UserWarning, match=message): - vectorizer.fit_transform(vocabulary) + vectorizer.fit(vocabulary) + + with pytest.warns(None) as record: + vectorizer.transform(vocabulary) + assert not record def test_tf_transformer_feature_names_out(): diff --git a/sklearn/feature_extraction/text.py b/sklearn/feature_extraction/text.py index a64bbc3ff3737..9124edb455b71 100644 --- a/sklearn/feature_extraction/text.py +++ b/sklearn/feature_extraction/text.py @@ -1194,17 +1194,6 @@ def _count_vocab(self, raw_documents, fixed_vocab): j_indices = [] indptr = [] - if self.lowercase: - for vocab in vocabulary: - if any(map(str.isupper, vocab)): - warnings.warn( - "Upper case characters found in" - " vocabulary while 'lowercase'" - " is True. These entries will not" - " be matched with any documents" - ) - break - values = _make_int_array() indptr.append(0) for doc in raw_documents: @@ -1327,6 +1316,17 @@ def fit_transform(self, raw_documents, y=None): min_df = self.min_df max_features = self.max_features + if self.fixed_vocabulary_ and self.lowercase: + for term in self.vocabulary: + if any(map(str.isupper, term)): + warnings.warn( + "Upper case characters found in" + " vocabulary while 'lowercase'" + " is True. These entries will not" + " be matched with any documents" + ) + break + vocabulary, X = self._count_vocab(raw_documents, self.fixed_vocabulary_) if self.binary: