8000 MAINT `PairwiseDistancesReduction`: Do not slice memoryviews in `_com… · scikit-learn/scikit-learn@ad14f91 · GitHub
[go: up one dir, main page]

Skip to content

Commit ad14f91

Browse files
authored
MAINT PairwiseDistancesReduction: Do not slice memoryviews in _compute_dist_middle_terms (#24715)
1 parent 98024ce commit ad14f91

File tree

1 file changed

+9
-11
lines changed

1 file changed

+9
-11
lines changed

sklearn/metrics/_pairwise_distances_reduction/_middle_term_computer.pyx.tp

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -368,8 +368,6 @@ cdef class DenseDenseMiddleTermComputer{{name_suffix}}(MiddleTermComputer{{name_
368368
ITYPE_t thread_num,
369369
) nogil:
370370
cdef:
371-
const {{INPUT_DTYPE_t}}[:, ::1] X_c = self.X[X_start:X_end, :]
372-
const {{INPUT_DTYPE_t}}[:, ::1] Y_c = self.Y[Y_start:Y_end, :]
373371
DTYPE_t *dist_middle_terms = self.dist_middle_terms_chunks[thread_num].data()
374372

375373
# Careful: LDA, LDB and LDC are given for F-ordered arrays
@@ -380,9 +378,9 @@ cdef class DenseDenseMiddleTermComputer{{name_suffix}}(MiddleTermComputer{{name_
380378
BLAS_Order order = RowMajor
381379
BLAS_Trans ta = NoTrans
382380
BLAS_Trans tb = Trans
383-
ITYPE_t m = X_c.shape[0]
384-
ITYPE_t n = Y_c.shape[0]
385-
ITYPE_t K = X_c.shape[1]
381+
ITYPE_t m = X_end - X_start
382+
ITYPE_t n = Y_end - Y_start
383+
ITYPE_t K = self.n_features
386384
DTYPE_t alpha = - 2.
387385
{{if upcast_to_float64}}
388386
DTYPE_t * A = self.X_c_upcast[thread_num].data()
@@ -391,15 +389,15 @@ cdef class DenseDenseMiddleTermComputer{{name_suffix}}(MiddleTermComputer{{name_
391389
# Casting for A and B to remove the const is needed because APIs exposed via
392390
# scipy.linalg.cython_blas aren't reflecting the arguments' const qualifier.
393391
# See: https://github.com/scipy/scipy/issues/14262
394-
DTYPE_t * A = <DTYPE_t *> &X_c[0, 0]
395-
DTYPE_t * B = <DTYPE_t *> &Y_c[0, 0]
392+
DTYPE_t * A = <DTYPE_t *> &self.X[X_start, 0]
393+
DTYPE_t * B = <DTYPE_t *> &self.Y[Y_start, 0]
396394
{{endif}}
397-
ITYPE_t lda = X_c.shape[1]
398-
ITYPE_t ldb = X_c.shape[1]
395+
ITYPE_t lda = self.n_features
396+
ITYPE_t ldb = self.n_features
399397
DTYPE_t beta = 0.
400-
ITYPE_t ldc = Y_c.shape[0]
398+
ITYPE_t ldc = Y_end - Y_start
401399

402-
# dist_middle_terms = `-2 * X_c @ Y_c.T`
400+
# dist_middle_terms = `-2 * X[X_start:X_end] @ Y[Y_start:Y_end].T`
403401
_gemm(order, ta, tb, m, n, K, alpha, A, lda, B, ldb, beta, dist_middle_terms, ldc)
404402

405403
return dist_middle_terms

0 commit comments

Comments
 (0)
0