8000 ENH Let csr_row_norms support multi-thread (#25598) · scikit-learn/scikit-learn@de67a44 · GitHub
[go: up one dir, main page]

Skip to content

Commit de67a44

Browse files
ArturoAmorQjeremiedbbVincent-Maladiere
authored
ENH Let csr_row_norms support multi-thread (#25598)
Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com> Co-authored-by: Vincent M <maladiere.vincent@yahoo.fr>
1 parent ae4a1b1 commit de67a44

File tree

1 file changed

+12
-9
lines changed

1 file changed

+12
-9
lines changed

sklearn/utils/sparsefuncs_fast.pyx

+12-9
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,11 @@ from libc.math cimport fabs, sqrt
1212
cimport numpy as cnp
1313
import numpy as np
1414
from cython cimport floating
15+
from cython.parallel cimport prange
1516
from numpy.math cimport isnan
1617

18+
from sklearn.utils._openmp_helpers import _openmp_effective_n_threads
19+
1720
cnp.import_array()
1821

1922
ctypedef fused integral:
@@ -27,13 +30,14 @@ def csr_row_norms(X):
2730
"""Squared L2 norm of each row in CSR matrix X."""
2831
if X.dtype not in [np.float32, np.float64]:
2932
X = X.astype(np.float64)
30-
return _csr_row_norms(X.data, X.indices, X.indptr)
33+
n_threads = _openmp_effective_n_threads()
34+
return _sqeuclidean_row_norms_sparse(X.data, X.indptr, n_threads)
3135

3236

33-
def _csr_row_norms(
37+
def _sqeuclidean_row_norms_sparse(
3438
const floating[::1] X_data,
35-
const integral[::1] X_indices,
3639
const integral[::1] X_indptr,
40+
int n_threads,
3741
):
3842
cdef:
3943
integral n_samples = X_indptr.shape[0] - 1
@@ -42,14 +46,13 @@ def _csr_row_norms(
4246

4347
dtype = np.float32 if floating is float else np.float64
4448

45-
cdef floating[::1] norms < A23C span class="pl-k">= np.zeros(n_samples, dtype=dtype)
49+
cdef floating[::1] squared_row_norms = np.zeros(n_samples, dtype=dtype)
4650

47-
with nogil:
48-
for i in range(n_samples):
49-
for j in range(X_indptr[i], X_indptr[i + 1]):
50-
norms[i] += X_data[j] * X_data[j]
51+
for i in prange(n_samples, schedule='static', nogil=True, num_threads=n_threads):
52+
for j in range(X_indptr[i], X_indptr[i + 1]):
53+
squared_row_norms[i] += X_data[j] * X_data[j]
5154

52-
return np.asarray(norms)
55+
return np.asarray(squared_row_norms)
5356

5457

5558
def csr_mean_variance_axis0(X, weights=None, return_sum_weights=False):

0 commit comments

Comments
 (0)
0