8000 MAINT Group vector fixtures together · scikit-learn/scikit-learn@c4d8c4a · GitHub
[go: up one dir, main page]

Skip to content

Commit c4d8c4a

Browse files
committed
MAINT Group vector fixtures together
1 parent 390f624 commit c4d8c4a

File tree

1 file changed

+25
-24
lines changed

1 file changed

+25
-24
lines changed

sklearn/metrics/_pairwise_distances_reduction.pyx

Lines changed: 25 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -108,30 +108,6 @@ cdef class StdVectorSentinelITYPE(StdVectorSentinel):
108108
return sentinel
109109

110110

111-
cpdef DTYPE_t[::1] _sqeuclidean_row_norms(
112-
const DTYPE_t[:, ::1] X,
113-
ITYPE_t num_threads,
114-
):
115-
"""Compute the squared euclidean norm of the rows of X in parallel.
116-
117-
This is faster than using np.einsum("ij, ij->i") even when using a single thread.
118 10000 -
"""
119-
cdef:
120-
# Casting for X to remove the const qualifier is needed because APIs
121-
# exposed via scipy.linalg.cython_blas aren't reflecting the arguments'
122-
# const qualifier.
123-
# See: https://github.com/scipy/scipy/issues/14262
124-
DTYPE_t * X_ptr = <DTYPE_t *> &X[0, 0]
125-
ITYPE_t idx = 0
126-
ITYPE_t n = X.shape[0]
127-
ITYPE_t d = X.shape[1]
128-
DTYPE_t[::1] squared_row_norms = np.empty(n, dtype=DTYPE)
129-
130-
for idx in prange(n, schedule='static', nogil=True, num_threads=num_threads):
131-
squared_row_norms[idx] = _dot(d, X_ptr + idx * d, 1, X_ptr + idx * d, 1)
132-
133-
return squared_row_norms
134-
135111
cdef np.ndarray vector_to_nd_array(vector_DITYPE_t * vect_ptr):
136112
"""Create a numpy ndarray given a C++ vector.
137113
@@ -176,6 +152,31 @@ cdef np.ndarray[object, ndim=1] coerce_vectors_to_nd_arrays(
176152

177153
#####################
178154

155+
cpdef DTYPE_t[::1] _sqeuclidean_row_norms(
156+
const DTYPE_t[:, ::1] X,
157+
ITYPE_t num_threads,
158+
):
159+
"""Compute the squared euclidean norm of the rows of X in parallel.
160+
161+
This is faster than using np.einsum("ij, ij->i") even when using a single thread.
162+
"""
163+
cdef:
164+
# Casting for X to remove the const qualifier is needed because APIs
165+
# exposed via scipy.linalg.cython_blas aren't reflecting the arguments'
166+
# const qualifier.
167+
# See: https://github.com/scipy/scipy/issues/14262
168+
DTYPE_t * X_ptr = <DTYPE_t *> &X[0, 0]
169+
ITYPE_t idx = 0
170+
ITYPE_t n = X.shape[0]
171+
ITYPE_t d = X.shape[1]
172+
DTYPE_t[::1] squared_row_norms = np.empty(n, dtype=DTYPE)
173+
174+
for idx in prange(n, schedule='static', nogil=True, num_threads=num_threads):
175+
squared_row_norms[idx] = _dot(d, X_ptr + idx * d, 1, X_ptr + idx * d, 1)
176+
177+
return squared_row_norms
178+
179+
#####################
179180

180181
cdef class PairwiseDistancesReduction:
181182
"""Abstract base class for pairwise distance computation & reduction.

0 commit comments

Comments
 (0)
0