8000 Make csr_row_norms support fused types · scikit-learn/scikit-learn@65e91f7 · GitHub
[go: up one dir, main page]

Skip to content

Commit 65e91f7

Browse files
committed
Make csr_row_norms support fused types
1 parent 417788c commit 65e91f7

File tree

1 file changed

+13
-8
lines changed

1 file changed

+13
-8
lines changed

sklearn/utils/sparsefuncs_fast.pyx

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,24 +23,29 @@ ctypedef np.float64_t DOUBLE
2323

2424
def csr_row_norms(X):
2525
"""L2 norm of each row in CSR matrix X."""
26+
if X.dtype != np.float32:
27+
X = X.astype(np.float64)
28+
return _csr_row_norms(X.data, X.shape, X.indices, X.indptr)
29+
30+
31+
def _csr_row_norms(np.ndarray[floating, ndim=1, mode="c"] X_data,
32+
shape,
33+
np.ndarray[int, ndim=1, mode="c"] X_indices,
34+
np.ndarray[int, ndim=1, mode="c"] X_indptr):
2635
cdef:
27-
unsigned int n_samples = X.shape[0]
28-
unsigned int n_features = X.shape[1]
36+
unsigned int n_samples = shape[0]
37+
unsigned int n_features = shape[1]
2938
np.ndarray[DOUBLE, ndim=1, mode="c"] norms
30-
np.ndarray[DOUBLE, ndim=1, mode="c"] data
31-
np.ndarray[int, ndim=1, mode="c"] indices = X.indices
32-
np.ndarray[int, ndim=1, mode="c"] indptr = X.indptr
3339

3440
np.npy_intp i, j
3541
double sum_
3642

3743
norms = np.zeros(n_samples, dtype=np.float64)
38-
data = np.asarray(X.data, dtype=np.float64) # might copy!
3944

4045
for i in range(n_samples):
4146
sum_ = 0.0
42-
for j in range(indptr[i], indptr[i + 1]):
43-
sum_ += data[j] * data[j]
47+
for j in range(X_indptr[i], X_indptr[i + 1]):
48+
sum_ += X_data[j] * X_data[j]
4449
norms[i] = sum_
4550

4651
return norms

0 commit comments

Comments
 (0)
0