8000 MAINT Introduce FastEuclideanPairwiseArgKmin · scikit-learn/scikit-learn@e897695 · GitHub
[go: up one dir, main page]

Skip to content

Commit e897695

Browse files
committed
MAINT Introduce FastEuclideanPairwiseArgKmin
1 parent de166e0 commit e897695

File tree

2 files changed

+259
-7
lines changed

2 files changed

+259
-7
lines changed

sklearn/metrics/_pairwise_distances_reduction.pyx

Lines changed: 217 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -605,13 +605,26 @@ cdef class PairwiseDistancesArgKmin(PairwiseDistancesReduction):
605605
# For future work, this might can be an entrypoint to specialise operations
606606
# for various back-end and/or hardware and/or datatypes, and/or fused
607607
# {sparse, dense}-datasetspair etc.
608-
609-
pda = PairwiseDistancesArgKmin(
610-
datasets_pair=DatasetsPair.get_for(X, Y, metric, metric_kwargs),
611-
k=k,
612-
chunk_size=chunk_size,
613-
strategy=strategy,
614-
)
608+
if (
609+
metric in ("euclidean", "sqeuclidean")
610+
and not issparse(X)
611+
and not issparse(Y)
612+
):
613+
use_squared_distances = metric == "sqeuclidean"
614+
pda = FastEuclideanPairwiseDistancesArgKmin(
615+
X=X, Y=Y, k=k,
616+
use_squared_distances=use_squared_distances,
617+
chunk_size=chunk_size,
618+
strategy=strategy,
619+
metric_kwargs=metric_kwargs,
620+
)
621+
else: # Fall back on the default
622+
pda = PairwiseDistancesArgKmin(
623+
datasets_pair=DatasetsPair.get_for(X, Y, metric, metric_kwargs),
624+
k=k,
625+
chunk_size=chunk_size,
626+
strategy=strategy,
627+
)
615628

616629
# Limit the number of threads in second level of nested parallelism for BLAS
617630
# to avoid threads over-subscription (in GEMM for instance).
@@ -819,3 +832,200 @@ cdef class PairwiseDistancesArgKmin(PairwiseDistancesReduction):
819832
return np.asarray(self.argkmin_distances), np.asarray(self.argkmin_indices)
820833

821834
return np.asarray(self.argkmin_indices)
835+
836+
837+
cdef class FastEuclideanPairwiseDistancesArgKmin(PairwiseDistancesArgKmin):
838+
"""Fast specialized alternative for PairwiseDistancesArgKmin on EuclideanDistance.
839+
840+
The full pairwise squared distances matrix is computed as follows:
841+
842+
||X - Y||² = ||X||² - 2 X.Y^T + ||Y||²
843+
844+
The middle term gets computed efficiently bellow using BLAS Level 3 GEMM.
845+
846+
Notes
847+
-----
848+
This implementation has a superior arithmetic intensity and hence
849+
better running time when the alternative is IO bound, but it can suffer
850+
from numerical instability caused by catastrophic cancellation potentially
851+
introduced by the subtraction in the arithmetic expression above.
852+
853+
PairwiseDistancesArgKmin with EuclideanDistance must be used when higher
854+
numerical precision is needed.
855+
"""
856+
857+
cdef:
858+
const DTYPE_t[:, ::1] X
859+
const DTYPE_t[:, ::1] Y
860+
const DTYPE_t[::1] X_norm_squared
861+
const DTYPE_t[::1] Y_norm_squared
862+
863+
# Buffers for GEMM
864+
DTYPE_t ** dist_middle_terms_chunks
865+
bint use_squared_distances
866+
867+
@classmethod
868+
def is_usable_for(cls, X, Y, metric) -> bool:
869+
return (PairwiseDistancesArgKmin.is_usable_for(X, Y, metric) and
870+
not _in_unstable_openblas_configuration())
871+
872+
def __init__(
873+
self,
874+
X,
875+
Y,
876+
ITYPE_t k,
877+
bint use_squared_distances=False,
878+
chunk_size=None,
879+
strategy=None,
880+
metric_kwargs=None,
881+
):
882+
if metric_kwargs is not None and len(metric_kwargs) > 0:
883+
warnings.warn(
884+
f"Some metric_kwargs have been passed ({metric_kwargs}) but aren't"
885+
f"usable for this case ({self.__class__.__name__}) and will be ignored.",
886+
UserWarning,
887+
stacklevel=3,
888+
)
889+
890+
super().__init__(
891+
# The datasets pair here is used for exact distances computations
892+
datasets_pair=DatasetsPair.get_for(X, Y, metric="euclidean"),
893+
k=k,
894+
chunk_size=chunk_size,
895+
)
896+
# X and Y are checked by the DatasetsPair implemented as a DenseDenseDatasetsPair
897+
cdef:
898+
DenseDenseDatasetsPair datasets_pair = <DenseDenseDatasetsPair> self.datasets_pair
899+
self.X, self.Y = datasets_pair.X, datasets_pair.Y
900+
901+
if metric_kwargs is not None and "Y_norm_squared" in metric_kwargs:
902+
self.Y_norm_squared = metric_kwargs.pop("Y_norm_squared", None)
903+
else:
904+
self.Y_norm_squared = _sqeuclidean_row_norms(self.Y, self.effective_n_threads)
905+
906+
# Do not recompute norms if datasets are identical.
907+
self.X_norm_squared = (
908+
self.Y_norm_squared if X is Y else
909+
_sqeuclidean_row_norms(self.X, self.effective_n_threads)
910+
)
911+
self.use_squared_distances = use_squared_distances
912+
913+
# Temporary datastructures used in threads
914+
self.dist_middle_terms_chunks = <DTYPE_t **> malloc(
915+
sizeof(DTYPE_t *) * self.chunks_n_threads
916+
)
917+
918+
def __dealloc__(self):
919+
if self.dist_middle_terms_chunks is not NULL:
920+
free(self.dist_middle_terms_chunks)
921+
922+
@final
923+
cdef void compute_exact_distances(self) nogil:
924+
if not self.use_squared_distances:
925+
PairwiseDistancesArgKmin.compute_exact_distances(self)
926+
927+
@final
928+
cdef void _parallel_on_X_parallel_init(
929+
self,
930+
ITYPE_t thread_num,
931+
) nogil:
932+
PairwiseDistancesArgKmin._parallel_on_X_parallel_init(self, thread_num)
933+
934+
# Temporary buffer for the `-2 * X_c @ Y_c.T` term
935+
self.dist_middle_terms_chunks[thread_num] = <DTYPE_t *> malloc(
936+
self.Y_n_samples_chunk * self.X_n_samples_chunk * sizeof(DTYPE_t)
937+
)
938+
939+
@final
940+
cdef void _parallel_on_X_parallel_finalize(
941+
self,
942+
ITYPE_t thread_num
943+
) nogil:
944+
PairwiseDistancesArgKmin._parallel_on_X_parallel_finalize(self, thread_num)
945+
free(self.dist_middle_terms_chunks[thread_num])
946+
947+
@final
948+
cdef void _parallel_on_Y_parallel_init(
949+
self,
950+
) nogil:
951+
cdef ITYPE_t thread_num
952+
PairwiseDistancesArgKmin._parallel_on_Y_parallel_init(self)
953+
954+
for thread_num in range(self.chunks_n_threads):
955+
# Temporary buffer for the `-2 * X_c @ Y_c.T` term
956+
self.dist_middle_terms_chunks[thread_num] = <DTYPE_t *> malloc(
957+
self.Y_n_samples_chunk * self.X_n_samples_chunk * sizeof(DTYPE_t)
958+
)
959+
960+
@final
961+
cdef void _parallel_on_Y_finalize(
962+
self,
963+
) nogil:
964+
cdef ITYPE_t thread_num
965+
PairwiseDistancesArgKmin._parallel_on_Y_finalize(self)
966+
967+
for thread_num in range(self.chunks_n_threads):
968+
free(self.dist_middle_terms_chunks[thread_num])
969+
970+
@final
971+
cdef void _compute_and_reduce_distances_on_chunks(
972+
self,
973+
ITYPE_t X_start,
974+
ITYPE_t X_end,
975+
ITYPE_t Y_start,
976+
ITYPE_t Y_end,
977+
ITYPE_t thread_num,
978+
) nogil:
979+
cdef:
980+
ITYPE_t i, j
981+
982+
const DTYPE_t[:, ::1] X_c = self.X[X_start:X_end, :]
983+
const DTYPE_t[:, ::1] Y_c = self.Y[Y_start:Y_end, :]
984+
DTYPE_t *dist_middle_terms = self.dist_middle_terms_chunks[thread_num]
985+
DTYPE_t *heaps_r_distances = self.heaps_r_distances_chunks[thread_num]
986+
ITYPE_t *heaps_indices = self.heaps_indices_chunks[thread_num]
987+
988+
# Careful: LDA, LDB and LDC are given for F-ordered arrays
989+
# in BLAS documentations, for instance:
990+
# https://www.netlib.org/lapack/explore-html/db/dc9/group__single__blas__level3_gafe51bacb54592ff5de056acabd83c260.html #noqa
991+
#
992+
# Here, we use their counterpart values to work with C-ordered arrays.
993+
BLAS_Order order = RowMajor
994+
BLAS_Trans ta = NoTrans
995+
BLAS_Trans tb = Trans
996+
ITYPE_t m = X_c.shape[0]
997+
ITYPE_t n = Y_c.shape[0]
998+
ITYPE_t K = X_c.shape[1]
999+
DTYPE_t alpha = - 2.
1000+
# Casting for A and B to remove the const is needed because APIs exposed via
1001+
# scipy.linalg.cython_blas aren't reflecting the arguments' const qualifier.
1002+
DTYPE_t * A = <DTYPE_t*> & X_c[0, 0]
1003+
ITYPE_t lda = X_c.shape[1]
1004+
DTYPE_t * B = <DTYPE_t*> & Y_c[0, 0]
1005+
ITYPE_t ldb = X_c.shape[1]
1006+
DTYPE_t beta = 0.
1007+
DTYPE_t * C = dist_middle_terms
1008+
ITYPE_t ldc = Y_c.shape[0]
1009+
1010+
# dist_middle_terms = `-2 * X_c @ Y_c.T`
1011+
_gemm(order, ta, tb, m, n, K, alpha, A, lda, B, ldb, beta, C, ldc)
1012+
1013+
# Pushing the distance and their associated indices on heaps
1014+
# which keep tracks of the argkmin.
1015+
for i in range(X_c.shape[0]):
1016+
for j in range(Y_c.shape[0]):
1017+
heap_push(
1018+
heaps_r_distances + i * self.k,
1019+
heaps_indices + i * self.k,
1020+
self.k,
1021+
# Using the squared euclidean distance as the rank-preserving distance:
1022+
#
1023+
# ||X_c_i||² - 2 X_c_i.Y_c_j^T + ||Y_c_j||²
1024+
#
1025+
(
1026+
self.X_nor 10000 m_squared[i + X_start] +
1027+
dist_middle_terms[i * Y_c.shape[0] + j] +
1028+
self.Y_norm_squared[j + Y_start]
1029+
),
1030+
j + Y_start,
1031+
)

sklearn/metrics/tests/test_pairwise_distances_reduction.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from sklearn.metrics._pairwise_distances_reduction import (
88
PairwiseDistancesReduction,
99
PairwiseDistancesArgKmin,
10+
FastEuclideanPairwiseDistancesArgKmin,
1011
_sqeuclidean_row_norms,
1112
)
1213

@@ -139,6 +140,47 @@ def test_argkmin_factory_method_wrong_usages():
139140
)
140141

141142

143+
@fails_if_unstable_openblas
144+
@pytest.mark.filterwarnings("ignore:Constructing a DIA matrix")
145+
@pytest.mark.parametrize(
146+
"PairwiseDistancesReduction, FastPairwiseDistancesReduction",
147+
[
148+
(PairwiseDistancesArgKmin, FastEuclideanPairwiseDistancesArgKmin),
149+
],
150+
)
151+
def test_pairwise_distances_reduction_factory_method(
152+
PairwiseDistancesReduction, FastPairwiseDistancesReduction
153+
):
154+
# Test all the combinations of DatasetsPair for creation
155+
rng = np.random.RandomState(1)
156+
X = rng.rand(100, 10)
157+
Y = rng.rand(100, 10)
158+
metric = "euclidean"
159+
160+
# Dummy value for k or radius
161+
dummy_arg = 5
162+
163+
with pytest.raises(
164+
ValueError, match="Only dense datasets are supported for X and Y."
165+
):
166+
PairwiseDistancesReduction.compute(
167+
csr_matrix(X),
168+
csr_matrix(Y),
169+
dummy_arg,
170+
metric,
171+
)
172+
173+
with pytest.raises(
174+
ValueError, match="Only dense datasets are supported for X and Y."
175+
):
176+
PairwiseDistancesReduction.compute(X, csr_matrix(Y), dummy_arg, metric=metric)
177+
178+
with pytest.raises(
179+
ValueError, match="Only dense datasets are supported for X and Y."
180+
):
181+
PairwiseDistancesReduction.compute(csr_matrix(X), Y, dummy_arg, metric=metric)
182+
183+
142184
@fails_if_unstable_openblas
143185
@pytest.mark.parametrize("seed", range(5))
144186
@pytest.mark.parametrize("n_samples", [100, 1000])

0 commit comments

Comments
 (0)
0