8000 ENH Add 64 bit indices support in csr_row_norms and inplace L2/L1 csr… · scikit-learn/scikit-learn@d551227 · GitHub
[go: up one dir, main page]

Skip to content

Commit d551227

Browse files
rthjnothman
authored andcommitted
ENH Add 64 bit indices support in csr_row_norms and inplace L2/L1 csr norm (#9663)
1 parent 2c1d079 commit d551227

File tree

3 files changed

+43
-25
lines changed

3 files changed

+43
-25
lines changed

sklearn/utils/sparsefuncs_fast.pyx

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@ from cython cimport floating
1818

1919
np.import_array()
2020

21+
ctypedef fused integral:
22+
int
23+
long long
2124

2225
ctypedef np.float64_t DOUBLE
2326

@@ -30,11 +33,11 @@ def csr_row_norms(X):
3033

3134
def _csr_row_norms(np.ndarray[floating, ndim=1, mode="c"] X_data,
3235
shape,
33-
np.ndarray[int, ndim=1, mode="c"] X_indices,
34-
np.ndarray[int, ndim=1, mode="c"] X_indptr):
36+
np.ndarray[integral, ndim=1, mode="c"] X_indices,
37+
np.ndarray[integral, ndim=1, mode="c"] X_indptr):
3538
cdef:
36-
unsigned int n_samples = shape[0]
37-
unsigned int n_features = shape[1]
39+
unsigned long long n_samples = shape[0]
40+
unsigned long long n_features = shape[1]
3841
np.ndarray[DOUBLE, ndim=1, mode="c"] norms
3942

4043
np.npy_intp i, j
@@ -326,17 +329,16 @@ def inplace_csr_row_normalize_l1(X):
326329

327330
def _inplace_csr_row_normalize_l1(np.ndarray[floating, ndim=1] X_data,
328331
shape,
329-
np.ndarray[int, ndim=1] X_indices,
330-
np.ndarray[int, ndim=1] X_indptr):
331-
cdef unsigned int n_samples = shape[0]
332-
cdef unsigned int n_features = shape[1]
332+
np.ndarray[integral, ndim=1] X_indices,
333+
np.ndarray[integral, ndim=1] X_indptr):
334+
cdef unsigned long long n_samples = shape[0]
335+
cdef unsigned long long n_features = shape[1]
333336

334337
# the column indices for row i are stored in:
335338
# indices[indptr[i]:indices[i+1]]
336339
# and their corresponding values are stored in:
337340
# data[indptr[i]:indptr[i+1]]
338-
cdef unsigned int i
339-
cdef unsigned int j
341+
cdef np.npy_intp i, j
340342
cdef double sum_
341343

342344
for i in xrange(n_samples):
@@ -361,13 +363,12 @@ def inplace_csr_row_normalize_l2(X):
361363

362364
def _inplace_csr_row_normalize_l2(np.ndarray[floating, ndim=1] X_data,
363365
shape,
364-
np.ndarray[int, ndim=1] X_indices,
365-
np.ndarray[int, ndim=1] X_indptr):
366-
cdef unsigned int n_samples = shape[0]
367-
cdef unsigned int n_features = shape[1]
366+
np.ndarray[integral, ndim=1] X_indices,
367+
np.ndarray[integral, ndim=1] X_indptr):
368+
cdef integral n_samples = shape[0]
369+
cdef integral n_features = shape[1]
368370

369-
cdef unsigned int i
370-
cdef unsigned int j
371+
cdef np.npy_intp i, j
371372
cdef double sum_
372373

373374
for i in xrange(n_samples):

sklearn/utils/tests/test_extmath.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -206,10 +206,19 @@ def test_row_norms():
206206
precision)
207207
assert_array_almost_equal(np.sqrt(sq_norm), row_norms(X), precision)
208208

209-
Xcsr = sparse.csr_matrix(X, dtype=dtype)
210-
assert_array_almost_equal(sq_norm, row_norms(Xcsr, squared=True),
211-
precision)
212-
assert_array_almost_equal(np.sqrt(sq_norm), row_norms(Xcsr), precision)
209+
for csr_index_dtype in [np.int32, np.int64]:
210+
Xcsr = sparse.csr_matrix(X, dtype=dtype)
211+
# csr_matrix will use int32 indices by default,
212+
# up-casting those to int64 when necessary
213+
if csr_index_dtype is np.int64:
214+
Xcsr.indptr = Xcsr.indptr.astype(csr_index_dtype)
215+
Xcsr.indices = Xcsr.indices.astype(csr_index_dtype)
216+
assert Xcsr.indices.dtype == csr_index_dtype
217+
assert Xcsr.indptr.dtype == csr_index_dtype
218+
assert_array_almost_equal(sq_norm, row_norms(Xcsr, squared=True),
219+
precision)
220+
assert_array_almost_equal(np.sqrt(sq_norm), row_norms(Xcsr),
221+
precision)
213222

214223

215224
def test_randomized_svd_low_rank_with_noise():

sklearn/utils/tests/test_sparsefuncs.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -478,8 +478,16 @@ def test_inplace_normalize():
478478
for dtype in (np.float64, np.float32):
479479
X = rs.randn(10, 5).astype(dtype)
480480
X_csr = sp.csr_matrix(X)
481-
inplace_csr_row_normalize(X_csr)
482-
assert_equal(X_csr.dtype, dtype)
483-
if inplace_csr_row_normalize is inplace_csr_row_normalize_l2:
484-
X_csr.data **= 2
485-
assert_array_almost_equal(np.abs(X_csr).sum(axis=1), ones)
481+
for index_dtype in [np.int32, np.int64]:
482+
# csr_matrix will use int32 indices by default,
483+
# up-casting those to int64 when necessary
484+
if index_dtype is np.int64:
485+
X_csr.indptr = X_csr.indptr.astype(index_dtype)
486+
X_csr.indices = X_csr.indices.astype(index_dtype)
487+
assert X_csr.indices.dtype == index_dtype
488+
assert X_csr.indptr.dtype == index_dtype
489+
inplace_csr_row_normalize(X_csr)
490+
assert_equal(X_csr.dtype, dtype)
491+
if inplace_csr_row_normalize is inplace_csr_row_normalize_l2:
492+
X_csr.data **= 2
493+
assert_array_almost_equal(np.abs(X_csr).sum(axis=1), ones)

0 commit comments

Comments
 (0)
0