8000 fixes an issue discussed in #7762 · scikit-learn/scikit-learn@359b5c0 · GitHub
[go: up one dir, main page]

Skip to content

Commit 359b5c0

Browse files
committed
fixes an issue discussed in #7762
* added test to show example of issue in #7762 * fixes error caused by manually manipulating sparse indices in `CountVectorizer`
1 parent d990f72 commit 359b5c0

File tree

2 files changed

+26
-1
lines changed

2 files changed

+26
-1
lines changed

sklearn/feature_extraction/tests/test_text.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1140,3 +1140,28 @@ def test_vectorizer_stop_words_inconsistent():
11401140
vec.set_params(stop_words=["you've", "you", "you'll", 'blah', 'AND'])
11411141
assert_warns_message(UserWarning, message, vec.fit_transform,
11421142
['hello world'])
1143+
1144+
1145+
@pytest.mark.parametrize("vec", [CountVectorizer()])
1146+
def test_countvectorizer_sort_features_64bit_sparse_indices(vec):
1147+
# If a count vectorizer has to store >= 2**31 count values, the sparse
1148+
# storage matrix has 64bit indices / indptrs. This requires ~2*8*2**31
1149+
# bytes of memory in practice, so we just test the method that would
1150+
# hypothetically fail.
1151+
1152+
X = sparse.csr_matrix((5, 5), dtype=np.int64)
1153+
1154+
# force indices and indptr to int64.
1155+
INDEX_DTYPE = np.int64
1156+
X.indices = X.indices.astype(INDEX_DTYPE, copy=False)
1157+
X.indptr = X.indptr.astype(INDEX_DTYPE, copy=False)
1158+
1159+
vocabulary = {
1160+
"scikit-learn": 0,
1161+
"is": 1,
1162+
"great!": 2
1163+
}
1164+
1165+
vec._sort_features(X, vocabulary)
1166+
1167+
assert_equal(INDEX_DTYPE, X.indices.dtype)

sklearn/feature_extraction/text.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -852,7 +852,7 @@ def _sort_features(self, X, vocabulary):
852852
Returns a reordered matrix and modifies the vocabulary in place
853853
"""
854854
sorted_features = sorted(six.iteritems(vocabulary))
855-
map_index = np.empty(len(sorted_features), dtype=np.int32)
855+
map_index = np.empty(len(sorted_features), dtype=X.indices.dtype)
856856
for new_val, (term, old_val) in enumerate(sorted_features):
857857
vocabulary[term] = new_val
858858
map_index[old_val] = new_val

0 commit comments

Comments
 (0)
0