8000 BUG Corrects tag in TfidfTransformer (#20919) · scikit-learn/scikit-learn@cfc1695 · GitHub
[go: up one dir, main page]

Skip to content

Commit cfc1695

Browse files
authored
BUG Corrects tag in TfidfTransformer (#20919)
1 parent 267b617 commit cfc1695

File tree

2 files changed

+23
-2
lines changed

2 files changed

+23
-2
lines changed

sklearn/feature_extraction/text.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1559,7 +1559,12 @@ def fit(self, X, y=None):
15591559
self : object
15601560
Fitted transformer.
15611561
"""
1562-
X = self._validate_data(X, accept_sparse=("csr", "csc"))
1562+
# large sparse data is not supported for 32bit platforms because
1563+
# _document_frequency uses np.bincount which works on arrays of
1564+
# dtype NPY_INTP which is int32 for 32bit platforms. See #20923
1565+
X = self._validate_data(
1566+
X, accept_sparse=("csr", "csc"), accept_large_sparse=not _IS_32BIT
1567+
)
15631568
if not sp.issparse(X):
15641569
X = sp.csr_matrix(X)
15651570
dtype = X.dtype if X.dtype in FLOAT_DTYPES else np.float64
@@ -1648,7 +1653,7 @@ def idf_(self, value):
16481653
)
16491654

16501655
def _more_tags(self):
1651-
return {"X_types": "sparse"}
1656+
return {"X_types": ["2darray", "sparse"]}
16521657

16531658

16541659
class TfidfVectorizer(CountVectorizer):

sklearn/tests/test_common.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from sklearn.pipeline import make_pipeline
3838

3939
from sklearn.utils import IS_PYPY
40+
from sklearn.utils._tags import _DEFAULT_TAGS, _safe_tags
4041
from sklearn.utils._testing import (
4142
SkipTest,
4243
set_random_state,
@@ -308,6 +309,21 @@ def test_search_cv(estimator, check, request):
308309
check(estimator)
309310

310311

312+
@pytest.mark.parametrize(
313+
"estimator", _tested_estimators(), ids=_get_check_estimator_ids
314+
)
315+
def test_valid_tag_types(estimator):
316+
"""Check that estimator tags are valid."""
317+
tags = _safe_tags(estimator)
318+
319+
for name, tag in tags.items():
320+
correct_tags = type(_DEFAULT_TAGS[name])
321+
if name == "_xfail_checks":
322+
# _xfail_checks can be a dictionary
323+
correct_tags = (correct_tags, dict)
324+
assert isinstance(tag, correct_tags)
325+
326+
311327
@pytest.mark.parametrize(
312328
"estimator", _tested_estimators(), ids=_get_check_estimator_ids
313329
)

0 commit comments

Comments
 (0)
0