8000 MAINT Remove redundant sparse square euclidian distances function (#2… · scikit-learn/scikit-learn@ef5c087 · GitHub
[go: up one dir, main page]

Skip to content

Commit ef5c087

Browse files
authored
MAINT Remove redundant sparse square euclidian distances function (#25731)
1 parent 41b4c63 commit ef5c087

File tree

1 file changed

+3
-19
lines changed

1 file changed

+3
-19
lines changed

sklearn/metrics/_pairwise_distances_reduction/_base.pyx.tp

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ from sklearn import get_config
3232
from sklearn.utils import check_scalar
3333
from ...utils._openmp_helpers import _openmp_effective_n_threads
3434
from ...utils._typedefs import DTYPE, SPARSE_INDEX_TYPE
35+
from ...utils.sparsefuncs_fast import _sqeuclidean_row_norms_sparse
3536

3637
cnp.import_array()
3738

@@ -103,23 +104,6 @@ cdef DTYPE_t[::1] _sqeuclidean_row_norms32_dense(
103104
return squared_row_norms
104105

105106

106-
cdef DTYPE_t[::1] _sqeuclidean_row_norms64_sparse(
107-
const DTYPE_t[:] X_data,
108-
const SPARSE_INDEX_TYPE_t[:] X_indptr,
109-
ITYPE_t num_threads,
110-
):
111-
cdef:
112-
ITYPE_t n = X_indptr.shape[0] - 1
113-
SPARSE_INDEX_TYPE_t X_i_ptr, idx = 0
114-
DTYPE_t[::1] squared_row_norms = np.zeros(n, dtype=DTYPE)
115-
116-
for idx in prange(n, schedule='static', nogil=True, num_threads=num_threads):
117-
for X_i_ptr in range(X_indptr[idx], X_indptr[idx+1]):
118-
squared_row_norms[idx] += X_data[X_i_ptr] * X_data[X_i_ptr]
119-
120-
return squared_row_norms
121-
122-
123107
{{for name_suffix, INPUT_DTYPE_t, INPUT_DTYPE in implementation_specific_values}}
124108

125109
from ._datasets_pair cimport DatasetsPair{{name_suffix}}
@@ -131,10 +115,10 @@ cpdef DTYPE_t[::1] _sqeuclidean_row_norms{{name_suffix}}(
131115
):
132116
if issparse(X):
133117
# TODO: remove this instruction which is a cast in the float32 case
134-
# by moving squared row norms computations in MiddleTermComputer.
118+
# by moving squared row norms computations in MiddleTermComputer.
135119
X_data = np.asarray(X.data, dtype=DTYPE)
136120
X_indptr = np.asarray(X.indptr, dtype=SPARSE_INDEX_TYPE)
137-
return _sqeuclidean_row_norms64_sparse(X_data, X_indptr, num_threads)
121+
return _sqeuclidean_row_norms_sparse(X_data, X_indptr, num_threads)
138122
else:
139123
return _sqeuclidean_row_norms{{name_suffix}}_dense(X, num_threads)
140124

0 commit comments

Comments
 (0)
0