8000 MAINT Introduce `MiddleTermComputer`, an abstraction generalizing `GE… · scikit-learn/scikit-learn@239e163 · GitHub
[go: up one dir, main page]

Skip to content

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 239e163

Browse files
Vincent-Maladierejjerphanogrisel
authored
MAINT Introduce MiddleTermComputer, an abstraction generalizing GEMMTermComputer (#24807)
Co-authored-by: Julien Jerphanion <git@jjerphan.xyz> Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
1 parent 266d2a2 commit 239e163

File tree

12 files changed

+310
-132
lines changed

12 files changed

+310
-132
lines changed

.gitignore

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ sklearn/metrics/_pairwise_distances_reduction/_base.pxd
9393
sklearn/metrics/_pairwise_distances_reduction/_base.pyx
9494
sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pxd
9595
sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pyx
96-
sklearn/metrics/_pairwise_distances_reduction/_gemm_term_computer.pxd
97-
sklearn/metrics/_pairwise_distances_reduction/_gemm_term_computer.pyx
96+
sklearn/metrics/_pairwise_distances_reduction/_middle_term_computer.pxd
97+
sklearn/metrics/_pairwise_distances_reduction/_middle_term_computer.pyx
9898
sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors.pxd
9999
sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors.pyx

setup.cfg

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,8 @@ ignore =
8383
sklearn/metrics/_pairwise_distances_reduction/_base.pyx
8484
sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pxd
8585
sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pyx
86-
sklearn/metrics/_pairwise_distances_reduction/_gemm_term_computer.pxd
87-
sklearn/metrics/_pairwise_distances_reduction/_gemm_term_computer.pyx
86+
sklearn/metrics/_pairwise_distances_reduction/_middle_term_computer.pxd
87+
sklearn/metrics/_pairwise_distances_reduction/_middle_term_computer.pyx
8888
sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors.pxd
8989
sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors.pyx
9090

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@
8888
"sklearn.manifold._barnes_hut_tsne",
8989
"sklearn.metrics.cluster._expected_mutual_info_fast",
9090
"sklearn.metrics._pairwise_distances_reduction._datasets_pair",
91-
"sklearn.metrics._pairwise_distances_reduction._gemm_term_computer",
91+
"sklearn.metrics._pairwise_distances_reduction._middle_term_computer",
9292
"sklearn.metrics._pairwise_distances_reduction._base",
9393
"sklearn.metrics._pairwise_distances_reduction._argkmin",
9494
"sklearn.metrics._pairwise_distances_reduction._radius_neighbors",
@@ -316,7 +316,7 @@ def check_package_status(package, min_version):
316316
"extra_compile_args": ["-std=c++11"],
317317
},
318318
{
319-
"sources": ["_gemm_term_computer.pyx.tp", "_gemm_term_computer.pxd.tp"],
319+
"sources": ["_middle_term_computer.pyx.tp", "_middle_term_computer.pxd.tp"],
320320
"language": "c++",
321321
"include_np": True,
322322
"extra_compile_args": ["-std=c++11"],

sklearn/metrics/_pairwise_distances_reduction/_argkmin.pxd.tp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ cnp.import_array()
66
{{for name_suffix in ['64', '32']}}
77

88
from ._base cimport BaseDistancesReduction{{name_suffix}}
9-
from ._gemm_term_computer cimport GEMMTermComputer{{name_suffix}}
9+
from ._middle_term_computer cimport MiddleTermComputer{{name_suffix}}
1010

1111
cdef class ArgKmin{{name_suffix}}(BaseDistancesReduction{{name_suffix}}):
1212
"""float{{name_suffix}} implementation of the ArgKmin."""
@@ -25,7 +25,7 @@ cdef class ArgKmin{{name_suffix}}(BaseDistancesReduction{{name_suffix}}):
2525
cdef class EuclideanArgKmin{{name_suffix}}(ArgKmin{{name_suffix}}):
2626
"""EuclideanDistance-specialisation of ArgKmin{{name_suffix}}."""
2727
cdef:
28-
GEMMTermComputer{{name_suffix}} gemm_term_computer
28+
MiddleTermComputer{{name_suffix}} middle_term_computer
2929
const DTYPE_t[::1] X_norm_squared
3030
const DTYPE_t[::1] Y_norm_squared
3131

sklearn/metrics/_pairwise_distances_reduction/_argkmin.pyx.tp

Lines changed: 30 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,9 @@ from ._base cimport (
2828
_sqeuclidean_row_norms{{name_suffix}},
2929
)
3030

31-
from ._datasets_pair cimport (
32-
DatasetsPair{{name_suffix}},
33-
DenseDenseDatasetsPair{{name_suffix}},
34-
)
31+
from ._datasets_pair cimport DatasetsPair{{name_suffix}}
3532

36-
from ._gemm_term_computer cimport GEMMTermComputer{{name_suffix}}
33+
from ._middle_term_computer cimport MiddleTermComputer{{name_suffix}}
3734

3835

3936
cdef class ArgKmin{{name_suffix}}(BaseDistancesReduction{{name_suffix}}):
@@ -66,13 +63,16 @@ cdef class ArgKmin{{name_suffix}}(BaseDistancesReduction{{name_suffix}}):
6663
"""
6764
if (
6865
metric in ("euclidean", "sqeuclidean")
69-
and not issparse(X)
70-
and not issparse(Y)
66+
and not (issparse(X) or issparse(Y))
7167
):
72-
# Specialized implementation with improved arithmetic intensity
73-
# and vector instructions (SIMD) by processing several vectors
74-
# at time to leverage a call to the BLAS GEMM routine as explained
75-
# in more details in the docstring.
68+
# Specialized implementation of ArgKmin for the Euclidean distance.
69+
# This implementation computes the distances by chunk using
70+
# a decomposition of the Squared Euclidean distance.
71+
# This specialisation has an improved arithmetic intensity for both
72+
# the dense and sparse settings, allowing in most case speed-ups of
73+
# several orders of magnitude compared to the generic ArgKmin
74+
# implementation.
75+
# For more information see MiddleTermComputer.
7676
use_squared_distances = metric == "sqeuclidean"
7777
pda = EuclideanArgKmin{{name_suffix}}(
7878
X=X, Y=Y, k=k,
@@ -82,8 +82,8 @@ cdef class ArgKmin{{name_suffix}}(BaseDistancesReduction{{name_suffix}}):
8282
metric_kwargs=metric_kwargs,
8383
)
8484
else:
85-
# Fall back on a generic implementation that handles most scipy
86-
# metrics by computing the distances between 2 vectors at a time.
85+
# Fall back on a generic implementation that handles most scipy
86+
# metrics by computing the distances between 2 vectors at a time.
8787
pda = ArgKmin{{name_suffix}}(
8888
datasets_pair=DatasetsPair{{name_suffix}}.get_for(X, Y, metric, metric_kwargs),
8989
k=k,
@@ -347,21 +347,16 @@ cdef class EuclideanArgKmin{{name_suffix}}(ArgKmin{{name_suffix}}):
347347
strategy=strategy,
348348
k=k,
349349
)
350-
# X and Y are checked by the DatasetsPair{{name_suffix}} implemented
351-
# as a DenseDenseDatasetsPair{{name_suffix}}
352350
cdef:
353-
DenseDenseDatasetsPair{{name_suffix}} datasets_pair = (
354-
<DenseDenseDatasetsPair{{name_suffix}}> self.datasets_pair
355-
)
356351
ITYPE_t dist_middle_terms_chunks_size = self.Y_n_samples_chunk * self.X_n_samples_chunk
357352

358-
self.gemm_term_computer = GEMMTermComputer{{name_suffix}}(
359-
datasets_pair.X,
360-
datasets_pair.Y,
353+
self.middle_term_computer = MiddleTermComputer{{name_suffix}}.get_for(
354+
X,
355+
Y,
361356
self.effective_n_threads,
362357
self.chunks_n_threads,
363358
dist_middle_terms_chunks_size,
364-
n_features=datasets_pair.X.shape[1],
359+
n_features=X.shape[1],
365360
chunk_size=self.chunk_size,
366361
)
367362

@@ -373,12 +368,16 @@ cdef class EuclideanArgKmin{{name_suffix}}(ArgKmin{{name_suffix}}):
373368
dtype=np.float64
374369
)
375370
else:
376-
self.Y_norm_squared = _sqeuclidean_row_norms{{name_suffix}}(datasets_pair.Y, self.effective_n_threads)
371+
self.Y_norm_squared = _sqeuclidean_row_norms{{name_suffix}}(
372+
Y, self.effective_n_threads
373+
)
377374

378375
# Do not recompute norms if datasets are identical.
379376
self.X_norm_squared = (
380377
self.Y_norm_squared if X is Y else
381-
_sqeuclidean_row_norms{{name_suffix}}(datasets_pair.X, self.effective_n_threads)
378+
_sqeuclidean_row_norms{{name_suffix}}(
379+
X, self.effective_n_threads
380+
)
382381
)
383382
self.use_squared_distances = use_squared_distances
384383

@@ -393,8 +392,7 @@ cdef class EuclideanArgKmin{{name_suffix}}(ArgKmin{{name_suffix}}):
393392
ITYPE_t thread_num,
394393
) nogil:
395394
ArgKmin{{name_suffix}}._parallel_on_X_parallel_init(self, thread_num)
396-
self.gemm_term_computer._parallel_on_X_parallel_init(thread_num)
397-
395+
self.middle_term_computer._parallel_on_X_parallel_init(thread_num)
398396

399397
@final
400398
cdef void _parallel_on_X_init_chunk(
@@ -404,8 +402,7 @@ cdef class EuclideanArgKmin{{name_suffix}}(ArgKmin{{name_suffix}}):
404402
ITYPE_t X_end,
405403
) nogil:
406404
ArgKmin{{name_suffix}}._parallel_on_X_init_chunk(self, thread_num, X_start, X_end)
407-
self.gemm_term_computer._parallel_on_X_init_chunk(thread_num, X_start, X_end)
408-
405+
self.middle_term_computer._parallel_on_X_init_chunk(thread_num, X_start, X_end)
409406

410407
@final
411408
cdef void _parallel_on_X_pre_compute_and_reduce_distances_on_chunks(
@@ -422,18 +419,16 @@ cdef class EuclideanArgKmin{{name_suffix}}(ArgKmin{{name_suffix}}):
422419
Y_start, Y_end,
423420
thread_num,
424421
)
425-
self.gemm_term_computer._parallel_on_X_pre_compute_and_reduce_distances_on_chunks(
422+
self.middle_term_computer._parallel_on_X_pre_compute_and_reduce_distances_on_chunks(
426423
X_start, X_end, Y_start, Y_end, thread_num,
427424
)
428425

429-
430426
@final
431427
cdef void _parallel_on_Y_init(
432428
self,
433429
) nogil:
434430
ArgKmin{{name_suffix}}._parallel_on_Y_init(self)
435-
self.gemm_term_computer._parallel_on_Y_init()
436-
431+
self.middle_term_computer._parallel_on_Y_init()
437432

438433
@final
439434
cdef void _parallel_on_Y_parallel_init(
@@ -443,8 +438,7 @@ cdef class EuclideanArgKmin{{name_suffix}}(ArgKmin{{name_suffix}}):
443438
ITYPE_t X_end,
444439
) nogil:
445440
ArgKmin{{name_suffix}}._parallel_on_Y_parallel_init(self, thread_num, X_start, X_end)
446-
self.gemm_term_computer._parallel_on_Y_parallel_init(thread_num, X_start, X_end)
447-
441+
self.middle_term_computer._parallel_on_Y_parallel_init(thread_num, X_start, X_end)
448442

449443
@final
450444
cdef void _parallel_on_Y_pre_compute_and_reduce_distances_on_chunks(
@@ -461,11 +455,10 @@ cdef class EuclideanArgKmin{{name_suffix}}(ArgKmin{{name_suffix}}):
461455
Y_start, Y_end,
462456
thread_num,
463457
)
464-
self.gemm_term_computer._parallel_on_Y_pre_compute_and_reduce_distances_on_chunks(
458+
self.middle_term_computer._parallel_on_Y_pre_compute_and_reduce_distances_on_chunks(
465459
X_start, X_end, Y_start, Y_end, thread_num
466460
)
467461

468-
469462
@final
470463
cdef void _compute_and_reduce_distances_on_chunks(
471464
self,
@@ -477,10 +470,9 @@ cdef class EuclideanArgKmin{{name_suffix}}(ArgKmin{{name_suffix}}):
477470
) nogil:
478471
cdef:
479472
ITYPE_t i, j
480-
DTYPE_t squared_dist_i_j
481473
ITYPE_t n_X = X_end - X_start
482474
ITYPE_t n_Y = Y_end - Y_start
483-
DTYPE_t * dist_middle_terms = self.gemm_term_computer._compute_dist_middle_terms(
475+
DTYPE_t * dist_middle_terms = self.middle_term_computer._compute_dist_middle_terms(
484476
X_start, X_end, Y_start, Y_end, thread_num
485477
)
486478
DTYPE_t * heaps_r_distances = self.heaps_r_distances_chunks[thread_num]

sklearn/metrics/_pairwise_distances_reduction/_base.pxd.tp

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,21 +6,16 @@ from ...utils._typedefs cimport ITYPE_t, DTYPE_t
66

77
cnp.import_array()
88

9-
cpdef DTYPE_t[::1] _sqeuclidean_row_norms64(
10-
const DTYPE_t[:, ::1] X,
11-
ITYPE_t num_threads,
12-
)
13-
14-
cpdef DTYPE_t[::1] _sqeuclidean_row_norms32(
15-
const cnp.float32_t[:, ::1] X,
16-
ITYPE_t num_threads,
17-
)
18-
19-
{{for name_suffix in ['64', '32']}}
9+
{{for name_suffix, INPUT_DTYPE_t in [('64', 'DTYPE_t'), ('32', 'cnp.float32_t')]}}
2010

2111
from ._datasets_pair cimport DatasetsPair{{name_suffix}}
2212

2313

14+
cpdef DTYPE_t[::1] _sqeuclidean_row_norms{{name_suffix}}(
15+
const {{INPUT_DTYPE_t}}[:, ::1] X,
16+
ITYPE_t num_threads,
17+
)
18+
2419
cdef class BaseDistancesReduction{{name_suffix}}:
2520
"""
2621
Base float{{name_suffix}} implementation template of the pairwise-distances

sklearn/metrics/_pairwise_distances_reduction/_base.pyx.tp

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,18 @@
1+
{{py:
2+
3+
implementation_specific_values = [
4+
# Values are the following ones:
5+
#
6+
# name_suffix, INPUT_DTYPE_t, INPUT_DTYPE
7+
#
8+
# We also use the float64 dtype and C-type names as defined in
9+
# `sklearn.utils._typedefs` to maintain consistency.
10+
#
11+
('64', 'DTYPE_t', 'DTYPE'),
12+
('32', 'cnp.float32_t', 'np.float32')
13+
]
14+
15+
}}
116
cimport numpy as cnp
217

318
from cython cimport final
@@ -21,7 +36,7 @@ cnp.import_array()
2136

2237
#####################
2338

24-
cpdef DTYPE_t[::1] _sqeuclidean_row_norms64(
39+
cdef DTYPE_t[::1] _sqeuclidean_row_norms64_dense(
2540
const DTYPE_t[:, ::1] X,
2641
ITYPE_t num_threads,
2742
):
@@ -46,7 +61,7 @@ cpdef DTYPE_t[::1] _sqeuclidean_row_norms64(
4661
return squared_row_norms
4762

4863

49-
cpdef DTYPE_t[::1] _sqeuclidean_row_norms32(
64+
cdef DTYPE_t[::1] _sqeuclidean_row_norms32_dense(
5065
const cnp.float32_t[:, ::1] X,
5166
ITYPE_t num_threads,
5267
):
@@ -86,10 +101,19 @@ cpdef DTYPE_t[::1] _sqeuclidean_row_norms32(
86101

87102
return squared_row_norms
88103

89-
{{for name_suffix in ['64', '32']}}
104+
105+
{{for name_suffix, INPUT_DTYPE_t, INPUT_DTYPE in implementation_specific_values}}
90106

91107
from ._datasets_pair cimport DatasetsPair{{name_suffix}}
92108

109+
110+
cpdef DTYPE_t[::1] _sqeuclidean_row_norms{{name_suffix}}(
111+
const {{INPUT_DTYPE_t}}[:, ::1] X,
112+
ITYPE_t num_threads,
113+
):
114+
return _sqeuclidean_row_norms{{name_suffix}}_dense(X, num_threads)
115+
116+
93117
cdef class BaseDistancesReduction{{name_suffix}}:
94118
"""
95119
Base float{{name_suffix}} implementation template of the pairwise-distances
@@ -359,7 +383,7 @@ cdef class BaseDistancesReduction{{name_suffix}}:
359383
In this method, EuclideanDistance specialisations of subclass of
360384
BaseDistancesReduction _must_ call:
361385

362-
self.gemm_term_computer._parallel_on_X_init_chunk(
386+
self.middle_term_computer._parallel_on_X_init_chunk(
363387
thread_num, X_start, X_end,
364388
)
365389

@@ -382,7 +406,7 @@ cdef class BaseDistancesReduction{{name_suffix}}:
382406
In this method, EuclideanDistance specialisations of subclass of
383407
BaseDistancesReduction _must_ call:
384408

385-
self.gemm_term_computer._parallel_on_X_pre_compute_and_reduce_distances_on_chunks(
409+
self.middle_term_computer._parallel_on_X_pre_compute_and_reduce_distances_on_chunks(
386410
X_start, X_end, Y_start, Y_end, thread_num,
387411
)
388412

@@ -425,7 +449,7 @@ cdef class BaseDistancesReduction{{name_suffix}}:
425449
In this method, EuclideanDistance specialisations of subclass of
426450
BaseDistancesReduction _must_ call:
427451

428-
self.gemm_term_computer._parallel_on_Y_parallel_init(
452+
self.middle_term_computer._parallel_on_Y_parallel_init(
429453
thread_num, X_start, X_end,
430454
)
431455

@@ -448,7 +472,7 @@ cdef class BaseDistancesReduction{{name_suffix}}:
448472
In this method, EuclideanDistance specialisations of subclass of
449473
BaseDistancesReduction _must_ call:
450474

451-
self.gemm_term_computer._parallel_on_Y_pre_compute_and_reduce_distances_on_chunks(
475+
self.middle_term_computer._parallel_on_Y_pre_compute_and_reduce_distances_on_chunks(
452476
X_start, X_end, Y_start, Y_end, thread_num,
453477
)
454478

sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def sqeuclidean_row_norms(X, num_threads):
2929
3030
Parameters
3131
----------
32-
X : ndarray of shape (n_samples, n_features)
32+
X : ndarray or CSR matrix of shape (n_samples, n_features)
3333
Input data. Must be c-contiguous.
3434
3535
num_threads : int
@@ -41,9 +41,9 @@ def sqeuclidean_row_norms(X, num_threads):
4141
Arrays containing the squared euclidean norm of each row of X.
4242
"""
4343
if X.dtype == np.float64:
44-
return _sqeuclidean_row_norms64(X, num_threads)
44+
return np.asarray(_sqeuclidean_row_norms64(X, num_threads))
4545
if X.dtype == np.float32:
46-
return _sqeuclidean_row_norms32(X, num_threads)
46+
return np.asarray(_sqeuclidean_row_norms32(X, num_threads))
4747

4848
raise ValueError(
4949
"Only float64 or float32 datasets are supported at this time, "

0 commit comments

Comments
 (0)
0