8000 fixes an issue discussed in #7762 · scikit-learn/scikit-learn@317a169 · GitHub
[go: up one dir, main page]

Skip to content

Commit 317a169

Browse files
committed
fixes an issue discussed in #7762
* added test to show example of issue in #7762 * fixes error caused by manually manipulating sparse indices in `CountVectorizer`
1 parent d990f72 commit 317a169

File tree

2 files changed

+25
-1
lines changed

2 files changed

+25
-1
lines changed

sklearn/feature_extraction/tests/test_text.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1140,3 +1140,27 @@ def test_vectorizer_stop_words_inconsistent():
11401140
vec.set_params(stop_words=["you've", "you", "you'll", 'blah', 'AND'])
11411141
assert_warns_message(UserWarning, message, vec.fit_transform,
11421142
['hello world'])
1143+
1144+
1145+
def test_countvectorizer_sort_features_64bit_sparse_indices():
1146+
"""
1147+
Check that CountVectorizer._sort_features preserves the dtype of its sparse
1148+
feature matrix.
1149+
"""
1150+
1151+
X = sparse.csr_matrix((5, 5), dtype=np.int64)
1152+
1153+
# force indices and indptr to int64.
1154+
INDICES_DTYPE = np.int64
1155+
X.indices = X.indices.astype(INDICES_DTYPE)
1156+
X.indptr = X.indptr.astype(INDICES_DTYPE)
1157+
1158+
vocabulary = {
1159+
"scikit-learn": 0,
1160+
"is": 1,
1161+
"great!": 2
1162+
}
1163+
1164+
Xs = CountVectorizer()._sort_features(X, vocabulary)
1165+
1166+
assert INDICES_DTYPE == Xs.indices.dtype

sklearn/feature_extraction/text.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -852,7 +852,7 @@ def _sort_features(self, X, vocabulary):
852852
Returns a reordered matrix and modifies the vocabulary in place
853853
"""
854854
sorted_features = sorted(six.iteritems(vocabulary))
855-
map_index = np.empty(len(sorted_features), dtype=np.int32)
855+
map_index = np.empty(len(sorted_features), dtype=X.indices.dtype)
856856
for new_val, (term, old_val) in enumerate(sorted_features):
857857
vocabulary[term] = new_val
858858
map_index[old_val] = new_val

0 commit comments

Comments
 (0)
0