@@ -32,6 +32,7 @@ from sklearn import get_config
32
32
from sklearn.utils import check_scalar
33
33
from ...utils._openmp_helpers import _openmp_effective_n_threads
34
34
from ...utils._typedefs import DTYPE, SPARSE_INDEX_TYPE
35
+ from ...utils.sparsefuncs_fast import _sqeuclidean_row_norms_sparse
35
36
36
37
cnp.import_array()
37
38
@@ -103,23 +104,6 @@ cdef DTYPE_t[::1] _sqeuclidean_row_norms32_dense(
103
104
return squared_row_norms
104
105
105
106
106
- cdef DTYPE_t[::1] _sqeuclidean_row_norms64_sparse(
107
- const DTYPE_t[:] X_data,
108
- const SPARSE_INDEX_TYPE_t[:] X_indptr,
109
- ITYPE_t num_threads,
110
- ):
111
- cdef:
112
- ITYPE_t n = X_indptr.shape[0] - 1
113
- SPARSE_INDEX_TYPE_t X_i_ptr, idx = 0
114
- DTYPE_t[::1] squared_row_norms = np.zeros(n, dtype=DTYPE)
115
-
116
- for idx in prange(n, schedule='static', nogil=True, num_threads=num_threads):
117
- for X_i_ptr in range(X_indptr[idx], X_indptr[idx+1]):
118
- squared_row_norms[idx] += X_data[X_i_ptr] * X_data[X_i_ptr]
119
-
120
- return squared_row_norms
121
-
122
-
123
107
{{for name_suffix, INPUT_DTYPE_t, INPUT_DTYPE in implementation_specific_values}}
124
108
125
109
from ._datasets_pair cimport DatasetsPair{{name_suffix}}
@@ -131,10 +115,10 @@ cpdef DTYPE_t[::1] _sqeuclidean_row_norms{{name_suffix}}(
131
115
):
132
116
if issparse(X):
133
117
# TODO: remove this instruction which is a cast in the float32 case
134
- # by moving squared row norms computations in MiddleTermComputer.
118
+ # by moving squared row norms computations in MiddleTermComputer.
135
119
X_data = np.asarray(X.data, dtype=DTYPE)
136
120
X_indptr = np.asarray(X.indptr, dtype=SPARSE_INDEX_TYPE)
137
- return _sqeuclidean_row_norms64_sparse(X_data, X_indptr, num_threads)
121
+ return _sqeuclidean_row_norms_sparse (X_data, X_indptr, num_threads)
138
122
else:
139
123
return _sqeuclidean_row_norms{{name_suffix}}_dense(X, num_threads)
140
124
0 commit comments