8000 FIX raise error for max_df and min_df greater than 1 in Vectorizer (#… · samronsin/scikit-learn@5e8e5a6 · GitHub
[go: up one dir, main page]

Skip to content

Commit 5e8e5a6

Browse files
AlekLefebvreAlek Lefebvreglemaitre
authored andcommitted
FIX raise error for max_df and min_df greater than 1 in Vectorizer (scikit-learn#20752)
Co-authored-by: Alek Lefebvre <info@aleklefebvre.ca> Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
1 parent 8545cc8 commit 5e8e5a6

File tree

3 files changed

+52
-10
lines changed

3 files changed

+52
-10
lines changed

doc/whats_new/v1.1.rst

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,13 +46,21 @@ Changelog
4646
:pr:`21032` by :user:`Guillaume Lemaitre <glemaitre>`.
4747

4848
:mod:`sklearn.ensemble`
49-
...........................
49+
.......................
5050

5151
- |Fix| Fixed a bug that could produce a segfault in rare cases for
5252
:class:`ensemble.HistGradientBoostingClassifier` and
5353
:class:`ensemble.HistGradientBoostingRegressor`.
5454
:pr:`21130` :user:`Christian Lorentzen <lorentzenchr>`.
5555

56+
:mod:`sklearn.feature_extraction`
57+
.................................
58+
59+
- |Fix| Fixed a bug in :class:`feature_extraction.CountVectorizer` and
60+
:class:`feature_extraction.TfidfVectorizer` by raising an
61+
error when 'min_idf' or 'max_idf' are floating-point numbers greater than 1.
62+
:pr:`20752` by :user:`Alek Lefebvre <AlekLefebvre>`.
63+
5664
:mod:`sklearn.linear_model`
5765
...........................
5866

sklearn/feature_extraction/tests/test_text.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -832,6 +832,31 @@ def test_vectorizer_min_df():
832832
assert len(vect.stop_words_) == 5
833833

834834

835+
@pytest.mark.parametrize(
836+
"params, err_type, message",
837+
(
838+
({"max_df": 2.0}, ValueError, "max_df == 2.0, must be <= 1.0."),
839+
({"min_df": 1.5}, ValueError, "min_df == 1.5, must be <= 1.0."),
840+
({"max_df": -2}, ValueError, "max_df == -2, must be >= 0."),
841+
({"min_df": -10}, ValueError, "min_df == -10, must be >= 0."),
842+
({"min_df": 3, "max_df": 2.0}, ValueError, "max_df == 2.0, must be <= 1.0."),
843+
({"min_df": 1.5, "max_df": 50}, ValueError, "min_df == 1.5, must be <= 1.0."),
844+
({"max_features": -10}, ValueError, "max_features == -10, must be >= 0."),
845+
(
846+
{"max_features": 3.5},
847+
TypeError,
848+
"max_features must be an instance of <class 'numbers.Integral'>, not <class"
849+
" 'float'>",
850+
),
851+
),
852+
)
853+
def test_vectorizer_params_validation(params, err_type, message):
854+
with pytest.raises(err_type, match=message):
855+
test_data = ["abc", "dea", "eat"]
856+
vect = CountVectorizer(**params, analyzer="char")
857+
vect.fit(test_data)
858+
859+
835860
# TODO: Remove in 1.2 when get_feature_names is removed.
836861
@pytest.mark.filterwarnings("ignore::FutureWarning:sklearn")
837862
@pytest.mark.parametrize("get_names", ["get_feature_names", "get_feature_names_out"])

sklearn/feature_extraction/text.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from ..preprocessing import normalize
3030
from ._hash import FeatureHasher
3131
from ._stop_words import ENGLISH_STOP_WORDS
32-
from ..utils.validation import check_is_fitted, check_array, FLOAT_DTYPES
32+
from ..utils.validation import check_is_fitted, check_array, FLOAT_DTYPES, check_scalar
3333
from ..utils.deprecation import deprecated
3434
from ..utils import _IS_32BIT
3535
from ..utils.fixes import _astype_copy_false
@@ -1120,15 +1120,7 @@ def __init__(
11201120
self.stop_words = stop_words
11211121
self.max_df = max_df
11221122
self.min_df = min_df
1123-
if max_df < 0 or min_df < 0:
1124-
raise ValueError("negative value for max_df or min_df")
11251123
self.max_features = max_features
1126-
if max_features is not None:
1127-
if not isinstance(max_features, numbers.Integral) or max_features <= 0:
1128-
raise ValueError(
1129-
"max_features=%r, neither a positive integer nor None"
1130-
% max_features
1131-
)
11321124
self.ngram_range = ngram_range
11331125
self.vocabulary = vocabulary
11341126
self.binary = binary
@@ -1265,6 +1257,23 @@ def _count_vocab(self, raw_documents, fixed_vocab):
12651257
X.sort_indices()
12661258
return vocabulary, X
12671259

1260+
def _validate_params(self):
1261+
"""Validation of min_df, max_df and max_features"""
1262+
super()._validate_params()
1263+
1264+
if self.max_features is not None:
1265+
check_scalar(self.max_features, "max_features", numbers.Integral, min_val=0)
1266+
1267+
if isinstance(self.min_df, numbers.Integral):
1268+
check_scalar(self.min_df, "min_df", numbers.Integral, min_val=0)
1269+
else:
1270+
check_scalar(self.min_df, "min_df", numbers.Real, min_val=0.0, max_val=1.0)
1271+
1272+
if isinstance(self.max_df, numbers.Integral):
1273+
check_scalar(self.max_df, "max_df", numbers.Integral, min_val=0)
1274+
else:
1275+
check_scalar(self.max_df, "max_df", numbers.Real, min_val=0.0, max_val=1.0)
1276+
12681277
def fit(self, raw_documents, y=None):
12691278
"""Learn a vocabulary dictionary of all tokens in the raw documents.
12701279

0 commit comments

Comments
 (0)
0