8000 MAINT Correctly resize buffers for upcasting · scikit-learn/scikit-learn@f0fc839 · GitHub
[go: up one dir, main page]

Skip to content

Commit f0fc839

Browse files
committed
MAINT Correctly resize buffers for upcasting
1 parent 7b0bcd3 commit f0fc839

File tree

1 file changed

+13
-5
lines changed

1 file changed

+13
-5
lines changed

sklearn/metrics/_pairwise_distances_reduction.pyx.tp

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -622,11 +622,13 @@ cdef class GEMMTermComputer{{bitness}}:
622622
ITYPE_t chunks_n_threads
623623
ITYPE_t dist_middle_terms_chunks_size
624624
ITYPE_t n_features
625+
ITYPE_t chunk_size
625626

626627
# Buffers for the `-2 * X_c @ Y_c.T` term computed via GEMM
627628
vector[vector[DTYPE_t]] dist_middle_terms_chunks
628629

629630
{{if need_upcast}}
631+
# Buffers for upcasting chunks of X and Y from 32bit to 64bit
630632
vector[vector[DTYPE_t]] X_c_upcast
631633
vector[vector[DTYPE_t]] Y_c_upcast
632634
{{endif}}
@@ -638,24 +640,28 @@ cdef class GEMMTermComputer{{bitness}}:
638640
ITYPE_t chunks_n_threads,
639641
ITYPE_t dist_middle_terms_chunks_size,
640642
ITYPE_t n_features,
643+
ITYPE_t chunk_size,
641644
):
642645
self.X = X
643646
self.Y = Y
644647
self.effective_n_threads = effective_n_threads
645648
self.chunks_n_threads = chunks_n_threads
646649
self.dist_middle_terms_chunks_size = dist_middle_terms_chunks_size
647650
self.n_features = n_features
651+
self.chunk_size = chunk_size
648652

649653
self.dist_middle_terms_chunks = vector[vector[DTYPE_t]](self.effective_n_threads)
650654

651655
{{if need_upcast}}
656+
# We populate the buffer for upcasting chunks of X and Y from 32bit to 64bit.
652657
self.X_c_upcast = vector[vector[DTYPE_t]](self.effective_n_threads)
653658
self.Y_c_upcast = vector[vector[DTYPE_t]](self.effective_n_threads)
654659

655-
# Buffers for upcasting chunks of X and Y from 32bit to 64bit.
660+
upcast_buffer_n_elements = self.chunk_size * n_features
661+
656662
for thread_num in range(self.effective_n_threads):
657-
self.X_c_upcast[thread_num].resize(self.dist_middle_terms_chunks_size)
658-
self.Y_c_upcast[thread_num].resize(self.dist_middle_terms_chunks_size)
663+
self.X_c_upcast[thread_num].resize(upcast_buffer_n_elements)
664+
self.Y_c_upcast[thread_num].resize(upcast_buffer_n_elements)
659665
{{endif}}
660666

661667

@@ -1556,7 +1562,8 @@ cdef class FastEuclideanPairwiseDistancesArgKmin{{bitness}}(PairwiseDistancesArg
15561562
self.effective_n_threads,
15571563
self.chunks_n_threads,
15581564
dist_middle_terms_chunks_size,
1559-
n_features=datasets_pair.X.shape[1]
1565+
n_features=datasets_pair.X.shape[1],
1566+
chunk_size=self.chunk_size,
15601567
)
15611568

15621569
if metric_kwargs is not None and "Y_norm_squared" in metric_kwargs:
@@ -2171,7 +2178,8 @@ cdef class FastEuclideanPairwiseDistancesRadiusNeighborhood{{bitness}}(PairwiseD
21712178
self.effective_n_threads,
21722179
self.chunks_n_threads,
21732180
dist_middle_terms_chunks_size,
2174-
n_features=datasets_pair.X.shape[1]
2181+
n_features=datasets_pair.X.shape[1],
2182+
chunk_size=self.chunk_size,
21752183
)
21762184

21772185
if metric_kwargs is not None and "Y_norm_squared" in metric_kwargs:

0 commit comments

Comments
 (0)
0