8000 PERF revert openmp use in csr_row_norms (#26275) · scikit-learn/scikit-learn@092caed · GitHub
[go: up one dir, main page]

Skip to content

Commit 092caed

Browse files
authored
PERF revert openmp use in csr_row_norms (#26275)
1 parent 523c135 commit 092caed

File tree

2 files changed

+24
-12
lines changed

2 files changed

+24
-12
lines changed

sklearn/metrics/_pairwise_distances_reduction/_base.pyx.tp

+19-3
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ from libcpp.vector cimport vector
55

66
from ...utils._cython_blas cimport _dot
77
from ...utils._openmp_helpers cimport omp_get_thread_num
8-
from ...utils._typedefs cimport intp_t, float32_t, float64_t
8+
from ...utils._typedefs cimport intp_t, float32_t, float64_t, int32_t
99

1010
import numpy as np
1111

@@ -14,7 +14,6 @@ from numbers import Integral
1414
from sklearn import get_config
1515
from sklearn.utils import check_scalar
1616
from ...utils._openmp_helpers import _openmp_effective_n_threads
17-
from ...utils.sparsefuncs_fast import _sqeuclidean_row_norms_sparse
1817

1918
#####################
2019

@@ -84,6 +83,23 @@ cdef float64_t[::1] _sqeuclidean_row_norms32_dense(
8483
return squared_row_norms
8584

8685

86+
cdef float64_t[::1] _sqeuclidean_row_norms64_sparse(
87+
const float64_t[:] X_data,
88+
const int32_t[:] X_indptr,
89+
intp_t num_threads,
90+
):
91+
cdef:
92+
intp_t n = X_indptr.shape[0] - 1
93+
int32_t X_i_ptr, idx = 0
94+
float64_t[::1] squared_row_norms = np.zeros(n, dtype=np.float64)
95+
96+
for idx in prange(n, schedule='static', nogil=True, num_threads=num_threads):
97+
for X_i_ptr in range(X_indptr[idx], X_indptr[idx+1]):
98+
squared_row_norms[idx] += X_data[X_i_ptr] * X_data[X_i_ptr]
99+
100+
return squared_row_norms
101+
102+
87103
{{for name_suffix in ["64", "32"]}}
88104

89105
from ._datasets_pair cimport DatasetsPair{{name_suffix}}
@@ -98,7 +114,7 @@ cpdef float64_t[::1] _sqeuclidean_row_norms{{name_suffix}}(
98114
# by moving squared row norms computations in MiddleTermComputer.
99115
X_data = np.asarray(X.data, dtype=np.float64)
100116
X_indptr = np.asarray(X.indptr, dtype=np.int32)
101-
return _sqeuclidean_row_norms_sparse(X_data, X_indptr, num_threads)
117+
return _sqeuclidean_row_norms64_sparse(X_data, X_indptr, num_threads)
102118
else:
103119
return _sqeuclidean_row_norms{{name_suffix}}_dense(X, num_threads)
104120

sklearn/utils/sparsefuncs_fast.pyx

+5-9
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,6 @@ from libc.math cimport fabs, sqrt, isnan
1111
cimport numpy as cnp
1212
import numpy as np
1313
from cython cimport floating
14-
from cython.parallel cimport prange
15-
16-
from sklearn.utils._openmp_helpers import _openmp_effective_n_threads
1714

1815
cnp.import_array()
1916

@@ -28,14 +25,12 @@ def csr_row_norms(X):
2825
"""Squared L2 norm of each row in CSR matrix X."""
2926
if X.dtype not in [np.float32, np.float64]:
3027
X = X.astype(np.float64)
31-
n_threads = _openmp_effective_n_threads()
32-
return _sqeuclidean_row_norms_sparse(X.data, X.indptr, n_threads)
28+
return _sqeuclidean_row_norms_sparse(X.data, X.indptr)
3329

3430

3531
def _sqeuclidean_row_norms_sparse(
3632
const floating[::1] X_data,
3733
const integral[::1] X_indptr,
38-
int n_threads,
3934
):
4035
cdef:
4136
integral n_samples = X_indptr.shape[0] - 1
@@ -45,9 +40,10 @@ def _sqeuclidean_row_norms_sparse(
4540

4641
cdef floating[::1] squared_row_norms = np.zeros(n_samples, dtype=dtype)
4742

48-
for i in prange(n_samples, schedule='static', nogil=True, num_threads=n_threads):
49-
for j in range(X_indptr[i], X_indptr[i + 1]):
50-
squared_row_norms[i] += X_data[j] * X_data[j]
43+
with nogil:
44+
for i in range(n_samples):
45+
for j in range(X_indptr[i], X_indptr[i + 1]):
46+
squared_row_norms[i] += X_data[j] * X_data[j]
5147

5248
return np.asarray(squared_row_norms)
5349

0 commit comments

Comments
 (0)
0