diff --git a/sklearn/metrics/_pairwise_distances_reduction/_base.pyx.tp b/sklearn/metrics/_pairwise_distances_reduction/_base.pyx.tp index 6a4d879667d3a..bab9952e22e2d 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_base.pyx.tp +++ b/sklearn/metrics/_pairwise_distances_reduction/_base.pyx.tp @@ -32,6 +32,7 @@ from sklearn import get_config from sklearn.utils import check_scalar from ...utils._openmp_helpers import _openmp_effective_n_threads from ...utils._typedefs import DTYPE, SPARSE_INDEX_TYPE +from ...utils.sparsefuncs_fast import _sqeuclidean_row_norms_sparse cnp.import_array() @@ -103,23 +104,6 @@ cdef DTYPE_t[::1] _sqeuclidean_row_norms32_dense( return squared_row_norms -cdef DTYPE_t[::1] _sqeuclidean_row_norms64_sparse( - const DTYPE_t[:] X_data, - const SPARSE_INDEX_TYPE_t[:] X_indptr, - ITYPE_t num_threads, -): - cdef: - ITYPE_t n = X_indptr.shape[0] - 1 - SPARSE_INDEX_TYPE_t X_i_ptr, idx = 0 - DTYPE_t[::1] squared_row_norms = np.zeros(n, dtype=DTYPE) - - for idx in prange(n, schedule='static', nogil=True, num_threads=num_threads): - for X_i_ptr in range(X_indptr[idx], X_indptr[idx+1]): - squared_row_norms[idx] += X_data[X_i_ptr] * X_data[X_i_ptr] - - return squared_row_norms - - {{for name_suffix, INPUT_DTYPE_t, INPUT_DTYPE in implementation_specific_values}} from ._datasets_pair cimport DatasetsPair{{name_suffix}} @@ -131,10 +115,10 @@ cpdef DTYPE_t[::1] _sqeuclidean_row_norms{{name_suffix}}( ): if issparse(X): # TODO: remove this instruction which is a cast in the float32 case - # by moving squared row norms computations in MiddleTermComputer. + # by moving squared row norms computations in MiddleTermComputer. X_data = np.asarray(X.data, dtype=DTYPE) X_indptr = np.asarray(X.indptr, dtype=SPARSE_INDEX_TYPE) - return _sqeuclidean_row_norms64_sparse(X_data, X_indptr, num_threads) + return _sqeuclidean_row_norms_sparse(X_data, X_indptr, num_threads) else: return _sqeuclidean_row_norms{{name_suffix}}_dense(X, num_threads)