@@ -28,12 +28,9 @@ from ._base cimport (
28
28
_sqeuclidean_row_norms{{name_suffix}},
29
29
)
30
30
31
- from ._datasets_pair cimport (
32
- DatasetsPair{{name_suffix}},
33
- DenseDenseDatasetsPair{{name_suffix}},
34
- )
31
+ from ._datasets_pair cimport DatasetsPair{{name_suffix}}
35
32
36
- from ._gemm_term_computer cimport GEMMTermComputer {{name_suffix}}
33
+ from ._middle_term_computer cimport MiddleTermComputer {{name_suffix}}
37
34
38
35
39
36
cdef class ArgKmin{{name_suffix}}(BaseDistancesReduction{{name_suffix}}):
@@ -66,13 +63,16 @@ cdef class ArgKmin{{name_suffix}}(BaseDistancesReduction{{name_suffix}}):
66
63
"""
67
64
if (
68
65
metric in ("euclidean", "sqeuclidean")
69
- and not issparse(X)
70
- and not issparse(Y)
66
+ and not (issparse(X) or issparse(Y))
71
67
):
72
- # Specialized implementation with improved arithmetic intensity
73
- # and vector instructions (SIMD) by processing several vectors
74
- # at time to leverage a call to the BLAS GEMM routine as explained
75
- # in more details in the docstring.
68
+ # Specialized implementation of ArgKmin for the Euclidean distance.
69
+ # This implementation computes the distances by chunk using
70
+ # a decomposition of the Squared Euclidean distance.
71
+ # This specialisation has an improved arithmetic intensity for both
72
+ # the dense and sparse settings, allowing in most case speed-ups of
73
+ # several orders of magnitude compared to the generic ArgKmin
74
+ # implementation.
75
+ # For more information see MiddleTermComputer.
76
76
use_squared_distances = metric == "sqeuclidean"
77
77
pda = EuclideanArgKmin{{name_suffix}}(
78
78
X=X, Y=Y, k=k,
@@ -82,8 +82,8 @@ cdef class ArgKmin{{name_suffix}}(BaseDistancesReduction{{name_suffix}}):
82
82
metric_kwargs=metric_kwargs,
83
83
)
84
84
else:
85
- # Fall back on a generic implementation that handles most scipy
86
- # metrics by computing the distances between 2 vectors at a time.
85
+ # Fall back on a generic implementation that handles most scipy
86
+ # metrics by computing the distances between 2 vectors at a time.
87
87
pda = ArgKmin{{name_suffix}}(
88
88
datasets_pair=DatasetsPair{{name_suffix}}.get_for(X, Y, metric, metric_kwargs),
89
89
k=k,
@@ -347,21 +347,16 @@ cdef class EuclideanArgKmin{{name_suffix}}(ArgKmin{{name_suffix}}):
347
347
strategy=strategy,
348
348
k=k,
349
349
)
350
- # X and Y are checked by the DatasetsPair{{name_suffix}} implemented
351
- # as a DenseDenseDatasetsPair{{name_suffix}}
352
350
cdef:
353
- DenseDenseDatasetsPair{{name_suffix}} datasets_pair = (
354
- <DenseDenseDatasetsPair{{name_suffix}}> self.datasets_pair
355
- )
356
351
ITYPE_t dist_middle_terms_chunks_size = self.Y_n_samples_chunk * self.X_n_samples_chunk
357
352
358
- self.gemm_term_computer = GEMMTermComputer {{name_suffix}}(
359
- datasets_pair. X,
360
- datasets_pair. Y,
353
+ self.middle_term_computer = MiddleTermComputer {{name_suffix}}.get_for (
354
+ X,
355
+ Y,
361
356
self.effective_n_threads,
362
357
self.chunks_n_threads,
363
358
dist_middle_terms_chunks_size,
364
- n_features=datasets_pair. X.shape[1],
359
+ n_features=X.shape[1],
365
360
chunk_size=self.chunk_size,
366
361
)
367
362
@@ -373,12 +368,16 @@ cdef class EuclideanArgKmin{{name_suffix}}(ArgKmin{{name_suffix}}):
373
368
dtype=np.float64
374
369
)
375
370
else:
376
- self.Y_norm_squared = _sqeuclidean_row_norms{{name_suffix}}(datasets_pair.Y, self.effective_n_threads)
371
+ self.Y_norm_squared = _sqeuclidean_row_norms{{name_suffix}}(
372
+ Y, self.effective_n_threads
373
+ )
377
374
378
375
# Do not recompute norms if datasets are identical.
379
376
self.X_norm_squared = (
380
377
self.Y_norm_squared if X is Y else
381
- _sqeuclidean_row_norms{{name_suffix}}(datasets_pair.X, self.effective_n_threads)
378
+ _sqeuclidean_row_norms{{name_suffix}}(
379
+ X, self.effective_n_threads
380
+ )
382
381
)
383
382
self.use_squared_distances = use_squared_distances
384
383
@@ -393,8 +392,7 @@ cdef class EuclideanArgKmin{{name_suffix}}(ArgKmin{{name_suffix}}):
393
392
ITYPE_t thread_num,
394
393
) nogil:
395
394
ArgKmin{{name_suffix}}._parallel_on_X_parallel_init(self, thread_num)
396
- self.gemm_term_computer._parallel_on_X_parallel_init(thread_num)
397
-
395
+ self.middle_term_computer._parallel_on_X_parallel_init(thread_num)
398
396
399
397
@final
400
398
cdef void _parallel_on_X_init_chunk(
@@ -404,8 +402,7 @@ cdef class EuclideanArgKmin{{name_suffix}}(ArgKmin{{name_suffix}}):
404
402
ITYPE_t X_end,
405
403
) nogil:
406
404
ArgKmin{{name_suffix}}._parallel_on_X_init_chunk(self, thread_num, X_start, X_end)
407
- self.gemm_term_computer._parallel_on_X_init_chunk(thread_num, X_start, X_end)
408
-
405
+ self.middle_term_computer._parallel_on_X_init_chunk(thread_num, X_start, X_end)
409
406
410
407
@final
411
408
cdef void _parallel_on_X_pre_compute_and_reduce_distances_on_chunks(
@@ -422,18 +419,16 @@ cdef class EuclideanArgKmin{{name_suffix}}(ArgKmin{{name_suffix}}):
422
419
Y_start, Y_end,
423
420
thread_num,
424
421
)
425
- self.gemm_term_computer ._parallel_on_X_pre_compute_and_reduce_distances_on_chunks(
422
+ self.middle_term_computer ._parallel_on_X_pre_compute_and_reduce_distances_on_chunks(
426
423
X_start, X_end, Y_start, Y_end, thread_num,
427
424
)
428
425
429
-
430
426
@final
431
427
cdef void _parallel_on_Y_init(
432
428
self,
433
429
) nogil:
434
430
ArgKmin{{name_suffix}}._parallel_on_Y_init(self)
435
- self.gemm_term_computer._parallel_on_Y_init()
436
-
431
+ self.middle_term_computer._parallel_on_Y_init()
437
432
438
433
@final
439
434
cdef void _parallel_on_Y_parallel_init(
@@ -443,8 +438,7 @@ cdef class EuclideanArgKmin{{name_suffix}}(ArgKmin{{name_suffix}}):
443
438
ITYPE_t X_end,
444
439
) nogil:
445
440
ArgKmin{{name_suffix}}._parallel_on_Y_parallel_init(self, thread_num, X_start, X_end)
446
- self.gemm_term_computer._parallel_on_Y_parallel_init(thread_num, X_start, X_end)
447
-
441
+ self.middle_term_computer._parallel_on_Y_parallel_init(thread_num, X_start, X_end)
448
442
449
443
@final
450
444
cdef void _parallel_on_Y_pre_compute_and_reduce_distances_on_chunks(
@@ -461,11 +455,10 @@ cdef class EuclideanArgKmin{{name_suffix}}(ArgKmin{{name_suffix}}):
461
455
Y_start, Y_end,
462
456
thread_num,
463
457
)
464
- self.gemm_term_computer ._parallel_on_Y_pre_compute_and_reduce_distances_on_chunks(
458
+ self.middle_term_computer ._parallel_on_Y_pre_compute_and_reduce_distances_on_chunks(
465
459
X_start, X_end, Y_start, Y_end, thread_num
466
460
)
467
461
468
-
469
462
@final
470
463
cdef void _compute_and_reduce_distances_on_chunks(
471
464
self,
@@ -477,10 +470,9 @@ cdef class EuclideanArgKmin{{name_suffix}}(ArgKmin{{name_suffix}}):
477
470
) nogil:
478
471
cdef:
479
472
ITYPE_t i, j
480
- DTYPE_t squared_dist_i_j
481
473
ITYPE_t n_X = X_end - X_start
482
474
ITYPE_t n_Y = Y_end - Y_start
483
- DTYPE_t * dist_middle_terms = self.gemm_term_computer ._compute_dist_middle_terms(
475
+ DTYPE_t * dist_middle_terms = self.middle_term_computer ._compute_dist_middle_terms(
484
476
X_start, X_end, Y_start, Y_end, thread_num
485
477
)
486
478
DTYPE_t * heaps_r_distances = self.heaps_r_distances_chunks[thread_num]
0 commit comments