8000 MAINT Clean-up comments and rename variables in `_middle_term_sparse_… · betatim/scikit-learn@4acd91d · GitHub
[go: up one dir, main page]

Skip to content

Commit 4acd91d

Browse files
MAINT Clean-up comments and rename variables in _middle_term_sparse_sparse_{32, 64} (scikit-learn#25449)
Co-authored-by: Julien Jerphanion <git@jjerphan.xyz>
1 parent 83a8774 commit 4acd91d

File tree

1 file changed

+7
-15
lines changed

1 file changed

+7
-15
lines changed

sklearn/metrics/_pairwise_distances_reduction/_middle_term_computer.pyx.tp

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,6 @@ import numpy as np
3838
from scipy.sparse import issparse, csr_matrix
3939
from ...utils._typedefs import DTYPE, SPARSE_INDEX_TYPE
4040

41-
# TODO: If possible optimize this routine to efficiently treat cases where
42-
# `n_samples_X << n_samples_Y` met in practise when X_test consists of a
43-
# few samples, and thus when there's a single chunk of X whose number of
44-
# samples is less than the default chunk size.
45-
46-
# TODO: compare this routine with the similar ones in SciPy, especially
47-
# `csr_matmat` which might implement a better algorithm.
48-
# See: https://github.com/scipy/scipy/blob/e58292e066ba2cb2f3d1e0563ca9314ff1f4f311/scipy/sparse/sparsetools/csr.h#L603-L669 # noqa
4941
cdef void _middle_term_sparse_sparse_64(
5042
const DTYPE_t[:] X_data,
5143
const SPARSE_INDEX_TYPE_t[:] X_indices,
@@ -66,17 +58,17 @@ cdef void _middle_term_sparse_sparse_64(
6658
ITYPE_t i, j, k
6759
ITYPE_t n_X = X_end - X_start
6860
ITYPE_t n_Y = Y_end - Y_start
69-
ITYPE_t X_i_col_idx, X_i_ptr, Y_j_col_idx, Y_j_ptr
61+
ITYPE_t x_col, x_ptr, y_col, y_ptr
7062

7163
for i in range(n_X):
72-
for X_i_ptr in range(X_indptr[X_start+i], X_indptr[X_start+i+1]):
73-
X_i_col_idx = X_indices[X_i_ptr]
64+
for x_ptr in range(X_indptr[X_start+i], X_indptr[X_start+i+1]):
65+
x_col = X_indices[x_ptr]
7466
for j in range(n_Y):
7567
k = i * n_Y + j
76-
for Y_j_ptr in range(Y_indptr[Y_start+j], Y_indptr[Y_start+j+1]):
77-
Y_j_col_idx = Y_indices[Y_j_ptr]
78-
if X_i_col_idx == Y_j_col_idx:
79-
D[k] += -2 * X_data[X_i_ptr] * Y_data[Y_j_ptr]
68+
for y_ptr in range(Y_indptr[Y_start+j], Y_indptr[Y_start+j+1]):
69+
y_col = Y_indices[y_ptr]
70+
if x_col == y_col:
71+
D[k] += -2 * X_data[x_ptr] * Y_data[y_ptr]
8072

8173

8274
{{for name_suffix, upcast_to_float64, INPUT_DTYPE_t, INPUT_DTYPE in implementation_specific_values}}

0 commit comments

Comments
 (0)
0