8000 Extend changes to the HashingVectorizer · scikit-learn/scikit-learn@52c77cd · GitHub
[go: up one dir, main page]

Skip to content

Commit 52c77cd

Browse files
committed
Extend changes to the HashingVectorizer
1 parent 558f3ee commit 52c77cd

File tree

2 files changed

+26
-6
lines changed

2 files changed

+26
-6
lines changed

sklearn/feature_extraction/_hashing.pyx

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ cimport numpy as np
99
import numpy as np
1010

1111
from sklearn.utils.murmurhash cimport murmurhash3_bytes_s32
12+
from sklearn.utils.fixes import sp_version
1213

1314
np.import_array()
1415

@@ -33,12 +34,12 @@ def transform(raw_X, Py_ssize_t n_features, dtype, bint alternate_sign=1):
3334
cdef array.array indices
3435
cdef array.array indptr
3536
indices = array.array("i")
36-
indptr = array.array("i", [0])
37+
indptr = array.array("l", [0])
3738

3839
# Since Python array does not understand Numpy dtypes, we grow the indices
3940
# and values arrays ourselves. Use a Py_ssize_t capacity for safety.
4041
cdef Py_ssize_t capacity = 8192 # arbitrary
41-
cdef np.int32_t size = 0
42+
cdef np.intp_t size = 0
4243
cdef np.ndarray values = np.empty(capacity, dtype=dtype)
4344

4445
for x in raw_X:
@@ -79,4 +80,10 @@ def transform(raw_X, Py_ssize_t n_features, dtype, bint alternate_sign=1):
7980
indptr[len(indptr) - 1] = size
8081

8182
indices_a = np.frombuffer(indices, dtype=np.int32)
82-
return (indices_a, np.frombuffer(indptr, dtype=np.int32), values[:size])
83+
indptr_a = np.frombuffer(indptr, dtype=np.int64)
84+
85+
if indptr[-1] > 2147483648: # = 2**31
86+
indices_a = indices_a.astype(np.int64)
87+
else:
88+
indptr_a = indptr_a.astype(np.int32)
89+
return (indices_a, indptr_a, values[:size])

sklearn/feature_extraction/text.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from .hashing import FeatureHasher
3131
from .stop_words import ENGLISH_STOP_WORDS
3232
from ..utils.validation import check_is_fitted
33+
from ..utils.fixes import sp_version
3334

3435
__all__ = ['CountVectorizer',
3536
'ENGLISH_STOP_WORDS',
@@ -762,7 +763,7 @@ def _count_vocab(self, raw_documents, fixed_vocab):
762763

763764
analyze = self.build_analyzer()
764765
j_indices = []
765-
# indptr stores indices into j_indices, which can be large
766+
# indptr < D4CF span class="x x-first x-last">can overflow in 32 bit, always use 64 bit
766767
indptr = _make_int_array(dtype='l')
767768
values = _make_int_array()
768769
indptr.append(0)
@@ -790,8 +791,20 @@ def _count_vocab(self, raw_documents, fixed_vocab):
790791
raise ValueError("empty vocabulary; perhaps the documents only"
791792
" contain stop words")
792793

793-
j_indices = np.asarray(j_indices, dtype=np.intc)
794-
indptr = np.frombuffer(indptr, dtype=np.int_)
794+
if indptr[-1] > 2147483648: # = 2**31 - 1
795+
if sp_version >= (0, 14):
796+
indices_dtype = np.int_
797+
else:
798+
raise ValueError(('sparse CSR array has {} non-zero '
799+
'elements and require 64 bit indexing, '
800+
' which is unsupported with scipy {}. '
801+
'Please upgrade to scipy >=0.14')
802+
.format(indptr[-1], '.'.join(sp_version)))
803+
804+
else:
805+
indices_dtype = np.intc
806+
j_indices = np.asarray(j_indices, dtype=indices_dtype)
807+
indptr = np.frombuffer(indptr, dtype=np.int_).astype(indices_dtype)
795808
values = np.frombuffer(values, dtype=np.intc)
796809

797810
X = sp.csr_matrix((values, j_indices, indptr),

0 commit comments

Comments
 (0)
0