10000 MAINT Introduce `MiddleTermComputer`, an abstraction generalizing `GEMMTermComputer` by Vincent-Maladiere · Pull Request #24807 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

MAINT Introduce MiddleTermComputer, an abstraction generalizing GEMMTermComputer #24807

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ sklearn/metrics/_pairwise_distances_reduction/_base.pxd
sklearn/metrics/_pairwise_distances_reduction/_base.pyx
sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pxd
sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pyx
sklearn/metrics/_pairwise_distances_reduction/_gemm_term_computer.pxd
sklearn/metrics/_pairwise_distances_reduction/_gemm_term_computer.pyx
sklearn/metrics/_pairwise_distances_reduction/_middle_term_computer.pxd
sklearn/metrics/_pairwise_distances_reduction/_middle_term_computer.pyx
sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors.pxd
sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors.pyx
4 changes: 2 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,8 @@ ignore =
sklearn/metrics/_pairwise_distances_reduction/_base.pyx
sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pxd
sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pyx
sklearn/metrics/_pairwise_distances_reduction/_gemm_term_computer.pxd
sklearn/metrics/_pairwise_distances_reduction/_gemm_term_computer.pyx
sklearn/metrics/_pairwise_distances_reduction/_middle_term_computer.pxd
sklearn/metrics/_pairwise_distances_reduction/_middle_term_computer.pyx
sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors.pxd
sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors.pyx

Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@
"sklearn.manifold._barnes_hut_tsne",
"sklearn.metrics.cluster._expected_mutual_info_fast",
"sklearn.metrics._pairwise_distances_reduction._datasets_pair",
"sklearn.metrics._pairwise_distances_reduction._gemm_term_computer",
"sklearn.metrics._pairwise_distances_reduction._middle_term_computer",
"sklearn.metrics._pairwise_distances_reduction._base",
"sklearn.metrics._pairwise_distances_reduction._argkmin",
"sklearn.metrics._pairwise_distances_reduction._radius_neighbors",
Expand Down Expand Up @@ -316,7 +316,7 @@ def check 10000 _package_status(package, min_version):
"extra_compile_args": ["-std=c++11"],
},
{
"sources": ["_gemm_term_computer.pyx.tp", "_gemm_term_computer.pxd.tp"],
"sources": ["_middle_term_computer.pyx.tp", "_middle_term_computer.pxd.tp"],
"language": "c++",
"include_np": True,
"extra_compile_args": ["-std=c++11"],
Expand Down
4 changes: 2 additions & 2 deletions sklearn/metrics/_pairwise_distances_reduction/_argkmin.pxd.tp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ cnp.import_array()
{{for name_suffix in ['64', '32']}}

from ._base cimport BaseDistancesReduction{{name_suffix}}
from ._gemm_term_computer cimport GEMMTermComputer{{name_suffix}}
from ._middle_term_computer cimport MiddleTermComputer{{name_suffix}}

cdef class ArgKmin{{name_suffix}}(BaseDistancesReduction{{name_suffix}}):
"""float{{name_suffix}} implementation of the ArgKmin."""
Expand All @@ -25,7 +25,7 @@ cdef class ArgKmin{{name_suffix}}(BaseDistancesReduction{{name_suffix}}):
cdef class EuclideanArgKmin{{name_suffix}}(ArgKmin{{name_suffix}}):
"""EuclideanDistance-specialisation of ArgKmin{{name_suffix}}."""
cdef:
GEMMTermComputer{{name_suffix}} gemm_term_computer
MiddleTermComputer{{name_suffix}} middle_term_computer
const DTYPE_t[::1] X_norm_squared
const DTYPE_t[::1] Y_norm_squared

Expand Down
68 changes: 30 additions & 38 deletions sklearn/metrics/_pairwise_distances_reduction/_argkmin.pyx.tp
1E0A
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,9 @@ from ._base cimport (
_sqeuclidean_row_norms{{name_suffix}},
)

from ._datasets_pair cimport (
DatasetsPair{{name_suffix}},
DenseDenseDatasetsPair{{name_suffix}},
)
from ._datasets_pair cimport DatasetsPair{{name_suffix}}

from ._gemm_term_computer cimport GEMMTermComputer{{name_suffix}}
from ._middle_term_computer cimport MiddleTermComputer{{name_suffix}}


cdef class ArgKmin{{name_suffix}}(BaseDistancesReduction{{name_suffix}}):
Expand Down Expand Up @@ -66,13 +63,16 @@ cdef class ArgKmin{{name_suffix}}(BaseDistancesReduction{{name_suffix}}):
"""
if (
metric in ("euclidean", "sqeuclidean")
and not issparse(X)
and not issparse(Y)
and not (issparse(X) or issparse(Y))
):
# Specialized implementation with improved arithmetic intensity
# and vector instructions (SIMD) by processing several vectors
# at time to leverage a call to the BLAS GEMM routine as explained
# in more details in the docstring.
# Specialized implementation of ArgKmin for the Euclidean distance.
# This implementation computes the distances by chunk using
# a decomposition of the Squared Euclidean distance.
# This specialisation has an improved arithmetic intensity for both
# the dense and sparse settings, allowing in most case speed-ups of
# several orders of magnitude compared to the generic ArgKmin
# implementation.
# For more information see MiddleTermComputer.
use_squared_distances = metric == "sqeuclidean"
pda = EuclideanArgKmin{{name_suffix}}(
X=X, Y=Y, k=k,
Expand All @@ -82,8 +82,8 @@ cdef class ArgKmin{{name_suffix}}(BaseDistancesReduction{{name_suffix}}):
metric_kwargs=metric_kwargs,
)
else:
# Fall back on a generic implementation that handles most scipy
# metrics by computing the distances between 2 vectors at a time.
# Fall back on a generic implementation that handles most scipy
# metrics by computing the distances between 2 vectors at a time.
pda = ArgKmin{{name_suffix}}(
datasets_pair=DatasetsPair{{name_suffix}}.get_for(X, Y, metric, metric_kwargs),
k=k,
Expand Down Expand Up @@ -347,21 +347,16 @@ cdef class EuclideanArgKmin{{name_suffix}}(ArgKmin{{name_suffix}}):
strategy=strategy,
k=k,
)
# X and Y are checked by the DatasetsPair{{name_suffix}} implemented
# as a DenseDenseDatasetsPair{{name_suffix}}
cdef:
DenseDenseDatasetsPair{{name_suffix}} datasets_pair = (
<DenseDenseDatasetsPair{{name_suffix}}> self.datasets_pair
)
ITYPE_t dist_middle_terms_chunks_size = self.Y_n_samples_chunk * self.X_n_samples_chunk

self.gemm_term_computer = GEMMTermComputer{{name_suffix}}(
datasets_pair.X,
datasets_pair.Y,
self.middle_term_computer = MiddleTermComputer{{name_suffix}}.get_for(
X,
Y,
self.effective_n_threads,
self.chunks_n_threads,
dist_middle_terms_chunks_size,
n_features=datasets_pair.X.shape[1],
n_features=X.shape[1],
chunk_size=self.chunk_size,
)

Expand All @@ -373,12 +368,16 @@ cdef class EuclideanArgKmin{{name_suffix}}(ArgKmin{{name_suffix}}):
dtype=np.float64
)
else:
self.Y_norm_squared = _sqeuclidean_row_norms{{name_suffix}}(datasets_pair.Y, self.effective_n_threads)
self.Y_norm_squared = _sqeuclidean_row_norms{{name_suffix}}(
Y, self.effective_n_threads
)

# Do not recompute norms if datasets are identical.
self.X_norm_squared = (
self.Y_norm_squared if X is Y else
_sqeuclidean_row_norms{{name_suffix}}(datasets_pair.X, self.effective_n_threads)
_sqeuclidean_row_norms{{name_suffix}}(
X, self.effective_n_threads
)
)
self.use_squared_distances = use_squared_distances

Expand All @@ -393,8 +392,7 @@ cdef class EuclideanArgKmin{{name_suffix}}(ArgKmin{{name_suffix}}):
ITYPE_t thread_num,
) nogil:
ArgKmin{{name_suffix}}._parallel_on_X_parallel_init(self, thread_num)
self.gemm_term_computer._parallel_on_X_parallel_init(thread_num)

self.middle_term_computer._parallel_on_X_parallel_init(thread_num)

@final
cdef void _parallel_on_X_init_chunk(
Expand All @@ -404,8 +402,7 @@ cdef class EuclideanArgKmin{{name_suffix}}(ArgKmin{{name_suffix}}):
ITYPE_t X_end,
) nogil:
ArgKmin{{name_suffix}}._parallel_on_X_init_chunk(self, thread_num, X_start, X_end)
self.gemm_term_computer._parallel_on_X_init_chunk(thread_num, X_start, X_end)

< F438 /span>
self.middle_term_computer._parallel_on_X_init_chunk(thread_num, X_start, X_end)

@final
cdef void _parallel_on_X_pre_compute_and_reduce_distances_on_chunks(
Expand All @@ -422,18 +419,16 @@ cdef class EuclideanArgKmin{{name_suffix}}(ArgKmin{{name_suffix}}):
Y_start, Y_end,
thread_num,
)
self.gemm_term_computer._parallel_on_X_pre_compute_and_reduce_distances_on_chunks(
self.middle_term_computer._parallel_on_X_pre_compute_and_reduce_distances_on_chunks(
X_start, X_end, Y_start, Y_end, thread_num,
)


@final
cdef void _parallel_on_Y_init(
self,
) nogil:
ArgKmin{{name_suffix}}._parallel_on_Y_init(self)
self.gemm_term_computer._parallel_on_Y_init()

self.middle_term_computer._parallel_on_Y_init()

@final
cdef void _parallel_on_Y_parallel_init(
Expand All @@ -443,8 +438,7 @@ cdef class EuclideanArgKmin{{name_suffix}}(ArgKmin{{name_suffix}}):
ITYPE_t X_end,
) nogil:
ArgKmin{{name_suffix}}._parallel_on_Y_parallel_init(self, thread_num, X_start, X_end)
self.gemm_term_computer._parallel_on_Y_parallel_init(thread_num, X_start, X_end)

self.middle_term_computer._parallel_on_Y_parallel_init(thread_num, X_start, X_end)

@final
cdef void _parallel_on_Y_pre_compute_and_reduce_distances_on_chunks(
Expand All @@ -461,11 +455,10 @@ cdef class EuclideanArgKmin{{name_suffix}}(ArgKmin{{name_suffix}}):
Y_start, Y_end,
thread_num,
)
self.gemm_term_computer._parallel_on_Y_pre_compute_and_reduce_distances_on_chunks(
self.middle_term_computer._parallel_on_Y_pre_compute_and_reduce_distances_on_chunks(
X_start, X_end, Y_start, Y_end, thread_num
)


@final
cdef void _compute_and_reduce_distances_on_chunks(
self,
Expand All @@ -477,10 +470,9 @@ cdef class EuclideanArgKmin{{name_suffix}}(ArgKmin{{name_suffix}}):
) nogil:
cdef:
ITYPE_t i, j
DTYPE_t squared_dist_i_j
ITYPE_t n_X = X_end - X_start
ITYPE_t n_Y = Y_end - Y_start
DTYPE_t * dist_middle_terms = self.gemm_term_computer._compute_dist_middle_terms(
DTYPE_t * dist_middle_terms = self.middle_term_computer._compute_dist_middle_terms(
X_start, X_end, Y_start, Y_end, thread_num
)
DTYPE_t * heaps_r_distances = self.heaps_r_distances_chunks[thread_num]
Expand Down
17 changes: 6 additions & 11 deletions sklearn/metrics/_pairwise_distances_reduction/_base.pxd.tp
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,16 @@ from ...utils._typedefs cimport ITYPE_t, DTYPE_t

cnp.import_array()

cpdef DTYPE_t[::1] _sqeuclidean_row_norms64(
const DTYPE_t[:, ::1] X,
ITYPE_t num_threads,
)

cpdef DTYPE_t[::1] _sqeuclidean_row_norms32(
const cnp.float32_t[:, ::1] X,
ITYPE_t num_threads,
)

{{for name_suffix in ['64', '32']}}
{{for name_suffix, INPUT_DTYPE_t in [('64', 'DTYPE_t'), ('32', 'cnp.float32_t')]}}

from ._datasets_pair cimport DatasetsPair{{name_suffix}}


cpdef DTYPE_t[::1] _sqeuclidean_row_norms{{name_suffix}}(
const {{INPUT_DTYPE_t}}[:, ::1] X,
ITYPE_t num_threads,
)

cdef class BaseDistancesReduction{{name_suffix}}:
"""
Base float{{name_suffix}} implementation template of the pairwise-distances
Expand Down
38 changes: 31 additions & 7 deletions sklearn/metrics/_pairwise_distances_reduction/_base.pyx.tp
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
{{py:

implementation_specific_values = [
# Values are the following ones:
#
# name_suffix, INPUT_DTYPE_t, INPUT_DTYPE
#
# We also use the float64 dtype and C-type names as defined in
# `sklearn.utils._typedefs` to maintain consistency.
#
('64', 'DTYPE_t', 'DTYPE'),
('32', 'cnp.float32_t', 'np.float32')
]

}}
cimport numpy as cnp

from cython cimport final
Expand All @@ -21,7 +36,7 @@ cnp.import_array()

#####################

cpdef DTYPE_t[::1] _sqeuclidean_row_norms64(
cdef DTYPE_t[::1] _sqeuclidean_row_norms64_dense(
const DTYPE_t[:, ::1] X,
ITYPE_t num_threads,
):
Expand All @@ -46,7 +61,7 @@ cpdef DTYPE_t[::1] _sqeuclidean_row_norms64(
return squared_row_norms


cpdef DTYPE_t[::1] _sqeuclidean_row_norms32(
cdef DTYPE_t[::1] _sqeuclidean_row_norms32_dense(
const cnp.float32_t[:, ::1] X,
ITYPE_t num_threads,
):
Expand Down Expand Up @@ -86,10 +101,19 @@ cpdef DTYPE_t[::1] _sqeuclidean_row_norms32(

return squared_row_norms

{{for name_suffix in ['64', '32']}}

{{for name_suffix, INPUT_DTYPE_t, INPUT_DTYPE in implementation_specific_values}}

from ._datasets_pair cimport DatasetsPair{{name_suffix}}


cpdef DTYPE_t[::1] _sqeuclidean_row_norms{{name_suffix}}(
const {{INPUT_DTYPE_t}}[:, ::1] X,
ITYPE_t num_threads,
):
return _sqeuclidean_row_norms{{name_suffix}}_dense(X, num_threads)


cdef class BaseDistancesReduction{{name_suffix}}:
"""
Base float{{name_suffix}} implementation template of the pairwise-distances
Expand Down Expand Up @@ -359,7 +383,7 @@ cdef class BaseDistancesReduction{{name_suffix}}:
In this method, EuclideanDistance specialisations of subclass of
BaseDistancesReduction _must_ call:

self.gemm_term_computer._parallel_on_X_init_chunk(
self.middle_term_computer._parallel_on_X_init_chunk(
thread_num, X_start, X_end,
)

Expand All @@ -382,7 +406,7 @@ cdef class BaseDistancesReduction{{name_suffix}}:
In this method, EuclideanDistance specialisations of subclass of
BaseDistancesReduction _must_ call:

self.gemm_term_computer._parallel_on_X_pre_compute_and_reduce_distances_on_chunks(
self.middle_term_computer._parallel_on_X_pre_compute_and_reduce_distances_on_chunks(
X_start, X_end, Y_start, Y_end, thread_num,
)

Expand Down Expand Up @@ -425,7 +449,7 @@ cdef class BaseDistancesReduction{{name_suffix}}:
In this method, EuclideanDistance specialisations of subclass of
BaseDistancesReduction _must_ call:

self.gemm_term_computer._parallel_on_Y_parallel_init(
self.middle_term_computer._parallel_on_Y_parallel_init(
thread_num, X_start, X_end,
)

Expand All @@ -448,7 +472,7 @@ cdef class BaseDistancesReduction{{name_suffix}}:
In this method, EuclideanDistance specialisations of subclass of
BaseDistancesReduction _must_ call:

self.gemm_term_computer._parallel_on_Y_pre_compute_and_reduce_distances_on_chunks(
self.middle_term_computer._parallel_on_Y_pre_compute_and_reduce_distances_on_chunks(
X_start, X_end, Y_start, Y_end, thread_num,
)

Expand Down
6 changes: 3 additions & 3 deletions sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def sqeuclidean_row_norms(X, num_threads):

Parameters
----------
X : ndarray of shape (n_samples, n_features)
X : ndarray or CSR matrix of shape (n_samples, n_features)
Input data. Must be c-contiguous.

num_threads : int
Expand All @@ -41,9 +41,9 @@ def sqeuclidean_row_norms(X, num_threads):
Arrays containing the squared euclidean norm of each row of X.
"""
if X.dtype == np.float64:
return _sqeuclidean_row_norms64(X, num_threads)
return np.asarray(_sqeuclidean_row_norms64(X, num_threads))
if X.dtype == np.float32:
return _sqeuclidean_row_norms32(X, num_threads)
return np.asarray(_sqeuclidean_row_norms32(X, num_threads))

raise ValueError(
"Only float64 or float32 datasets are supported at this time, "
Expand Down
Loading
0