8000 Revert "Do not slice memoryviews in _compute_dist_middle_terms" · scikit-learn/scikit-learn@379ffae · GitHub
[go: up one dir, main page]

Skip to content

Navigation Menu

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 379ffae

Browse files
committed
Revert "Do not slice memoryviews in _compute_dist_middle_terms"
This reverts commit f2e917b.
1 parent f2e917b commit 379ffae

File tree

1 file changed

+11
-9
lines changed

1 file changed

+11
-9
lines changed

sklearn/metrics/_pairwise_distances_reduction/_gemm_term_computer.pyx.tp

Lines changed: 11 additions & 9 deletions
< 8000 /tr>
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,8 @@ cdef class GEMMTermComputer{{name_suffix}}:
175175
ITYPE_t thread_num,
176176
) nogil:
177177
cdef:
178+
const {{INPUT_DTYPE_t}}[:, ::1] X_c = self.X[X_start:X_end, :]
179+
const {{INPUT_DTYPE_t}}[:, ::1] Y_c = self.Y[Y_start:Y_end, :]
178180
DTYPE_t *dist_middle_terms = self.dist_middle_terms_chunks[thread_num].data()
179181

180182
# Careful: LDA, LDB and LDC are given for F-ordered arrays
@@ -185,9 +187,9 @@ cdef class GEMMTermComputer{{name_suffix}}:
185187
BLAS_Order order = RowMajor
186188
BLAS_Trans ta = NoTrans
187189
BLAS_Trans tb = Trans
188-
ITYPE_t m = X_end - X_start
189-
ITYPE_t n = Y_end - Y_start
190-
ITYPE_t K = self.n_features
190+
ITYPE_t m = X_c.shape[0]
191+
ITYPE_t n = Y_c.shape[0]
192+
ITYPE_t K = X_c.shape[1]
191193
DTYPE_t alpha = - 2.
192194
{{if upcast_to_float64}}
193195
DTYPE_t * A = self.X_c_upcast[thread_num].data()
@@ -196,15 +198,15 @@ cdef class GEMMTermComputer{{name_suffix}}:
196198
# Casting for A and B to remove the const is needed because APIs exposed via
197199
# scipy.linalg.cython_blas aren't reflecting the arguments' const qualifier.
198200
# See: https://github.com/scipy/scipy/issues/14262
199-
DTYPE_t * A = <DTYPE_t *> &self.X[X_start, 0]
200-
DTYPE_t * B = <DTYPE_t *> &self.Y[Y_start, 0]
201+
DTYPE_t * A = <DTYPE_t *> &X_c[0, 0]
202+
DTYPE_t * B = <DTYPE_t *> &Y_c[0, 0]
201203
{{endif}}
202-
ITYPE_t lda = self.n_features
203-
ITYPE_t ldb = self.n_features
204+
ITYPE_t lda = X_c.shape[1]
205+
ITYPE_t ldb = X_c.shape[1]
204206
DTYPE_t beta = 0.
205-
ITYPE_t ldc = Y_end - Y_start
207+
ITYPE_t ldc = Y_c.shape[0]
206208

207-
# dist_middle_terms = `-2 * X[X_start:X_end] @ Y[Y_start:Y_end].T`
209+
# dist_middle_terms = `-2 * X_c @ Y_c.T`
208210
_gemm(order, ta, tb, m, n, K, alpha, A, lda, B, ldb, beta, dist_middle_terms, ldc)
209211

210212
return dist_middle_terms

0 commit comments

Comments
 (0)
0