@@ -175,8 +175,6 @@ cdef class GEMMTermComputer{{name_suffix}}:
175
175
ITYPE_t thread_num,
176
176
) nogil:
177
177
cdef:
178
- const {{INPUT_DTYPE_t}}[:, ::1] X_c = self.X[X_start:X_end, :]
179
- const {{INPUT_DTYPE_t}}[:, ::1] Y_c = self.Y[Y_start:Y_end, :]
180
178
DTYPE_t *dist_middle_terms = self.dist_middle_terms_chunks[thread_num].data()
181
179
182
180
# Careful: LDA, LDB and LDC are given for F-ordered arrays
@@ -187,9 +185,9 @@ cdef class GEMMTermComputer{{name_suffix}}:
187
185
BLAS_Order order = RowMajor
188
186
BLAS_Trans ta = NoTrans
189
187
BLAS_Trans tb = Trans
190
- ITYPE_t m = X_c.shape[0]
191
- ITYPE_t n = Y_c.shape[0]
192
- ITYPE_t K = X_c.shape[1]
188
+ ITYPE_t m = X_end - X_start
189
+ ITYPE_t n = Y_end - Y_start
190
+ ITYPE_t K = self.n_features
193
191
DTYPE_t alpha = - 2.
194
192
{{if upcast_to_float64}}
195
193
DTYPE_t * A = self.X_c_upcast[thread_num].data()
@@ -198,15 +196,15 @@ cdef class GEMMTermComputer{{name_suffix}}:
198
196
# Casting for A and B to remove the const is needed because APIs exposed via
199
197
# scipy.linalg.cython_blas aren't reflecting the arguments' const qualifier.
200
198
# See: https://github.com/scipy/scipy/issues/14262
201
- DTYPE_t * A = <DTYPE_t *> &X_c[0 , 0]
202
- DTYPE_t * B = <DTYPE_t *> &Y_c[0 , 0]
199
+ DTYPE_t * A = <DTYPE_t *> &self.X[X_start , 0]
200
+ DTYPE_t * B = <DTYPE_t *> &self.Y[Y_start , 0]
203
201
{{endif}}
204
- ITYPE_t lda = X_c.shape[1]
205
- ITYPE_t ldb = X_c.shape[1]
202
+ ITYPE_t lda = self.n_features
203
+ ITYPE_t ldb = self.n_features
206
204
DTYPE_t beta = 0.
207
- ITYPE_t ldc = Y_c.shape[0]
205
+ ITYPE_t ldc = Y_end - Y_start
208
206
209
- # dist_middle_terms = `-2 * X_c @ Y_c .T`
207
+ # dist_middle_terms = `-2 * X[X_start:X_end] @ Y[Y_start:Y_end] .T`
210
208
_gemm(order, ta, tb, m, n, K, alpha, A, lda, B, ldb, beta, dist_middle_terms, ldc)
211
209
212
210
return dist_middle_terms
0 commit comments