8000 Merge pull request #6588 from yenchenlin1994/use-fused-types-in-assig… · scikit-learn/scikit-learn@d7cf4b0 · GitHub
[go: up one dir, main page]

Skip to content

Commit d7cf4b0

Browse files
committed
Merge pull request #6588 from yenchenlin1994/use-fused-types-in-assign_rows_csr
[MRG+1] Make assign_rows_csr support Cython fused types
2 parents e5c366f + 42d49c8 commit d7cf4b0

File tree

2 files changed

+19
-17
lines changed

2 files changed

+19
-17
lines changed

sklearn/utils/sparsefuncs_fast.pyx

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ cimport numpy as np
1111
import numpy as np
1212
import scipy.sparse as sp
1313
cimport cython
14+
from cython cimport floating
1415

1516
np.import_array()
1617

@@ -360,11 +361,11 @@ cdef void add_row_csr(np.ndarray[np.float64_t, ndim=1] data,
360361
def assign_rows_csr(X,
361362
np.ndarray[np.npy_intp, ndim=1] X_rows,
362363
np.ndarray[np.npy_intp, ndim=1] out_rows,
363-
np.ndarray[np.float64_t, ndim=2, mode="c"] out):
364+
np.ndarray[floating, ndim=2, mode="c"] out):
364365
"""Densify selected rows of a CSR matrix into a preallocated array.
365366
366367
Like out[out_rows] = X[X_rows].toarray() but without copying.
367-
Only supported for dtype=np.float64.
368+
No-copy supported for both dtype=np.float32 and dtype=np.float64.
368369
369370
Parameters
370371
----------
@@ -378,7 +379,7 @@ def assign_rows_csr(X,
378379
# but int is what scipy.sparse uses.
379380
int i, ind, j
380381
np.npy_intp rX
381-
np.ndarray[DOUBLE, ndim=1] data = X.data
382+
np.ndarray[floating, ndim=1] data = X.data
382383
np.ndarray[int, ndim=1] indices = X.indices, indptr = X.indptr
383384

384385
if X_rows.shape[0] != out_rows.shape[0]:

sklearn/utils/tests/test_sparsefuncs.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -179,20 +179,21 @@ def test_mean_variance_illegal_axis():
179179

180180

181181
def test_densify_rows():
182-
X = sp.csr_matrix([[0, 3, 0],
183-
[2, 4, 0],
184-
[0, 0, 0],
185-
[9, 8, 7],
186-
[4, 0, 5]], dtype=np.float64)
187-
X_rows = np.array([0, 2, 3], dtype=np.intp)
188-
out = np.ones((6, X.shape[1]), dtype=np.float64)
189-
out_rows = np.array([1, 3, 4], dtype=np.intp)
190-
191-
expect = np.ones_like(out)
192-
expect[out_rows] = X[X_rows, :].toarray()
193-
194-
assign_rows_csr(X, X_rows, out_rows, out)
195-
assert_array_equal(out, expect)
182+
for dtype in (np.float32, np.float64):
183+
X = sp.csr_matrix([[0, 3, 0],
184+
[2, 4, 0],
185+
[0, 0, 0],
186+
[9, 8, 7],
187+
[4, 0, 5]], dtype=dtype)
188+
X_rows = np.array([0, 2, 3], dtype=np.intp)
189+
out = np.ones((6, X.shape[1]), dtype=dtype)
190+
out_rows = np.array([1, 3, 4], dtype=np.intp)
191+
192+
expect = np.ones_like(out)
193+
expect[out_rows] = X[X_rows, :].toarray()
194+
195+
assign_rows_csr(X, X_rows, out_rows, out)
196+
assert_array_equal(out, expect)
196197

197198

198199
def test_inplace_column_scale():

0 commit comments

Comments
 (0)
0