8000 FIX an issue w/ large sparse matrix indices in CountVectorizer (#11295) · scikit-learn/scikit-learn@5fc5c6e · GitHub
[go: up one dir, main page]

Skip to content

Commit 5fc5c6e

Browse files
gvacaliucjnothman
authored andcommitted
FIX an issue w/ large sparse matrix indices in CountVectorizer (#11295)
1 parent fdf2f38 commit 5fc5c6e

File tree

3 files changed

+45
-8
lines changed

3 files changed

+45
-8
lines changed

doc/whats_new/v0.20.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,14 @@ Changelog
5151
combination with ``handle_unknown='ignore'``.
5252
:issue:`12881` by `Joris Van den Bossche`_.
5353

54+
:mod:`sklearn.feature_extraction.text`
55+
......................................
56+
57+
- |Fix| Fixed a bug in :class:`feature_extraction.text.CountVectorizer` which
58+
would result in the sparse feature matrix having conflicting `indptr` and
59+
`indices` precisions under very large vocabularies. :issue:`11295` by
60+
:user:`Gabriel Vacaliuc <gvacaliuc>`.
61+
5462
.. _changes_0_20_2:
5563

5664
Version 0.20.2
< 8000 /code>

sklearn/feature_extraction/tests/test_text.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@
3636
assert_warns_message, assert_raise_message,
3737
clean_warning_registry, ignore_warnings,
3838
SkipTest, assert_raises, assert_no_warnings,
39-
fails_if_pypy, assert_allclose_dense_sparse)
39+
fails_if_pypy, assert_allclose_dense_sparse,
40+
skip_if_32bit)
4041
from collections import defaultdict
4142
from functools import partial
4243
import pickle
@@ -1144,6 +1145,35 @@ def test_vectorizer_stop_words_inconsistent():
11441145
['hello world'])
11451146

11461147

1148+
@skip_if_32bit
1149+
def test_countvectorizer_sort_features_64bit_sparse_indices():
1150+
"""
1151+
Check that CountVectorizer._sort_features preserves the dtype of its sparse
1152+
feature matrix.
1153+
1154+
This test is skipped on 32bit platforms, see:
1155+
https://github.com/scikit-learn/scikit-learn/pull/11295
1156+
for more details.
1157+
"""
1158+
1159+
X = sparse.csr_matrix((5, 5), dtype=np.int64)
1160+
1161+
# force indices and indptr to int64.
1162+
INDICES_DTYPE = np.int64
1163+
X.indices = X.indices.astype(INDICES_DTYPE)
1164+
X.indptr = X.indptr.astype(INDICES_DTYPE)
1165+
1166+
vocabulary = {
1167+
"scikit-learn": 0,
1168+
"is": 1,
1169+
"great!": 2
1170+
}
1171+
1172+
Xs = CountVectorizer()._sort_features(X, vocabulary)
1173+
1174+
assert INDICES_DTYPE == Xs.indices.dtype
1175+
1176+
11471177
@fails_if_pypy
11481178
@pytest.mark.parametrize('Estimator',
11491179
[CountVectorizer, TfidfVectorizer, HashingVectorizer])

sklearn/feature_extraction/text.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from .stop_words import ENGLISH_STOP_WORDS
3232
from ..utils.validation import check_is_fitted, check_array, FLOAT_DTYPES
3333
from ..utils.fixes import sp_version
34+
from ..utils import _IS_32BIT
3435

3536

3637
__all__ = ['HashingVectorizer',
@@ -871,7 +872,7 @@ def _sort_features(self, X, vocabulary):
871872
Returns a reordered matrix and modifies the vocabulary in place
872873
"""
873874
sorted_features = sorted(vocabulary.items())
874-
map_index = np.empty(len(sorted_features), dtype=np.int32)
875+
map_index = np.empty(len(sorted_features), dtype=X.indices.dtype)
875876
for new_val, (term, old_val) in enumerate(sorted_features):
876877
vocabulary[term] = new_val
877878
map_index[old_val] = new_val
@@ -961,14 +962,12 @@ def _count_vocab(self, raw_documents, fixed_vocab):
961962
" contain stop wor 6D40 ds")
962963

963964
if indptr[-1] > 2147483648: # = 2**31 - 1
964-
if sp_version >= (0, 14):
965-
indices_dtype = np.int64
966-
else:
965+
if _IS_32BIT:
967966
raise ValueError(('sparse CSR array has {} non-zero '
968967
'elements and requires 64 bit indexing, '
969-
' which is unsupported with scipy {}. '
970-
'Please upgrade to scipy >=0.14')
971-
.format(indptr[-1], '.'.join(sp_version)))
968+
'which is unsupported with 32 bit Python.')
969+
.format(indptr[-1]))
970+
indices_dtype = np.int64
972971

973972
else:
974973
indices_dtype = np.int32

0 commit comments

Comments
 (0)
0