8000 Do not slice memoryviews in _compute_dist_middle_terms · scikit-learn/scikit-learn@f2e917b · GitHub
[go: up one dir, main page]

Skip to content

Navigation Menu

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit f2e917b

Browse files
committed
Do not slice memoryviews in _compute_dist_middle_terms
See the reasons here: #17299
1 parent 8ddef01 commit f2e917b

File tree

1 file changed

+9
-11
lines changed

1 file changed

+9
-11
lines changed

sklearn/metrics/_pairwise_distances_reduction/_gemm_term_computer.pyx.tp

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -175,8 +175,6 @@ cdef class GEMMTermComputer{{name_suffix}}:
175175
ITYPE_t thread_num,
176176
) nogil:
177177
cdef:
178-
const {{INPUT_DTYPE_t}}[:, ::1] X_c = self.X[X_start:X_end, :]
179-
const {{INPUT_DTYPE_t}}[:, ::1] Y_c = self.Y[Y_start:Y_end, :]
180178
DTYPE_t *dist_middle_terms = self.dist_middle_terms_chunks[thread_num].data()
181179

182180
# Careful: LDA, LDB and LDC are given for F-ordered arrays
@@ -187,9 +185,9 @@ cdef class GEMMTermComputer{{name_suffix}}:
187185
BLAS_Order order = RowMajor
188186
BLAS_Trans ta = NoTrans
189187
BLAS_Trans tb = Trans
190-
ITYPE_t m = X_c.shape[0]
191-
ITYPE_t n = Y_c.shape[0]
192-
ITYPE_t K = X_c.shape[1]
188+
ITYPE_t m = X_end - X_start
189+
ITYPE_t n = Y_end - Y_start
190+
ITYPE_t K = self.n_features
193191
DTYPE_t alpha = - 2.
194192
{{if upcast_to_float64}}
195193
DTYPE_t * A = self.X_c_upcast[thread_num].data()
@@ -198,15 +196,15 @@ cdef class GEMMTermComputer{{name_suffix}}:
198196
# Casting for A and B to remove the const is needed because APIs exposed via
199197
# scipy.linalg.cython_blas aren't reflecting the arguments' const qualifier.
200198
# See: https://github.com/scipy/scipy/issues/14262
201-
DTYPE_t * A = <DTYPE_t *> &X_c[0, 0]
202-
DTYPE_t * B = <DTYPE_t *> &Y_c[0, 0]
199+
DTYPE_t * A = <DTYPE_t *> &self.X[X_start, 0]
200+
DTYPE_t * B = <DTYPE_t *> &self.Y[Y_start, 0]
203201
{{endif}}
204-
ITYPE_t lda = X_c.shape[1]
205-
ITYPE_t ldb = X_c.shape[1]
202+
ITYPE_t lda = self.n_features
203+
ITYPE_t ldb = self.n_features
206204
DTYPE_t beta = 0.
207-
ITYPE_t ldc = Y_c.shape[0]
205+
ITYPE_t ldc = Y_end - Y_start
208206

209-
# dist_middle_terms = `-2 * X_c @ Y_c.T`
207+
# dist_middle_terms = `-2 * X[X_start:X_end] @ Y[Y_start:Y_end].T`
210208
_gemm(order, ta, tb, m, n, K, alpha, A, lda, B, ldb, beta, dist_middle_terms, ldc)
211209

212210
return dist_middle_terms

0 commit comments

Comments
 (0)
0