@@ -12,8 +12,11 @@ from libc.math cimport fabs, sqrt
12
12
cimport numpy as cnp
13
13
import numpy as np
14
14
from cython cimport floating
15
+ from cython.parallel cimport prange
15
16
from numpy.math cimport isnan
16
17
18
+ from sklearn.utils._openmp_helpers import _openmp_effective_n_threads
19
+
17
20
cnp.import_array()
18
21
19
22
ctypedef fused integral:
@@ -27,13 +30,14 @@ def csr_row_norms(X):
27
30
""" Squared L2 norm of each row in CSR matrix X."""
28
31
if X.dtype not in [np.float32, np.float64]:
29
32
X = X.astype(np.float64)
30
- return _csr_row_norms(X.data, X.indices, X.indptr)
33
+ n_threads = _openmp_effective_n_threads()
34
+ return _sqeuclidean_row_norms_sparse(X.data, X.indptr, n_threads)
31
35
32
36
33
- def _csr_row_norms (
37
+ def _sqeuclidean_row_norms_sparse (
34
38
const floating[::1] X_data ,
35
- const integral[::1] X_indices ,
36
39
const integral[::1] X_indptr ,
40
+ int n_threads ,
37
41
):
38
42
cdef:
39
43
integral n_samples = X_indptr.shape[0 ] - 1
@@ -42,14 +46,13 @@ def _csr_row_norms(
42
46
43
47
dtype = np.float32 if floating is float else np.float64
44
48
45
- cdef floating[::1 ] norms <
A23C
span class="pl-k">= np.zeros(n_samples, dtype = dtype)
49
+ cdef floating[::1 ] squared_row_norms = np.zeros(n_samples, dtype = dtype)
46
50
47
- with nogil:
48
- for i in range (n_samples):
49
- for j in range (X_indptr[i], X_indptr[i + 1 ]):
50
- norms[i] += X_data[j] * X_data[j]
51
+ for i in prange(n_samples, schedule = ' static' , nogil = True , num_threads = n_threads):
52
+ for j in range (X_indptr[i], X_indptr[i + 1 ]):
53
+ squared_row_norms[i] += X_data[j] * X_data[j]
51
54
52
- return np.asarray(norms )
55
+ return np.asarray(squared_row_norms )
53
56
54
57
55
58
def csr_mean_variance_axis0 (X , weights = None , return_sum_weights = False ):
0 commit comments