8000 Revert "FIX an issue w/ large sparse matrix indices in CountVectorize… · xhluca/scikit-learn@2716628 · GitHub
[go: up one dir, main page]

Skip to content

Commit 2716628

Browse files
author
Xing
committed
Revert "FIX an issue w/ large sparse matrix indices in CountVectorizer (scikit-learn#11295)"
This reverts commit 1da72e7.
1 parent 710c54b commit 2716628

File tree

3 files changed

+8
-45
lines changed

3 files changed

+8
-45
lines changed

doc/whats_new/v0.20.rst

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -51,14 +51,6 @@ Changelog
5151
combination with ``handle_unknown='ignore'``.
5252
:issue:`12881` by `Joris Van den Bossche`_.
5353

54-
:mod:`sklearn.feature_extraction.text`
55-
......................................
56-
57-
- |Fix| Fixed a bug in :class:`feature_extraction.text.CountVectorizer` which
58-
would result in the sparse feature matrix having conflicting `indptr` and
59-
`indices` precisions under very large vocabularies. :issue:`11295` by
60-
:user:`Gabriel Vacaliuc <gvacaliuc>`.
61-
6254
.. _changes_0_20_2:
6355

6456
Version 0.20.2

sklearn/feature_extraction/tests/test_text.py

Lines changed: 1 addition & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,7 @@
3636
assert_warns_message, assert_raise_message,
3737
clean_warning_registry, ignore_warnings,
3838
SkipTest, assert_raises, assert_no_warnings,
39-
fails_if_pypy, assert_allclose_dense_sparse,
40-
skip_if_32bit)
39+
fails_if_pypy, assert_allclose_dense_sparse)
4140
from collections import defaultdict
4241
from functools import partial
4342
import pickle
@@ -1145,35 +1144,6 @@ def test_vectorizer_stop_words_inconsistent():
11451144
['hello world'])
11461145

11471146

1148-
@skip_if_32bit
1149-
def test_countvectorizer_sort_features_64bit_sparse_indices():
1150-
"""
1151-
Check that CountVectorizer._sort_features preserves the dtype of its sparse
1152-
feature matrix.
1153-
1154-
This test is skipped on 32bit platforms, see:
1155-
https://github.com/scikit-learn/scikit-learn/pull/11295
1156-
for more details.
1157-
"""
1158-
1159-
X = sparse.csr_matrix((5, 5), dtype=np.int64)
1160-
1161-
# force indices and indptr to int64.
1162-
INDICES_DTYPE = np.int64
1163-
X.indices = X.indices.astype(INDICES_DTYPE)
1164-
X.indptr = X.indptr.astype(INDICES_DTYPE)
1165-
1166-
vocabulary = {
1167-
"scikit-learn": 0,
1168-
"is": 1,
1169-
"great!": 2
1170-
}
1171-
1172-
Xs = CountVectorizer()._sort_features(X, vocabulary)
1173-
1174-
assert INDICES_DTYPE == Xs.indices.dtype
1175-
1176-
11771147
@fails_if_pypy
11781148
@pytest.mark.parametrize('Estimator',
11791149
[CountVectorizer, TfidfVectorizer, HashingVectorizer])

sklearn/feature_extraction/text.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
from .stop_words import ENGLISH_STOP_WORDS
3232
from ..utils.validation import check_is_fitted, check_array, FLOAT_DTYPES
3333
from ..utils.fixes import sp_version
34-
from ..utils import _IS_32BIT
3534

3635

3736
__all__ = ['HashingVectorizer',
@@ -872,7 +871,7 @@ def _sort_features(self, X, vocabulary):
872871
Returns a reordered matrix and modifies the vocabulary in place
873872
"""
874873
sorted_features = sorted(vocabulary.items())
875-
map_index = np.empty(len(sorted_features), dtype=X.indices.dtype)
874+
map_index = np.empty(len(sorted_features), dtype=np.int32)
876875
for new_val, (term, old_val) in enumerate(sorted_features):
877876
vocabulary[term] = new_val
878877
map_index[old_val] = new_val
@@ -962,12 +961,14 @@ def _count_vocab(self, raw_documents, fixed_vocab):
962961
" contain stop words")
963962

964963
if indptr[-1] > 2147483648: # = 2**31 - 1
965-
if _IS_32BIT:
964+
if sp_version >= (0, 14):
965+
indices_dtype = np.int64
966+
else:
966967
raise ValueError(('sparse CSR array has {} non-zero '
967968
'elements and requires 64 bit indexing, '
968-
'which is unsupported with 32 bit Python.')
969-
.format(indptr[-1]))
970-
indices_dtype = np.int64
969+
' which is unsupported with scipy {}. '
970+
'Please upgrade to scipy >=0.14')
971+
.format(indptr[-1], '.'.join(sp_version)))
971972

972973
else:
973974
indices_dtype = np.int32

0 commit comments

Comments
 (0)
0