@@ -622,11 +622,13 @@ cdef class GEMMTermComputer{{bitness}}:
622
622
ITYPE_t chunks_n_threads
623
623
ITYPE_t dist_middle_terms_chunks_size
624
624
ITYPE_t n_features
625
+ ITYPE_t chunk_size
625
626
626
627
# Buffers for the `-2 * X_c @ Y_c.T` term computed via GEMM
627
628
vector[vector[DTYPE_t]] dist_middle_terms_chunks
628
629
629
630
{{if need_upcast}}
631
+ # Buffers for upcasting chunks of X and Y from 32bit to 64bit
630
632
vector[vector[DTYPE_t]] X_c_upcast
631
633
vector[vector[DTYPE_t]] Y_c_upcast
632
634
{{endif}}
@@ -638,24 +640,28 @@ cdef class GEMMTermComputer{{bitness}}:
638
640
ITYPE_t chunks_n_threads,
639
641
ITYPE_t dist_middle_terms_chunks_size,
640
642
ITYPE_t n_features,
643
+ ITYPE_t chunk_size,
641
644
):
642
645
self.X = X
643
646
self.Y = Y
644
647
self.effective_n_threads = effective_n_threads
645
648
self.chunks_n_threads = chunks_n_threads
646
649
self.dist_middle_terms_chunks_size = dist_middle_terms_chunks_size
647
650
self.n_features = n_features
651
+ self.chunk_size = chunk_size
648
652
649
653
self.dist_middle_terms_chunks = vector[vector[DTYPE_t]](self.effective_n_threads)
650
654
651
655
{{if need_upcast}}
656
+ # We populate the buffer for upcasting chunks of X and Y from 32bit to 64bit.
652
657
self.X_c_upcast = vector[vector[DTYPE_t]](self.effective_n_threads)
653
658
self.Y_c_upcast = vector[vector[DTYPE_t]](self.effective_n_threads)
654
659
655
- # Buffers for upcasting chunks of X and Y from 32bit to 64bit.
660
+ upcast_buffer_n_elements = self.chunk_size * n_features
661
+
656
662
for thread_num in range(self.effective_n_threads):
657
- self.X_c_upcast[thread_num].resize(self.dist_middle_terms_chunks_size )
658
- self.Y_c_upcast[thread_num].resize(self.dist_middle_terms_chunks_size )
663
+ self.X_c_upcast[thread_num].resize(upcast_buffer_n_elements )
664
+ self.Y_c_upcast[thread_num].resize(upcast_buffer_n_elements )
659
665
{{endif}}
660
666
661
667
@@ -1556,7 +1562,8 @@ cdef class FastEuclideanPairwiseDistancesArgKmin{{bitness}}(PairwiseDistancesArg
1556
1562
self.effective_n_threads,
1557
1563
self.chunks_n_threads,
1558
1564
dist_middle_terms_chunks_size,
1559
- n_features=datasets_pair.X.shape[1]
1565
+ n_features=datasets_pair.X.shape[1],
1566
+ chunk_size=self.chunk_size,
1560
1567
)
1561
1568
1562
1569
if metric_kwargs is not None and "Y_norm_squared" in metric_kwargs:
@@ -2171,7 +2178,8 @@ cdef class FastEuclideanPairwiseDistancesRadiusNeighborhood{{bitness}}(PairwiseD
2171
2178
self.effective_n_threads,
2172
2179
self.chunks_n_threads,
2173
2180
dist_middle_terms_chunks_size,
2174
- n_features=datasets_pair.X.shape[1]
2181
+ n_features=datasets_pair.X.shape[1],
2182
+ chunk_size=self.chunk_size,
2175
2183
)
2176
2184
2177
2185
if metric_kwargs is not None and "Y_norm_squared" in metric_kwargs:
0 commit comments