@@ -605,13 +605,26 @@ cdef class PairwiseDistancesArgKmin(PairwiseDistancesReduction):
605
605
# For future work, this might can be an entrypoint to specialise operations
606
606
# for various back-end and/or hardware and/or datatypes, and/or fused
607
607
# {sparse, dense}-datasetspair etc.
608
-
609
- pda = PairwiseDistancesArgKmin(
610
- datasets_pair = DatasetsPair.get_for(X, Y, metric, metric_kwargs),
611
- k = k,
612
- chunk_size = chunk_size,
613
- strategy = strategy,
614
- )
608
+ if (
609
+ metric in (" euclidean" , " sqeuclidean" )
610
+ and not issparse(X)
611
+ and not issparse(Y)
612
+ ):
613
+ use_squared_distances = metric == " sqeuclidean"
614
+ pda = FastEuclideanPairwiseDistancesArgKmin(
615
+ X = X, Y = Y, k = k,
616
+ use_squared_distances = use_squared_distances,
617
+ chunk_size = chunk_size,
618
+ strategy = strategy,
619
+ metric_kwargs = metric_kwargs,
620
+ )
621
+ else : # Fall back on the default
622
+ pda = PairwiseDistancesArgKmin(
623
+ datasets_pair = DatasetsPair.get_for(X, Y, metric, metric_kwargs),
624
+ k = k,
625
+ chunk_size = chunk_size,
626
+ strategy = strategy,
627
+ )
615
628
616
629
# Limit the number of threads in second level of nested parallelism for BLAS
617
630
# to avoid threads over-subscription (in GEMM for instance).
@@ -819,3 +832,200 @@ cdef class PairwiseDistancesArgKmin(PairwiseDistancesReduction):
819
832
return np.asarray(self .argkmin_distances), np.asarray(self .argkmin_indices)
820
833
821
834
return np.asarray(self .argkmin_indices)
835
+
836
+
837
+ cdef class FastEuclideanPairwiseDistancesArgKmin(PairwiseDistancesArgKmin):
838
+ """ Fast specialized alternative for PairwiseDistancesArgKmin on EuclideanDistance.
839
+
840
+ The full pairwise squared distances matrix is computed as follows:
841
+
842
+ ||X - Y||² = ||X||² - 2 X.Y^T + ||Y||²
843
+
844
+ The middle term gets computed efficiently bellow using BLAS Level 3 GEMM.
845
+
846
+ Notes
847
+ -----
848
+ This implementation has a superior arithmetic intensity and hence
849
+ better running time when the alternative is IO bound, but it can suffer
850
+ from numerical instability caused by catastrophic cancellation potentially
851
+ introduced by the subtraction in the arithmetic expression above.
852
+
853
+ PairwiseDistancesArgKmin with EuclideanDistance must be used when higher
854
+ numerical precision is needed.
855
+ """
856
+
857
+ cdef:
858
+ const DTYPE_t[:, ::1 ] X
859
+ const DTYPE_t[:, ::1 ] Y
860
+ const DTYPE_t[::1 ] X_norm_squared
861
+ const DTYPE_t[::1 ] Y_norm_squared
862
+
863
+ # Buffers for GEMM
864
+ DTYPE_t ** dist_middle_terms_chunks
865
+ bint use_squared_distances
866
+
867
+ @classmethod
868
+ def is_usable_for (cls , X , Y , metric ) -> bool:
869
+ return (PairwiseDistancesArgKmin.is_usable_for(X , Y , metric ) and
870
+ not _in_unstable_openblas_configuration())
871
+
872
+ def __init__(
873
+ self ,
874
+ X ,
875
+ Y ,
876
+ ITYPE_t k ,
877
+ bint use_squared_distances = False ,
878
+ chunk_size = None ,
879
+ strategy = None ,
880
+ metric_kwargs = None ,
881
+ ):
882
+ if metric_kwargs is not None and len (metric_kwargs) > 0 :
883
+ warnings.warn(
884
+ f" Some metric_kwargs have been passed ({metric_kwargs}) but aren't"
885
+ f" usable for this case ({self.__class__.__name__}) and will be ignored." ,
886
+ UserWarning ,
887
+ stacklevel = 3 ,
888
+ )
889
+
890
+ super ().__init__(
891
+ # The datasets pair here is used for exact distances computations
892
+ datasets_pair = DatasetsPair.get_for(X, Y, metric = " euclidean" ),
893
+ k = k,
894
+ chunk_size = chunk_size,
895
+ )
896
+ # X and Y are checked by the DatasetsPair implemented as a DenseDenseDatasetsPair
897
+ cdef:
898
+ DenseDenseDatasetsPair datasets_pair = < DenseDenseDatasetsPair> self .datasets_pair
899
+ self .X, self .Y = datasets_pair.X, datasets_pair.Y
900
+
901
+ if metric_kwargs is not None and " Y_norm_squared" in metric_kwargs:
902
+ self .Y_norm_squared = metric_kwargs.pop(" Y_norm_squared" , None )
903
+ else :
904
+ self .Y_norm_squared = _sqeuclidean_row_norms(self .Y, self .effective_n_threads)
905
+
906
+ # Do not recompute norms if datasets are identical.
907
+ self .X_norm_squared = (
908
+ self .Y_norm_squared if X is Y else
909
+ _sqeuclidean_row_norms(self .X, self .effective_n_threads)
910
+ )
911
+ self .use_squared_distances = use_squared_distances
912
+
913
+ # Temporary datastructures used in threads
914
+ self .dist_middle_terms_chunks = < DTYPE_t ** > malloc(
915
+ sizeof(DTYPE_t * ) * self .chunks_n_threads
916
+ )
917
+
918
+ def __dealloc__ (self ):
919
+ if self .dist_middle_terms_chunks is not NULL :
920
+ free(self .dist_middle_terms_chunks)
921
+
922
+ @final
923
+ cdef void compute_exact_distances(self ) nogil:
924
+ if not self .use_squared_distances:
925
+ PairwiseDistancesArgKmin.compute_exact_distances(self )
926
+
927
+ @final
928
+ cdef void _parallel_on_X_parallel_init(
929
+ self ,
930
+ ITYPE_t thread_num,
931
+ ) nogil:
932
+ PairwiseDistancesArgKmin._parallel_on_X_parallel_init(self , thread_num)
933
+
934
+ # Temporary buffer for the `-2 * X_c @ Y_c.T` term
935
+ self .dist_middle_terms_chunks[thread_num] = < DTYPE_t * > malloc(
936
+ self .Y_n_samples_chunk * self .X_n_samples_chunk * sizeof(DTYPE_t)
937
+ )
938
+
939
+ @final
940
+ cdef void _parallel_on_X_parallel_finalize(
941
+ self ,
942
+ ITYPE_t thread_num
943
+ ) nogil:
944
+ PairwiseDistancesArgKmin._parallel_on_X_parallel_finalize(self , thread_num)
945
+ free(self .dist_middle_terms_chunks[thread_num])
946
+
947
+ @final
948
+ cdef void _parallel_on_Y_parallel_init(
949
+ self ,
950
+ ) nogil:
951
+ cdef ITYPE_t thread_num
952
+ PairwiseDistancesArgKmin._parallel_on_Y_parallel_init(self )
953
+
954
+ for thread_num in range (self .chunks_n_threads):
955
+ # Temporary buffer for the `-2 * X_c @ Y_c.T` term
956
+ self .dist_middle_terms_chunks[thread_num] = < DTYPE_t * > malloc(
957
+ self .Y_n_samples_chunk * self .X_n_samples_chunk * sizeof(DTYPE_t)
958
+ )
959
+
960
+ @final
961
+ cdef void _parallel_on_Y_finalize(
962
+ self ,
963
+ ) nogil:
964
+ cdef ITYPE_t thread_num
965
+ PairwiseDistancesArgKmin._parallel_on_Y_finalize(self )
966
+
967
+ for thread_num in range (self .chunks_n_threads):
968
+ free(self .dist_middle_terms_chunks[thread_num])
969
+
970
+ @final
971
+ cdef void _compute_and_reduce_distances_on_chunks(
972
+ self ,
973
+ ITYPE_t X_start,
974
+ ITYPE_t X_end,
975
+ ITYPE_t Y_start,
976
+ ITYPE_t Y_end,
977
+ ITYPE_t thread_num,
978
+ ) nogil:
979
+ cdef:
980
+ ITYPE_t i, j
981
+
982
+ const DTYPE_t[:, ::1 ] X_c = self .X[X_start:X_end, :]
983
+ const DTYPE_t[:, ::1 ] Y_c = self .Y[Y_start:Y_end, :]
984
+ DTYPE_t * dist_middle_terms = self .dist_middle_terms_chunks[thread_num]
985
+ DTYPE_t * heaps_r_distances = self .heaps_r_distances_chunks[thread_num]
986
+ ITYPE_t * heaps_indices = self .heaps_indices_chunks[thread_num]
987
+
988
+ # Careful: LDA, LDB and LDC are given for F-ordered arrays
989
+ # in BLAS documentations, for instance:
990
+ # https://www.netlib.org/lapack/explore-html/db/dc9/group__single__blas__level3_gafe51bacb54592ff5de056acabd83c260.html #noqa
991
+ #
992
+ # Here, we use their counterpart values to work with C-ordered arrays.
993
+ BLAS_Order order = RowMajor
994
+ BLAS_Trans ta = NoTrans
995
+ BLAS_Trans tb = Trans
996
+ ITYPE_t m = X_c.shape[0 ]
997
+ ITYPE_t n = Y_c.shape[0 ]
998
+ ITYPE_t K = X_c.shape[1 ]
999
+ DTYPE_t alpha = - 2.
1000
+ # Casting for A and B to remove the const is needed because APIs exposed via
1001
+ # scipy.linalg.cython_blas aren't reflecting the arguments' const qualifier.
1002
+ DTYPE_t * A = < DTYPE_t* > & X_c[0 , 0 ]
1003
+ ITYPE_t lda = X_c.shape[1 ]
1004
+ DTYPE_t * B = < DTYPE_t* > & Y_c[0 , 0 ]
1005
+ ITYPE_t ldb = X_c.shape[1 ]
1006
+ DTYPE_t beta = 0.
1007
+ DTYPE_t * C = dist_middle_terms
1008
+ ITYPE_t ldc = Y_c.shape[0 ]
1009
+
1010
+ # dist_middle_terms = `-2 * X_c @ Y_c.T`
1011
+ _gemm(order, ta, tb, m, n, K, alpha, A, lda, B, ldb, beta, C, ldc)
1012
+
1013
+ # Pushing the distance and their associated indices on heaps
1014
+ # which keep tracks of the argkmin.
1015
+ for i in range (X_c.shape[0 ]):
1016
+ for j in range (Y_c.shape[0 ]):
1017
+ heap_push(
1018
+ heaps_r_distances + i * self .k,
1019
+ heaps_indices + i * self .k,
1020
+ self .k,
1021
+ # Using the squared euclidean distance as the rank-preserving distance:
1022
+ #
1023
+ # ||X_c_i||² - 2 X_c_i.Y_c_j^T + ||Y_c_j||²
1024
+ #
1025
+ (
1026
+ self .X_nor
10000
m_squared[i + X_start] +
1027
+ dist_middle_terms[i * Y_c.shape[0 ] + j] +
1028
+ self .Y_norm_squared[j + Y_start]
1029
+ ),
1030
+ j + Y_start,
1031
+ )
0 commit comments