10000 Rewrite the 64 bit index support of CSR arrays · scikit-learn/scikit-learn@8d302e4 · GitHub
[go: up one dir, main page]

Skip to content

Commit 8d302e4

Browse files
committed
Rewrite the 64 bit index support of CSR arrays
1 parent 564f8b7 commit 8d302e4

File tree

2 files changed

+44
-7
lines changed

2 files changed

+44
-7
lines changed

sklearn/feature_extraction/_hashing.pyx

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Author: Lars Buitinck
22
# License: BSD 3 clause
33

4+
import sys
45
import array
56
from cpython cimport array
67
cimport cython
@@ -9,6 +10,7 @@ cimport numpy as np
910
import numpy as np
1011

1112
from sklearn.utils.murmurhash cimport murmurhash3_bytes_s32
13+
from sklearn.utils.fixes import sp_version
1214

1315
np.import_array()
1416

@@ -33,12 +35,20 @@ def transform(raw_X, Py_ssize_t n_features, dtype, bint alternate_sign=1):
3335
cdef array.array indices
3436
cdef array.array indptr
3537
indices = array.array("i")
36-
indptr = array.array("i", [0])
38+
if sys.version_info >= (3, 3):
39+
indices_array_dtype = "q"
40+
indices_np_dtype = np.longlong
41+
else:
42+
# On Windows with PY2.7 long int would still correspond to 32 bit.
43+
indices_array_dtype = "l"
44+
indices_np_dtype = np.int_
45+
46+
indptr = array.array(indices_array_dtype, [0])
3747

3848
# Since Python array does not understand Numpy dtypes, we grow the indices
3949
# and values arrays ourselves. Use a Py_ssize_t capacity for safety.
4050
cdef Py_ssize_t capacity = 8192 # arbitrary
41-
cdef np.int32_t size = 0
51+
cdef np.int64_t size = 0
4252
cdef np.ndarray values = np.empty(capacity, dtype=dtype)
4353

4454
for x in raw_X:
@@ -79,4 +89,18 @@ def transform(raw_X, Py_ssize_t n_features, dtype, bint alternate_sign=1):
7989
indptr[len(indptr) - 1] = size
8090

8191
indices_a = np.frombuffer(indices, dtype=np.int32)
82-
return (indices_a, np.frombuffer(indptr, dtype=np.int32), values[:size])
92+
indptr_a = np.frombuffer(indptr, dtype=indices_np_dtype)
93+
94+
if indptr[-1] > 2147483648: # = 2**31
95+
if sp_version < (0, 14):
96+
raise ValueError(('sparse CSR array has {} non-zero '
97+
'elements and requires 64 bit indexing, '
98+
' which is unsupported with scipy {}. '
99+
'Please upgrade to scipy >=0.14')
100+
.format(indptr[-1], '.'.join(sp_version)))
101+
# both indices and indptr have the same dtype in CSR arrays
102+
indices_a = indices_a.astype(np.int64)
103+
else:
104+
indptr_a = indptr_a.astype(np.int32)
105+
106+
return (indices_a, indptr_a, values[:size])

sklearn/feature_extraction/text.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from .hashing import FeatureHasher
3131
from .stop_words import ENGLISH_STOP_WORDS
3232
from ..utils.validation import check_is_fitted
33+
from ..utils.fixes import sp_version
3334

3435
__all__ = ['CountVectorizer',
3536
'ENGLISH_STOP_WORDS',
@@ -784,8 +785,8 @@ def _count_vocab(self, raw_documents, fixed_vocab):
784785

785786
analyze = self.build_analyzer()
786787
j_indices = []
787-
# indptr stores indices into j_indices, which can be large
788-
indptr = _make_int_array(dtype='l')
788+
indptr = []
789+
789790
values = _make_int_array()
790791
indptr.append(0)
791792
for doc in raw_documents:
@@ -812,8 +813,20 @@ def _count_vocab(self, raw_documents, fixed_vocab):
812813
raise ValueError("empty vocabulary; perhaps the documents only"
813814
" contain stop words")
814815

815-
j_indices = np.asarray(j_indices, dtype=np.intc)
816-
indptr = np.frombuffer(indptr, dtype=np.int_)
816+
if indptr[-1] > 2147483648: # = 2**31 - 1
817+
if sp_version >= (0, 14):
818+
indices_dtype = np.int64
819+
else:
820+
raise ValueError(('sparse CSR array has {} non-zero '
821+
'elements and requires 64 bit indexing, '
822+
' which is unsupported with scipy {}. '
823+
'Please upgrade to scipy >=0.14')
824+
.format(indptr[-1], '.'.join(sp_version)))
825+
826+
else:
827+
indices_dtype = np.int32
828+
j_indices = np.asarray(j_indices, dtype=indices_dtype)
829+
indptr = np.asarray(indptr, dtype=indices_dtype)
817830
values = np.frombuffer(values, dtype=np.intc)
818831

819832
X = sp.csr_matrix((values, j_indices, indptr),

0 commit comments

Comments
 (0)
0