From 0383b6d1e0a7ed19617aeef328ca1162ebdc1736 Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Wed, 22 Jun 2022 17:03:46 +0200 Subject: [PATCH 01/36] Introduce PairwiseDistances --- .../metrics/_pairwise_distances_reduction.pyx | 552 +++++++++++++++++- sklearn/metrics/pairwise.py | 5 +- 2 files changed, 552 insertions(+), 5 deletions(-) diff --git a/sklearn/metrics/_pairwise_distances_reduction.pyx b/sklearn/metrics/_pairwise_distances_reduction.pyx index 9606eb1273ce8..d1fa982f1cfa4 100644 --- a/sklearn/metrics/_pairwise_distances_reduction.pyx +++ b/sklearn/metrics/_pairwise_distances_reduction.pyx @@ -206,10 +206,14 @@ class PairwiseDistancesReduction: ------- True if the PairwiseDistancesReduction can be used, else False. """ - dtypes_validity = X.dtype == Y.dtype and Y.dtype == np.float64 - return (get_config().get("enable_cython_pairwise_dist", True) and - not issparse(X) and not issparse(Y) and dtypes_validity and - metric in cls.valid_metrics()) + try: + Y = X if Y is None else Y + dtypes_validity = X.dtype == Y.dtype and Y.dtype == np.float64 + return (get_config().get("enable_cython_pairwise_dist", True) and + not issparse(X) and not issparse(Y) and dtypes_validity and + metric in cls.valid_metrics()) + except Exception: + return False @classmethod @abstractmethod @@ -502,6 +506,244 @@ class PairwiseDistancesRadiusNeighborhood(PairwiseDistancesReduction): f"got: X.dtype={X.dtype} and Y.dtype={Y.dtype}." ) +class PairwiseDistancesArgKmin(PairwiseDistancesReduction): + """Compute the argkmin of row vectors of X on the ones of Y. + + For each row vector of X, computes the indices of k first the rows + vectors of Y with the smallest distances. + + PairwiseDistancesArgKmin is typically used to perform + bruteforce k-nearest neighbors queries. + + This class is not meant to be instanciated, one should only use + its :meth:`compute` classmethod which handles allocation and + deallocation consistently. + """ + + @classmethod + def compute( + cls, + X, + Y, + k, + metric="euclidean", + chunk_size=None, + metric_kwargs=None, + strategy=None, + return_distance=False, + ): + """Compute the argkmin reduction. + + Parameters + ---------- + X : ndarray or CSR matrix of shape (n_samples_X, n_features) + Input data. + + Y : ndarray or CSR matrix of shape (n_samples_Y, n_features) + Input data. + + k : int + The k for the argkmin reduction. + + metric : str, default='euclidean' + The distance metric to use for argkmin. + For a list of available metrics, see the documentation of + :class:`~sklearn.metrics.DistanceMetric`. + + chunk_size : int, default=None, + The number of vectors per chunk. If None (default) looks-up in + scikit-learn configuration for `pairwise_dist_chunk_size`, + and use 256 if it is not set. + + metric_kwargs : dict, default=None + Keyword arguments to pass to specified metric function. + + strategy : str, {'auto', 'parallel_on_X', 'parallel_on_Y'}, default=None + The chunking strategy defining which dataset parallelization are made on. + + For both strategies the computations happens with two nested loops, + respectively on chunks of X and chunks of Y. + Strategies differs on which loop (outer or inner) is made to run + in parallel with the Cython `prange` construct: + + - 'parallel_on_X' dispatches chunks of X uniformly on threads. + Each thread then iterates on all the chunks of Y. This strategy is + embarrassingly parallel and comes with no datastructures synchronisation. + + - 'parallel_on_Y' dispatches chunks of Y uniformly on threads. + Each thread processes all the chunks of X in turn. This strategy is + a sequence of embarrassingly parallel subtasks (the inner loop on Y + chunks) with intermediate datastructures synchronisation at each + iteration of the sequential outer loop on X chunks. + + - 'auto' relies on a simple heuristic to choose between + 'parallel_on_X' and 'parallel_on_Y': when `X.shape[0]` is large enough, + 'parallel_on_X' is usually the most efficient strategy. When `X.shape[0]` + is small but `Y.shape[0]` is large, 'parallel_on_Y' brings more opportunity + for parallelism and is therefore more efficient despite the synchronization + step at each iteration of the outer loop on chunks of `X`. + + - None (default) looks-up in scikit-learn configuration for + `pairwise_dist_parallel_strategy`, and use 'auto' if it is not set. + + return_distance : boolean, default=False + Return distances between each X vector and its + argkmin if set to True. + + Returns + ------- + If return_distance=False: + - argkmin_indices : ndarray of shape (n_samples_X, k) + Indices of the argkmin for each vector in X. + + If return_distance=True: + - argkmin_distances : ndarray of shape (n_samples_X, k) + Distances to the argkmin for each vector in X. + - argkmin_indices : ndarray of shape (n_samples_X, k) + Indices of the argkmin for each vector in X. + + Notes + ----- + This classmethod is responsible for introspecting the arguments + values to dispatch to the most appropriate implementation of + :class:`PairwiseDistancesArgKmin`. + + This allows decoupling the API entirely from the implementation details + whilst maintaining RAII: all temporarily allocated datastructures necessary + for the concrete implementation are therefore freed when this classmethod + returns. + """ + # Note (jjerphan): Some design thoughts for future extensions. + # This factory comes to handle specialisations for the given arguments. + # For future work, this might can be an entrypoint to specialise operations + # for various backend and/or hardware and/or datatypes, and/or fused + # {sparse, dense}-datasetspair etc. + if X.dtype == Y.dtype == np.float64: + return PairwiseDistancesArgKmin64.compute( + X=X, + Y=Y, + k=k, + metric=metric, + chunk_size=chunk_size, + metric_kwargs=metric_kwargs, + strategy=strategy, + return_distance=return_distance, + ) + raise ValueError( + f"Only 64bit float datasets are supported at this time, " + f"got: X.dtype={X.dtype} and Y.dtype={Y.dtype}." + ) + + +class PairwiseDistances(PairwiseDistancesReduction): + """Compute the pairwise distances matrix for two sets of vectors. + + The distance function `dist` depends on the values of the `metric` + and `metric_kwargs` parameters. + + This class is not meant to be instanciated, one should only use + its :meth:`compute` classmethod which handles allocation and + deallocation consistently. + """ + + @classmethod + def compute( + cls, + X, + Y, + metric="euclidean", + chunk_size=None, + metric_kwargs=None, + strategy=None, + ): + """Return pairwise distances matrix for the given arguments. + + Parameters + ---------- + X : ndarray or CSR matrix of shape (n_samples_X, n_features) + Input data. + + Y : ndarray or CSR matrix of shape (n_samples_Y, n_features) + Input data. + + metric : str, default='euclidean' + The distance metric to use. + For a list of available metrics, see the documentation of + :class:`~sklearn.metrics.DistanceMetric`. + + chunk_size : int, default=None, + The number of vectors per chunk. If None (default) looks-up in + scikit-learn configuration for `pairwise_dist_chunk_size`, + and use 256 if it is not set. + + metric_kwargs : dict, default=None + Keyword arguments to pass to specified metric function. + + strategy : str, {'auto', 'parallel_on_X', 'parallel_on_Y'}, default=None + The chunking strategy defining which dataset parallelization are made on. + + For both strategies the computations happens with two nested loops, + respectively on chunks of X and chunks of Y. + Strategies differs on which loop (outer or inner) is made to run + in parallel with the Cython `prange` construct: + + - 'parallel_on_X' dispatches chunks of X uniformly on threads. + Each thread then iterates on all the chunks of Y. This strategy is + embarrassingly parallel and comes with no datastructures synchronisation. + + - 'parallel_on_Y' dispatches chunks of Y uniformly on threads. + Each thread processes all the chunks of X in turn. This strategy is + a sequence of embarrassingly parallel subtasks (the inner loop on Y + chunks) with intermediate datastructures synchronisation at each + iteration of the sequential outer loop on X chunks. + + - 'auto' relies on a simple heuristic to choose between + 'parallel_on_X' and 'parallel_on_Y': when `X.shape[0]` is large enough, + 'parallel_on_X' is usually the most efficient strategy. When `X.shape[0]` + is small but `Y.shape[0]` is large, 'parallel_on_Y' brings more opportunity + for parallelism and is therefore more efficient despite the synchronization + step at each iteration of the outer loop on chunks of `X`. + + - None (default) looks-up in scikit-learn configuration for + `pairwise_dist_parallel_strategy`, and use 'auto' if it is not set. + + Returns + ------- + pairwise_distances_matrix : ndarray of shape (n_samples_X, n_samples_Y) + The pairwise distances matrix. + + Notes + ----- + This public classmethod is responsible for introspecting the arguments + values to dispatch to the private dtype-specialized implementation of + :class:`PairwiseDistances`. + + All temporarily allocated datastructures necessary for the concrete + implementation are therefore freed when this classmethod returns. + + This allows entirely decoupling the API entirely from the + implementation details whilst maintaining RAII. + """ + # Note (jjerphan): Some design thoughts for future extensions. + # This factory comes to handle specialisations for the given arguments. + # For future work, this might can be an entrypoint to specialise operations + # for various backend and/or hardware and/or datatypes, and/or fused + # {sparse, dense}-datasetspair etc. + Y = X if Y is None else Y + if X.dtype == Y.dtype == np.float64: + return PairwiseDistances64.compute( + X=X, + Y=Y, + metric=metric, + chunk_size=chunk_size, + metric_kwargs=metric_kwargs, + strategy=strategy, + ) + raise ValueError( + f"Only 64bit float datasets are supported at this time, " + f"got: X.dtype={X.dtype} and Y.dtype={Y.dtype}." + ) + ##################### # dtype-specialized implementations @@ -1990,3 +2232,305 @@ cdef class FastEuclideanPairwiseDistancesRadiusNeighborhood64(PairwiseDistancesR if squared_dist_i_j <= self.r_radius: deref(self.neigh_distances_chunks[thread_num])[i + X_start].push_back(squared_dist_i_j) deref(self.neigh_indices_chunks[thread_num])[i + X_start].push_back(j + Y_start) + + +cdef class PairwiseDistances64(PairwiseDistancesReduction64): + """64bit implementation of PairwiseDistances.""" + + cdef: + DTYPE_t[:, ::1] pairwise_distances_matrix + + @classmethod + def compute( + cls, + X, + Y, + str metric="euclidean", + chunk_size=None, + dict metric_kwargs=None, + str strategy=None, + ): + """Compute the pairwise-distances matrix. + + This classmethod is responsible for introspecting the arguments + values to dispatch to the most appropriate implementation of + :class:`PairwiseDistances64`. + + This allows decoupling the API entirely from the implementation details + whilst maintaining RAII: all temporarily allocated datastructures necessary + for the concrete implementation are therefore freed when this classmethod + returns. + + No instance should directly be created outside of this class method. + """ + if ( + metric in ("euclidean", "sqeuclidean") + and not issparse(X) + and not issparse(Y) + ): + # Specialized implementation with improved arithmetic intensity + # and vector instructions (SIMD) by processing several vectors + # at time to leverage a call to the BLAS GEMM routine as explained + # in more details in the docstring. + use_squared_distances = metric == "sqeuclidean" + pdr = FastEuclideanPairwiseDistances64( + X=X, Y=Y, + use_squared_distances=use_squared_distances, + chunk_size=chunk_size, + metric_kwargs=metric_kwargs, + strategy=strategy, + ) + else: + # Fall back on a generic implementation that handles most scipy + # metrics by computing the distances between 2 vectors at a time. + pdr = PairwiseDistances64( + datasets_pair=DatasetsPair.get_for(X, Y, metric, metric_kwargs), + chunk_size=chunk_size, + metric_kwargs=metric_kwargs, + strategy=strategy, + ) + + # Limit the number of threads in second level of nested parallelism for BLAS + # to avoid threads over-subscription (in GEMM for instance). + with threadpool_limits(limits=1, user_api="blas"): + if pdr.execute_in_parallel_on_Y: + pdr._parallel_on_Y() + else: + pdr._parallel_on_X() + + return pdr._finalize_results() + + + def __init__( + self, + DatasetsPair datasets_pair, + chunk_size=None, + strategy=None, + sort_results=False, + metric_kwargs=None, + ): + super().__init__( + datasets_pair=datasets_pair, + chunk_size=chunk_size, + strategy=strategy, + ) + + # Distance matrix which will be complete and returned to the caller. + self.pairwise_distances_matrix = np.empty( + (self.n_samples_X, self.n_samples_Y), dtype=DTYPE, + ) + + def _finalize_results(self): + self.compute_exact_distances() + return np.asarray(self.pairwise_distances_matrix) + + cdef void _compute_and_reduce_distances_on_chunks( + self, + ITYPE_t X_start, + ITYPE_t X_end, + ITYPE_t Y_start, + ITYPE_t Y_end, + ITYPE_t thread_num, + ) nogil: + cdef: + ITYPE_t i, j + DTYPE_t r_dist_i_j + + for i in range(X_start, X_end): + for j in range(Y_start, Y_end): + r_dist_i_j = self.datasets_pair.surrogate_dist(i, j) + self.pairwise_distances_matrix[X_start + i, Y_start + j] = r_dist_i_j + + cdef void compute_exact_distances(self) nogil: + """Convert rank-preserving distances to pairwise distances in parallel.""" + cdef: + ITYPE_t i, j + + for i in prange(self.n_samples_X, nogil=True, schedule='static', + num_threads=self.effective_n_threads): + for j in range(self.n_samples_Y): + self.pairwise_distances_matrix[i, j] = ( + self.datasets_pair.distance_metric._rdist_to_dist( + # Guard against eventual -0., causing nan production. + max(self.pairwise_distances_matrix[i, j], 0.) + ) + ) + +cdef class FastEuclideanPairwiseDistances64(PairwiseDistances64): + """EuclideanDistance-specialized 64bit implementation for PairwiseDistances.""" + cdef: + GEMMTermComputer64 gemm_term_computer + const DTYPE_t[::1] X_norm_squared + const DTYPE_t[::1] Y_norm_squared + + bint use_squared_distances + + @classmethod + def is_usable_for(cls, X, Y, metric) -> bool: + return (PairwiseDistances64.is_usable_for(X, Y, metric) + and not _in_unstable_openblas_configuration()) + + def __init__( + self, + X, + Y, + bint use_squared_distances=False, + chunk_size=None, + strategy=None, + metric_kwargs=None, + ): + if ( + metric_kwargs is not None and + len(metric_kwargs) > 0 and + "Y_norm_squared" not in metric_kwargs + ): + warnings.warn( + f"Some metric_kwargs have been passed ({metric_kwargs}) but aren't " + f"usable for this case (FastEuclideanPairwiseDistances64) and will be ignored.", + UserWarning, + stacklevel=3, + ) + + super().__init__( + # The datasets pair here is used for exact distances computations + datasets_pair=DatasetsPair.get_for(X, Y, metric="euclidean"), + chunk_size=chunk_size, + strategy=strategy, + metric_kwargs=metric_kwargs, + ) + # X and Y are checked by the DatasetsPair implemented as a DenseDenseDatasetsPair + cdef: + DenseDenseDatasetsPair datasets_pair = self.datasets_pair + ITYPE_t dist_middle_terms_chunks_size = self.Y_n_samples_chunk * self.X_n_samples_chunk + + self.gemm_term_computer = GEMMTermComputer64( + datasets_pair.X, + datasets_pair.Y, + self.effective_n_threads, + self.chunks_n_threads, + dist_middle_terms_chunks_size, + n_features=datasets_pair.X.shape[1], + chunk_size=self.chunk_size, + ) + + if metric_kwargs is not None and "Y_norm_squared" in metric_kwargs: + self.Y_norm_squared = metric_kwargs.pop("Y_norm_squared") + else: + self.Y_norm_squared = _sqeuclidean_row_norms64(datasets_pair.Y, self.effective_n_threads) + + # Do not recompute norms if datasets are identical. + self.X_norm_squared = ( + self.Y_norm_squared if X is Y else + _sqeuclidean_row_norms64(datasets_pair.X, self.effective_n_threads) + ) + self.use_squared_distances = use_squared_distances + + + @final + cdef void _parallel_on_X_parallel_init( + self, + ITYPE_t thread_num, + ) nogil: + PairwiseDistances64._parallel_on_X_parallel_init(self, thread_num) + self.gemm_term_computer._parallel_on_X_parallel_init(thread_num) + + @final + cdef void _parallel_on_X_init_chunk( + self, + ITYPE_t thread_num, + ITYPE_t X_start, + ITYPE_t X_end, + ) nogil: + PairwiseDistances64._parallel_on_X_init_chunk(self, thread_num, X_start, X_end) + self.gemm_term_computer._parallel_on_X_init_chunk(thread_num, X_start, X_end) + + @final + cdef void _parallel_on_X_pre_compute_and_reduce_distances_on_chunks( + self, + ITYPE_t X_start, + ITYPE_t X_end, + ITYPE_t Y_start, + ITYPE_t Y_end, + ITYPE_t thread_num, + ) nogil: + PairwiseDistances64._parallel_on_X_pre_compute_and_reduce_distances_on_chunks( + self, + X_start, X_end, + Y_start, Y_end, + thread_num, + ) + self.gemm_term_computer._parallel_on_X_pre_compute_and_reduce_distances_on_chunks( + X_start, X_end, Y_start, Y_end, thread_num, + ) + + @final + cdef void _parallel_on_Y_init( + self, + ) nogil: + cdef ITYPE_t thread_num + PairwiseDistances64._parallel_on_Y_init(self) + self.gemm_term_computer._parallel_on_Y_init() + + @final + cdef void _parallel_on_Y_parallel_init( + self, + ITYPE_t thread_num, + ITYPE_t X_start, + ITYPE_t X_end, + ) nogil: + PairwiseDistances64._parallel_on_Y_parallel_init(self, thread_num, X_start, X_end) + self.gemm_term_computer._parallel_on_Y_parallel_init(thread_num, X_start, X_end) + + @final + cdef void _parallel_on_Y_pre_compute_and_reduce_distances_on_chunks( + self, + ITYPE_t X_start, + ITYPE_t X_end, + ITYPE_t Y_start, + ITYPE_t Y_end, + ITYPE_t thread_num, + ) nogil: + PairwiseDistances64._parallel_on_Y_pre_compute_and_reduce_distances_on_chunks( + self, + X_start, X_end, + Y_start, Y_end, + thread_num, + ) + self.gemm_term_computer._parallel_on_Y_pre_compute_and_reduce_distances_on_chunks( + X_start, X_end, Y_start, Y_end, thread_num + ) + + @final + cdef void compute_exact_distances(self) nogil: + if not self.use_squared_distances: + PairwiseDistances64.compute_exact_distances(self) + + @final + cdef void _compute_and_reduce_distances_on_chunks( + self, + ITYPE_t X_start, + ITYPE_t X_end, + ITYPE_t Y_start, + ITYPE_t Y_end, + ITYPE_t thread_num, + ) nogil: + cdef: + ITYPE_t i, j + DTYPE_t squared_dist_i_j + ITYPE_t n_X = X_end - X_start + ITYPE_t n_Y = Y_end - Y_start + DTYPE_t *dist_middle_terms = self.gemm_term_computer._compute_distances_on_chunks( + X_start, X_end, Y_start, Y_end, thread_num + ) + + for i in range(n_X): + for j in range(n_Y): + # Using the squared euclidean distance as the rank-preserving distance: + # + # ||X_c_i||² - 2 X_c_i.Y_c_j^T + ||Y_c_j||² + # + self.pairwise_distances_matrix[i + X_start, j + Y_start] = ( + self.X_norm_squared[i + X_start] + + dist_middle_terms[i * n_Y + j] + + self.Y_norm_squared[j + Y_start] + ) diff --git a/sklearn/metrics/pairwise.py b/sklearn/metrics/pairwise.py index 3e3dabbdacde6..3e8397ca5f6c2 100644 --- a/sklearn/metrics/pairwise.py +++ b/sklearn/metrics/pairwise.py @@ -30,7 +30,7 @@ from ..utils.fixes import delayed from ..utils.fixes import sp_version, parse_version -from ._pairwise_distances_reduction import PairwiseDistancesArgKmin +from ._pairwise_distances_reduction import PairwiseDistancesArgKmin, PairwiseDistances from ._pairwise_fast import _chi2_kernel_fast, _sparse_manhattan from ..exceptions import DataConversionWarning @@ -1945,6 +1945,9 @@ def pairwise_distances( % (metric, _VALID_METRICS) ) + if PairwiseDistances.is_usable_for(X, Y, metric=metric): + return PairwiseDistances.compute(X, Y, metric=metric, metric_kwargs=kwds) + if metric == "precomputed": X, _ = check_pairwise_arrays( X, Y, precomputed=True, force_all_finite=force_all_finite From 5889e6fd8441ae954aa6f69466027ba3e2eb31c1 Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Mon, 4 Jul 2022 10:06:10 +0200 Subject: [PATCH 02/36] WIP --- .../metrics/_pairwise_distances_reduction/_dispatcher.py | 6 ++++-- sklearn/metrics/pairwise.py | 6 ++++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py b/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py index 7a16c61f03a8f..4b750678ae710 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py +++ b/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py @@ -12,6 +12,7 @@ from ._radius_neighborhood import PairwiseDistancesRadiusNeighborhood64 from ... import get_config +from ...utils.validation import _is_arraylike_not_scalar def sqeuclidean_row_norms(X, num_threads): @@ -80,12 +81,13 @@ def is_usable_for(cls, X, Y, metric) -> bool: ------- True if the PairwiseDistancesReduction can be used, else False. """ - dtypes_validity = X.dtype == Y.dtype == np.float64 return ( get_config().get("enable_cython_pairwise_dist", True) + and _is_arraylike_not_scalar(X) + and _is_arraylike_not_scalar(Y) + and X.dtype == Y.dtype == np.float64 and not issparse(X) and not issparse(Y) - and dtypes_validity and metric in cls.valid_metrics() ) diff --git a/sklearn/metrics/pairwise.py b/sklearn/metrics/pairwise.py index 2264c172bb16a..7233e80bb2c2b 100644 --- a/sklearn/metrics/pairwise.py +++ b/sklearn/metrics/pairwise.py @@ -1945,8 +1945,10 @@ def pairwise_distances( % (metric, _VALID_METRICS) ) - if PairwiseDistances.is_usable_for(X, Y, metric=metric): - return PairwiseDistances.compute(X, Y, metric=metric, metric_kwargs=kwds) + if PairwiseDistances.is_usable_for(X, X if Y is None else Y, metric=metric): + return PairwiseDistances.compute( + X, X if Y is None else Y, metric=metric, metric_kwargs=kwds + ) if metric == "precomputed": X, _ = check_pairwise_arrays( From 9101daf845a5abb9295ac4f705d5945fd0a9835a Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Mon, 11 Jul 2022 10:29:55 +0200 Subject: [PATCH 03/36] fixup! Introduce PairwiseDistances --- .../_pairwise_distances_reduction/__init__.py | 4 +- .../_pairwise_distances_reduction/_base.pxd | 1 + .../_datasets_pair.pxd | 4 +- .../_datasets_pair.pyx | 5 +- .../_dispatcher.py | 2 + .../_pairwise_distances.pyx | 84 ++++++++++++------- sklearn/metrics/pairwise.py | 33 ++------ 7 files changed, 73 insertions(+), 60 deletions(-) diff --git a/sklearn/metrics/_pairwise_distances_reduction/__init__.py b/sklearn/metrics/_pairwise_distances_reduction/__init__.py index cbeee703cdb74..943afc720d1fe 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/__init__.py +++ b/sklearn/metrics/_pairwise_distances_reduction/__init__.py @@ -85,7 +85,6 @@ # to optimally handle the Euclidean distance case using the Generalized Matrix # Multiplication (see the docstring of :class:`GEMMTermComputer64` for details). - from ._dispatcher import ( PairwiseDistancesReduction, PairwiseDistances, @@ -94,10 +93,13 @@ sqeuclidean_row_norms, ) +from ._pairwise_distances import _precompute_metric_params + __all__ = [ "PairwiseDistancesReduction", "PairwiseDistancesArgKmin", "PairwiseDistances", "PairwiseDistancesRadiusNeighborhood", "sqeuclidean_row_norms", + "_precompute_metric_params", ] diff --git a/sklearn/metrics/_pairwise_distances_reduction/_base.pxd b/sklearn/metrics/_pairwise_distances_reduction/_base.pxd index 9f6ad45cb839a..998455e638c1d 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_base.pxd +++ b/sklearn/metrics/_pairwise_distances_reduction/_base.pxd @@ -40,6 +40,7 @@ cdef class PairwiseDistancesReduction64: ITYPE_t n_samples_X, X_n_samples_chunk, X_n_chunks, X_n_samples_last_chunk ITYPE_t n_samples_Y, Y_n_samples_chunk, Y_n_chunks, Y_n_samples_last_chunk + bint X_is_Y bint execute_in_parallel_on_Y @final diff --git a/sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pxd b/sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pxd index de6458f8c6f26..ba792273fcb8b 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pxd +++ b/sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pxd @@ -3,7 +3,9 @@ from ...metrics._dist_metrics cimport DistanceMetric cdef class DatasetsPair: - cdef DistanceMetric distance_metric + cdef: + DistanceMetric distance_metric + readonly bint X_is_Y cdef ITYPE_t n_samples_X(self) nogil diff --git a/sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pyx b/sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pyx index abef1bed098ed..22c8a25a05e1d 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pyx +++ b/sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pyx @@ -97,8 +97,9 @@ cdef class DatasetsPair: return DenseDenseDatasetsPair(X, Y, distance_metric) - def __init__(self, DistanceMetric distance_metric): + def __init__(self, DistanceMetric distance_metric, bint X_is_Y): self.distance_metric = distance_metric + self.X_is_Y = X_is_Y cdef ITYPE_t n_samples_X(self) nogil: """Number of samples in X.""" @@ -141,7 +142,7 @@ cdef class DenseDenseDatasetsPair(DatasetsPair): """ def __init__(self, X, Y, DistanceMetric distance_metric): - super().__init__(distance_metric) + super().__init__(distance_metric, X_is_Y=X is Y) # Arrays have already been checked self.X = X self.Y = Y diff --git a/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py b/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py index 4b750678ae710..7036eaeb67689 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py +++ b/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py @@ -85,6 +85,8 @@ def is_usable_for(cls, X, Y, metric) -> bool: get_config().get("enable_cython_pairwise_dist", True) and _is_arraylike_not_scalar(X) and _is_arraylike_not_scalar(Y) + and not isinstance(X, (tuple, list)) + and not isinstance(Y, (tuple, list)) and X.dtype == Y.dtype == np.float64 and not issparse(X) and not issparse(Y) diff --git a/sklearn/metrics/_pairwise_distances_reduction/_pairwise_distances.pyx b/sklearn/metrics/_pairwise_distances_reduction/_pairwise_distances.pyx index c6a4baabef2b0..ef75c6d7e2b4b 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_pairwise_distances.pyx +++ b/sklearn/metrics/_pairwise_distances_reduction/_pairwise_distances.pyx @@ -22,12 +22,38 @@ import warnings from scipy.sparse import issparse from sklearn.utils import _in_unstable_openblas_configuration -from sklearn.utils.fixes import threadpool_limits +from sklearn.utils.fixes import threadpool_limits, sp_version, parse_version from ...utils._typedefs import ITYPE, DTYPE cnp.import_array() +def _precompute_metric_params(X, Y, metric=None, **kwds): + """Precompute data-derived metric parameters if not provided.""" + if metric == "seuclidean" and "V" not in kwds: + # There is a bug in scipy < 1.5 that will cause a crash if + # X.dtype != np.double (float64). See PR #15730 + dtype = np.float64 if sp_version < parse_version("1.5") else None + if X is Y: + V = np.var(X, axis=0, ddof=1, dtype=dtype) + else: + raise ValueError( + "The 'V' parameter is required for the seuclidean metric " + "when Y is passed." + ) + return {"V": V} + if metric == "mahalanobis" and "VI" not in kwds: + if X is Y: + VI = np.linalg.inv(np.cov(X.T)).T + else: + raise ValueError( + "The 'VI' parameter is required for the mahalanobis metric " + "when Y is passed." + ) + return {"VI": VI} + return {} + + cdef class PairwiseDistances64(PairwiseDistancesReduction64): """64bit implementation of PairwiseDistances.""" @@ -55,7 +81,7 @@ cdef class PairwiseDistances64(PairwiseDistancesReduction64): No instance should directly be created outside of this class method. """ if ( - metric in ("euclidean", "sqeuclidean") + metric in ("euclidean", "l2", "sqeuclidean") and not issparse(X) and not issparse(Y) ): @@ -72,8 +98,12 @@ cdef class PairwiseDistances64(PairwiseDistancesReduction64): strategy=strategy, ) else: - # Fall back on a generic implementation that handles most scipy - # metrics by computing the distances between 2 vectors at a time. + # Precompute data-derived distance metric parameters + params = _precompute_metric_params(X, Y, metric=metric, **metric_kwargs) + metric_kwargs.update(**params) + + # Fall back on a generic implementation that handles most scipy + # metrics by computing the distances between 2 vectors at a time. pdr = PairwiseDistances64( datasets_pair=DatasetsPair.get_for(X, Y, metric, metric_kwargs), chunk_size=chunk_size, @@ -112,8 +142,15 @@ cdef class PairwiseDistances64(PairwiseDistancesReduction64): ) def _finalize_results(self): - self.compute_exact_distances() - return np.asarray(self.pairwise_distances_matrix) + # If Y is X, then catastrophic cancellation might + # have occurred for computations of term on the diagonal + # which must be null. We enforce nullity of those term + # by zeroing the diagonal. + distance_matrix = np.asarray(self.pairwise_distances_matrix) + if self.datasets_pair.X_is_Y: + np.fill_diagonal(distance_matrix, 0) + + return distance_matrix cdef void _compute_and_reduce_distances_on_chunks( self, @@ -125,27 +162,13 @@ cdef class PairwiseDistances64(PairwiseDistancesReduction64): ) nogil: cdef: ITYPE_t i, j - DTYPE_t r_dist_i_j + DTYPE_t dist_i_j for i in range(X_start, X_end): for j in range(Y_start, Y_end): - r_dist_i_j = self.datasets_pair.surrogate_dist(i, j) - self.pairwise_distances_matrix[X_start + i, Y_start + j] = r_dist_i_j + dist_i_j = self.datasets_pair.dist(i, j) + self.pairwise_distances_matrix[X_start + i, Y_start + j] = dist_i_j - cdef void compute_exact_distances(self) nogil: - """Convert rank-preserving distances to pairwise distances in parallel.""" - cdef: - ITYPE_t i, j - - for i in prange(self.n_samples_X, nogil=True, schedule='static', - num_threads=self.effective_n_threads): - for j in range(self.n_samples_Y): - self.pairwise_distances_matrix[i, j] = ( - self.datasets_pair.distance_metric._rdist_to_dist( - # Guard against eventual -0., causing nan production. - max(self.pairwise_distances_matrix[i, j], 0.) - ) - ) cdef class FastEuclideanPairwiseDistances64(PairwiseDistances64): """EuclideanDistance-specialized 64bit implementation for PairwiseDistances.""" @@ -205,7 +228,7 @@ cdef class FastEuclideanPairwiseDistances64(PairwiseDistances64): # Do not recompute norms if datasets are identical. self.X_norm_squared = ( - self.Y_norm_squared if X is Y else + self.Y_norm_squared if self.datasets_pair.X_is_Y else _sqeuclidean_row_norms64(datasets_pair.X, self.effective_n_threads) ) self.use_squared_distances = use_squared_distances @@ -285,10 +308,6 @@ cdef class FastEuclideanPairwiseDistances64(PairwiseDistances64): X_start, X_end, Y_start, Y_end, thread_num ) - @final - cdef void compute_exact_distances(self) nogil: - if not self.use_squared_distances: - PairwiseDistances64.compute_exact_distances(self) @final cdef void _compute_and_reduce_distances_on_chunks( @@ -319,3 +338,12 @@ cdef class FastEuclideanPairwiseDistances64(PairwiseDistances64): + dist_middle_terms[i * n_Y + j] + self.Y_norm_squared[j + Y_start] ) + + def _finalize_results(self): + distance_matrix = PairwiseDistances64._finalize_results(self) + # Squared Euclidean distances have been used for efficiency. + # We remap them to Euclidean distances here before finalizing + # results. + if not self.use_squared_distances: + return np.sqrt(distance_matrix) + return PairwiseDistances64._finalize_results(self) diff --git a/sklearn/metrics/pairwise.py b/sklearn/metrics/pairwise.py index 7233e80bb2c2b..e835a885516f5 100644 --- a/sklearn/metrics/pairwise.py +++ b/sklearn/metrics/pairwise.py @@ -28,9 +28,12 @@ from ..preprocessing import normalize from ..utils._mask import _get_mask from ..utils.fixes import delayed -from ..utils.fixes import sp_version, parse_version -from ._pairwise_distances_reduction import PairwiseDistancesArgKmin, PairwiseDistances +from ._pairwise_distances_reduction import ( + PairwiseDistancesArgKmin, + PairwiseDistances, + _precompute_metric_params, +) from ._pairwise_fast import _chi2_kernel_fast, _sparse_manhattan from ..exceptions import DataConversionWarning @@ -1628,32 +1631,6 @@ def _check_chunk_size(reduced, chunk_size): ) -def _precompute_metric_params(X, Y, metric=None, **kwds): - """Precompute data-derived metric parameters if not provided.""" - if metric == "seuclidean" and "V" not in kwds: - # There is a bug in scipy < 1.5 that will cause a crash if - # X.dtype != np.double (float64). See PR #15730 - dtype = np.float64 if sp_version < parse_version("1.5") else None - if X is Y: - V = np.var(X, axis=0, ddof=1, dtype=dtype) - else: - raise ValueError( - "The 'V' parameter is required for the seuclidean metric " - "when Y is passed." - ) - return {"V": V} - if metric == "mahalanobis" and "VI" not in kwds: - if X is Y: - VI = np.linalg.inv(np.cov(X.T)).T - else: - raise ValueError( - "The 'VI' parameter is required for the mahalanobis metric " - "when Y is passed." - ) - return {"VI": VI} - return {} - - def pairwise_distances_chunked( X, Y=None, From dff1aa2654cb5a082ccb1690196144d674b1eccb Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Tue, 12 Jul 2022 00:03:19 +0200 Subject: [PATCH 04/36] Post-merge fix MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ಠ_ರೃ --- .../metrics/_pairwise_distances_reduction.pyx | 2125 ----------------- .../_pairwise_distances_reduction/__init__.py | 1 + 2 files changed, 1 insertion(+), 2125 deletions(-) delete mode 100644 sklearn/metrics/_pairwise_distances_reduction.pyx diff --git a/sklearn/metrics/_pairwise_distances_reduction.pyx b/sklearn/metrics/_pairwise_distances_reduction.pyx deleted file mode 100644 index 829d93388efcc..0000000000000 --- a/sklearn/metrics/_pairwise_distances_reduction.pyx +++ /dev/null @@ -1,2125 +0,0 @@ -# Pairwise Distances Reductions -# ============================= -# -# Author: Julien Jerphanion -# -# Overview -# -------- -# -# This module provides routines to compute pairwise distances between a set -# of row vectors of X and another set of row vectors of Y and apply a -# reduction on top. The canonical example is the brute-force computation -# of the top k nearest neighbors by leveraging the arg-k-min reduction. -# -# The reduction takes a matrix of pairwise distances between rows of X and Y -# as input and outputs an aggregate data-structure for each row of X. The -# aggregate values are typically smaller than the number of rows in Y, hence -# the term reduction. -# -# For computational reasons, the reduction are performed on the fly on chunks -# of rows of X and Y so as to keep intermediate data-structures in CPU cache -# and avoid unnecessary round trips of large distance arrays with the RAM -# that would otherwise severely degrade the speed by making the overall -# processing memory-bound. -# -# Finally, the routines follow a generic parallelization template to process -# chunks of data with OpenMP loops (via Cython prange), either on rows of X -# or rows of Y depending on their respective sizes. -# -# -# Dispatching to specialized implementations -# ------------------------------------------ -# -# Dispatchers are meant to be used in the Python code. Under the hood, a -# dispatcher must only define the logic to choose at runtime to the correct -# dtype-specialized :class:`PairwiseDistancesReduction` implementation based -# on the dtype of X and of Y. -# -# -# High-level diagram -# ------------------ -# -# Legend: -# -# A ---⊳ B: A inherits from B -# A ---x B: A dispatches on B -# -# -# (base dispatcher) -# PairwiseDistancesReduction -# ∆ -# | -# | -# +-----------------+-----------------+ -# | | -# (dispatcher) (dispatcher) -# PairwiseDistancesArgKmin PairwiseDistancesRadiusNeighbors -# | | -# | | -# | | -# | (64bit implem.) | -# | PairwiseDistancesReduction64 | -# | ∆ | -# | | | -# | | | -# | +-----------------+-----------------+ | -# | | | | -# | | | | -# x | | x -# PairwiseDistancesArgKmin64 PairwiseDistancesRadiusNeighbors64 -# | ∆ ∆ | -# | | | | -# x | | | -# FastEuclideanPairwiseDistancesArgKmin64 | | -# | | -# | x -# FastEuclideanPairwiseDistancesRadiusNeighbors64 -# -# For instance :class:`PairwiseDistancesArgKmin`, dispatches to -# :class:`PairwiseDistancesArgKmin64` if X and Y are both dense NumPy arrays -# with a float64 dtype. -# -# In addition, if the metric parameter is set to "euclidean" or "sqeuclidean", -# :class:`PairwiseDistancesArgKmin64` further dispatches to -# :class:`FastEuclideanPairwiseDistancesArgKmin64` a specialized subclass -# to optimally handle the Euclidean distance case using the Generalized Matrix -# Multiplication (see the docstring of :class:`GEMMTermComputer64` for details). -from abc import abstractmethod - -cimport numpy as cnp -import numpy as np -import warnings - -from .. import get_config -from libc.stdlib cimport free, malloc -from libc.float cimport DBL_MAX -from libcpp.memory cimport shared_ptr, make_shared -from libcpp.vector cimport vector -from cython cimport final -from cython.operator cimport dereference as deref -from cython.parallel cimport parallel, prange - -from ._dist_metrics cimport DatasetsPair, DenseDenseDatasetsPair -from ..utils._cython_blas cimport ( - BLAS_Order, - BLAS_Trans, - ColMajor, - NoTrans, - RowMajor, - Trans, - _dot, - _gemm, -) -from ..utils._heap cimport heap_push -from ..utils._sorting cimport simultaneous_sort -from ..utils._openmp_helpers cimport _openmp_thread_num -from ..utils._typedefs cimport ITYPE_t, DTYPE_t -from ..utils._vector_sentinel cimport vector_to_nd_array - -from numbers import Integral, Real -from typing import List -from scipy.sparse import issparse -from ._dist_metrics import BOOL_METRICS, METRIC_MAPPING -from ..utils import check_scalar, _in_unstable_openblas_configuration -from ..utils.fixes import threadpool_limits -from ..utils._openmp_helpers import _openmp_effective_n_threads -from ..utils._typedefs import ITYPE, DTYPE - -cnp.import_array() - -# TODO: change for `libcpp.algorithm.move` once Cython 3 is used -# Introduction in Cython: -# https://github.com/cython/cython/blob/05059e2a9b89bf6738a7750b905057e5b1e3fe2e/Cython/Includes/libcpp/algorithm.pxd#L47 #noqa -cdef extern from "" namespace "std" nogil: - OutputIt move[InputIt, OutputIt](InputIt first, InputIt last, OutputIt d_first) except + #noqa - -###################### -## std::vector to np.ndarray coercion -# As type covariance is not supported for C++ containers via Cython, -# we need to redefine fused types. -ctypedef fused vector_DITYPE_t: - vector[ITYPE_t] - vector[DTYPE_t] - - -ctypedef fused vector_vector_DITYPE_t: - vector[vector[ITYPE_t]] - vector[vector[DTYPE_t]] - - -cdef cnp.ndarray[object, ndim=1] coerce_vectors_to_nd_arrays( - shared_ptr[vector_vector_DITYPE_t] vecs -): - """Coerce a std::vector of std::vector to a ndarray of ndarray.""" - cdef: - ITYPE_t n = deref(vecs).size() - cnp.ndarray[object, ndim=1] nd_arrays_of_nd_arrays = np.empty(n, dtype=np.ndarray) - - for i in range(n): - nd_arrays_of_nd_arrays[i] = vector_to_nd_array(&(deref(vecs)[i])) - - return nd_arrays_of_nd_arrays - -##################### -# Dispatchers - -class PairwiseDistancesReduction: - """Abstract base dispatcher for pairwise distance computation & reduction. - - Each dispatcher extending the base :class:`PairwiseDistancesReduction` - dispatcher must implement the :meth:`compute` classmethod. - """ - - @classmethod - def valid_metrics(cls) -> List[str]: - excluded = { - "pyfunc", # is relatively slow because we need to coerce data as np arrays - "mahalanobis", # is numerically unstable - # TODO: In order to support discrete distance metrics, we need to have a - # stable simultaneous sort which preserves the order of the input. - # The best might be using std::stable_sort and a Comparator taking an - # Arrays of Structures instead of Structure of Arrays (currently used). - "hamming", - *BOOL_METRICS, - } - return sorted(set(METRIC_MAPPING.keys()) - excluded) - - @classmethod - def is_usable_for(cls, X, Y, metric) -> bool: - """Return True if the PairwiseDistancesReduction can be used for the - given parameters. - - Parameters - ---------- - X : {ndarray, sparse matrix} of shape (n_samples_X, n_features) - Input data. - - Y : {ndarray, sparse matrix} of shape (n_samples_Y, n_features) - Input data. - - metric : str, default='euclidean' - The distance metric to use. - For a list of available metrics, see the documentation of - :class:`~sklearn.metrics.DistanceMetric`. - - Returns - ------- - True if the PairwiseDistancesReduction can be used, else False. - """ - try: - Y = X if Y is None else Y - dtypes_validity = X.dtype == Y.dtype and Y.dtype == np.float64 - return (get_config().get("enable_cython_pairwise_dist", True) and - not issparse(X) and not issparse(Y) and dtypes_validity and - metric in cls.valid_metrics()) - except Exception: - return False - - @classmethod - @abstractmethod - def compute( - cls, - X, - Y, - **kwargs, - ): - """Compute the reduction. - - Parameters - ---------- - X : ndarray or CSR matrix of shape (n_samples_X, n_features) - Input data. - - Y : ndarray or CSR matrix of shape (n_samples_Y, n_features) - Input data. - - **kwargs : additional parameters for the reduction - - Notes - ----- - This method is an abstract class method: it has to be implemented - for all subclasses. - """ - -class PairwiseDistancesArgKmin(PairwiseDistancesReduction): - """Compute the argkmin of row vectors of X on the ones of Y. - - For each row vector of X, computes the indices of k first the rows - vectors of Y with the smallest distances. - - PairwiseDistancesArgKmin is typically used to perform - bruteforce k-nearest neighbors queries. - - This class is not meant to be instanciated, one should only use - its :meth:`compute` classmethod which handles allocation and - deallocation consistently. - """ - - @classmethod - def compute( - cls, - X, - Y, - k, - metric="euclidean", - chunk_size=None, - metric_kwargs=None, - strategy=None, - return_distance=False, - ): - """Compute the argkmin reduction. - - Parameters - ---------- - X : ndarray or CSR matrix of shape (n_samples_X, n_features) - Input data. - - Y : ndarray or CSR matrix of shape (n_samples_Y, n_features) - Input data. - - k : int - The k for the argkmin reduction. - - metric : str, default='euclidean' - The distance metric to use for argkmin. - For a list of available metrics, see the documentation of - :class:`~sklearn.metrics.DistanceMetric`. - - chunk_size : int, default=None, - The number of vectors per chunk. If None (default) looks-up in - scikit-learn configuration for `pairwise_dist_chunk_size`, - and use 256 if it is not set. - - metric_kwargs : dict, default=None - Keyword arguments to pass to specified metric function. - - strategy : str, {'auto', 'parallel_on_X', 'parallel_on_Y'}, default=None - The chunking strategy defining which dataset parallelization are made on. - - For both strategies the computations happens with two nested loops, - respectively on chunks of X and chunks of Y. - Strategies differs on which loop (outer or inner) is made to run - in parallel with the Cython `prange` construct: - - - 'parallel_on_X' dispatches chunks of X uniformly on threads. - Each thread then iterates on all the chunks of Y. This strategy is - embarrassingly parallel and comes with no datastructures synchronisation. - - - 'parallel_on_Y' dispatches chunks of Y uniformly on threads. - Each thread processes all the chunks of X in turn. This strategy is - a sequence of embarrassingly parallel subtasks (the inner loop on Y - chunks) with intermediate datastructures synchronisation at each - iteration of the sequential outer loop on X chunks. - - - 'auto' relies on a simple heuristic to choose between - 'parallel_on_X' and 'parallel_on_Y': when `X.shape[0]` is large enough, - 'parallel_on_X' is usually the most efficient strategy. When `X.shape[0]` - is small but `Y.shape[0]` is large, 'parallel_on_Y' brings more opportunity - for parallelism and is therefore more efficient despite the synchronization - step at each iteration of the outer loop on chunks of `X`. - - - None (default) looks-up in scikit-learn configuration for - `pairwise_dist_parallel_strategy`, and use 'auto' if it is not set. - - return_distance : boolean, default=False - Return distances between each X vector and its - argkmin if set to True. - - Returns - ------- - If return_distance=False: - - argkmin_indices : ndarray of shape (n_samples_X, k) - Indices of the argkmin for each vector in X. - - If return_distance=True: - - argkmin_distances : ndarray of shape (n_samples_X, k) - Distances to the argkmin for each vector in X. - - argkmin_indices : ndarray of shape (n_samples_X, k) - Indices of the argkmin for each vector in X. - - Notes - ----- - This classmethod is responsible for introspecting the arguments - values to dispatch to the most appropriate implementation of - :class:`PairwiseDistancesArgKmin`. - - This allows decoupling the API entirely from the implementation details - whilst maintaining RAII: all temporarily allocated datastructures necessary - for the concrete implementation are therefore freed when this classmethod - returns. - """ - # Note (jjerphan): Some design thoughts for future extensions. - # This factory comes to handle specialisations for the given arguments. - # For future work, this might can be an entrypoint to specialise operations - # for various backend and/or hardware and/or datatypes, and/or fused - # {sparse, dense}-datasetspair etc. - if X.dtype == Y.dtype == np.float64: - return PairwiseDistancesArgKmin64.compute( - X=X, - Y=Y, - k=k, - metric=metric, - chunk_size=chunk_size, - metric_kwargs=metric_kwargs, - strategy=strategy, - return_distance=return_distance, - ) - raise ValueError( - f"Only 64bit float datasets are supported at this time, " - f"got: X.dtype={X.dtype} and Y.dtype={Y.dtype}." - ) - - -class PairwiseDistancesRadiusNeighborhood(PairwiseDistancesReduction): - """Compute radius-based neighbors for two sets of vectors. - - For each row-vector X[i] of the queries X, find all the indices j of - row-vectors in Y such that: - - dist(X[i], Y[j]) <= radius - - The distance function `dist` depends on the values of the `metric` - and `metric_kwargs` parameters. - - This class is not meant to be instanciated, one should only use - its :meth:`compute` classmethod which handles allocation and - deallocation consistently. - """ - - @classmethod - def compute( - cls, - X, - Y, - radius, - metric="euclidean", - chunk_size=None, - metric_kwargs=None, - strategy=None, - return_distance=False, - sort_results=False, - ): - """Return the results of the reduction for the given arguments. - - Parameters - ---------- - X : ndarray or CSR matrix of shape (n_samples_X, n_features) - Input data. - - Y : ndarray or CSR matrix of shape (n_samples_Y, n_features) - Input data. - - radius : float - The radius defining the neighborhood. - - metric : str, default='euclidean' - The distance metric to use. - For a list of available metrics, see the documentation of - :class:`~sklearn.metrics.DistanceMetric`. - - chunk_size : int, default=None, - The number of vectors per chunk. If None (default) looks-up in - scikit-learn configuration for `pairwise_dist_chunk_size`, - and use 256 if it is not set. - - metric_kwargs : dict, default=None - Keyword arguments to pass to specified metric function. - - strategy : str, {'auto', 'parallel_on_X', 'parallel_on_Y'}, default=None - The chunking strategy defining which dataset parallelization are made on. - - For both strategies the computations happens with two nested loops, - respectively on chunks of X and chunks of Y. - Strategies differs on which loop (outer or inner) is made to run - in parallel with the Cython `prange` construct: - - - 'parallel_on_X' dispatches chunks of X uniformly on threads. - Each thread then iterates on all the chunks of Y. This strategy is - embarrassingly parallel and comes with no datastructures synchronisation. - - - 'parallel_on_Y' dispatches chunks of Y uniformly on threads. - Each thread processes all the chunks of X in turn. This strategy is - a sequence of embarrassingly parallel subtasks (the inner loop on Y - chunks) with intermediate datastructures synchronisation at each - iteration of the sequential outer loop on X chunks. - - - 'auto' relies on a simple heuristic to choose between - 'parallel_on_X' and 'parallel_on_Y': when `X.shape[0]` is large enough, - 'parallel_on_X' is usually the most efficient strategy. When `X.shape[0]` - is small but `Y.shape[0]` is large, 'parallel_on_Y' brings more opportunity - for parallelism and is therefore more efficient despite the synchronization - step at each iteration of the outer loop on chunks of `X`. - - - None (default) looks-up in scikit-learn configuration for - `pairwise_dist_parallel_strategy`, and use 'auto' if it is not set. - - return_distance : boolean, default=False - Return distances between each X vector and its neighbors if set to True. - - sort_results : boolean, default=False - Sort results with respect to distances between each X vector and its - neighbors if set to True. - - Returns - ------- - If return_distance=False: - - neighbors_indices : ndarray of n_samples_X ndarray - Indices of the neighbors for each vector in X. - - If return_distance=True: - - neighbors_indices : ndarray of n_samples_X ndarray - Indices of the neighbors for each vector in X. - - neighbors_distances : ndarray of n_samples_X ndarray - Distances to the neighbors for each vector in X. - - Notes - ----- - This public classmethod is responsible for introspecting the arguments - values to dispatch to the private dtype-specialized implementation of - :class:`PairwiseDistancesRadiusNeighborhood`. - - All temporarily allocated datastructures necessary for the concrete - implementation are therefore freed when this classmethod returns. - - This allows entirely decoupling the API entirely from the - implementation details whilst maintaining RAII. - """ - # Note (jjerphan): Some design thoughts for future extensions. - # This factory comes to handle specialisations for the given arguments. - # For future work, this might can be an entrypoint to specialise operations - # for various backend and/or hardware and/or datatypes, and/or fused - # {sparse, dense}-datasetspair etc. - if X.dtype == Y.dtype == np.float64: - return PairwiseDistancesRadiusNeighborhood64.compute( - X=X, - Y=Y, - radius=radius, - metric=metric, - chunk_size=chunk_size, - metric_kwargs=metric_kwargs, - strategy=strategy, - sort_results=sort_results, - return_distance=return_distance, - ) - raise ValueError( - f"Only 64bit float datasets are supported at this time, " - f"got: X.dtype={X.dtype} and Y.dtype={Y.dtype}." - ) - -class PairwiseDistancesArgKmin(PairwiseDistancesReduction): - """Compute the argkmin of row vectors of X on the ones of Y. - - For each row vector of X, computes the indices of k first the rows - vectors of Y with the smallest distances. - - PairwiseDistancesArgKmin is typically used to perform - bruteforce k-nearest neighbors queries. - - This class is not meant to be instanciated, one should only use - its :meth:`compute` classmethod which handles allocation and - deallocation consistently. - """ - - @classmethod - def compute( - cls, - X, - Y, - k, - metric="euclidean", - chunk_size=None, - metric_kwargs=None, - strategy=None, - return_distance=False, - ): - """Compute the argkmin reduction. - - Parameters - ---------- - X : ndarray or CSR matrix of shape (n_samples_X, n_features) - Input data. - - Y : ndarray or CSR matrix of shape (n_samples_Y, n_features) - Input data. - - k : int - The k for the argkmin reduction. - - metric : str, default='euclidean' - The distance metric to use for argkmin. - For a list of available metrics, see the documentation of - :class:`~sklearn.metrics.DistanceMetric`. - - chunk_size : int, default=None, - The number of vectors per chunk. If None (default) looks-up in - scikit-learn configuration for `pairwise_dist_chunk_size`, - and use 256 if it is not set. - - metric_kwargs : dict, default=None - Keyword arguments to pass to specified metric function. - - strategy : str, {'auto', 'parallel_on_X', 'parallel_on_Y'}, default=None - The chunking strategy defining which dataset parallelization are made on. - - For both strategies the computations happens with two nested loops, - respectively on chunks of X and chunks of Y. - Strategies differs on which loop (outer or inner) is made to run - in parallel with the Cython `prange` construct: - - - 'parallel_on_X' dispatches chunks of X uniformly on threads. - Each thread then iterates on all the chunks of Y. This strategy is - embarrassingly parallel and comes with no datastructures synchronisation. - - - 'parallel_on_Y' dispatches chunks of Y uniformly on threads. - Each thread processes all the chunks of X in turn. This strategy is - a sequence of embarrassingly parallel subtasks (the inner loop on Y - chunks) with intermediate datastructures synchronisation at each - iteration of the sequential outer loop on X chunks. - - - 'auto' relies on a simple heuristic to choose between - 'parallel_on_X' and 'parallel_on_Y': when `X.shape[0]` is large enough, - 'parallel_on_X' is usually the most efficient strategy. When `X.shape[0]` - is small but `Y.shape[0]` is large, 'parallel_on_Y' brings more opportunity - for parallelism and is therefore more efficient despite the synchronization - step at each iteration of the outer loop on chunks of `X`. - - - None (default) looks-up in scikit-learn configuration for - `pairwise_dist_parallel_strategy`, and use 'auto' if it is not set. - - return_distance : boolean, default=False - Return distances between each X vector and its - argkmin if set to True. - - Returns - ------- - If return_distance=False: - - argkmin_indices : ndarray of shape (n_samples_X, k) - Indices of the argkmin for each vector in X. - - If return_distance=True: - - argkmin_distances : ndarray of shape (n_samples_X, k) - Distances to the argkmin for each vector in X. - - argkmin_indices : ndarray of shape (n_samples_X, k) - Indices of the argkmin for each vector in X. - - Notes - ----- - This classmethod is responsible for introspecting the arguments - values to dispatch to the most appropriate implementation of - :class:`PairwiseDistancesArgKmin`. - - This allows decoupling the API entirely from the implementation details - whilst maintaining RAII: all temporarily allocated datastructures necessary - for the concrete implementation are therefore freed when this classmethod - returns. - """ - # Note (jjerphan): Some design thoughts for future extensions. - # This factory comes to handle specialisations for the given arguments. - # For future work, this might can be an entrypoint to specialise operations - # for various backend and/or hardware and/or datatypes, and/or fused - # {sparse, dense}-datasetspair etc. - if X.dtype == Y.dtype == np.float64: - return PairwiseDistancesArgKmin64.compute( - X=X, - Y=Y, - k=k, - metric=metric, - chunk_size=chunk_size, - metric_kwargs=metric_kwargs, - strategy=strategy, - return_distance=return_distance, - ) - raise ValueError( - f"Only 64bit float datasets are supported at this time, " - f"got: X.dtype={X.dtype} and Y.dtype={Y.dtype}." - ) - - -##################### -# dtype-specialized implementations - -cpdef DTYPE_t[::1] _sqeuclidean_row_norms64( - const DTYPE_t[:, ::1] X, - ITYPE_t num_threads, -): - """Compute the squared euclidean norm of the rows of X in parallel. - - This is faster than using np.einsum("ij, ij->i") even when using a single thread. - """ - cdef: - # Casting for X to remove the const qualifier is needed because APIs - # exposed via scipy.linalg.cython_blas aren't reflecting the arguments' - # const qualifier. - # See: https://github.com/scipy/scipy/issues/14262 - DTYPE_t * X_ptr = &X[0, 0] - ITYPE_t i = 0 - ITYPE_t n = X.shape[0] - ITYPE_t d = X.shape[1] - DTYPE_t[::1] squared_row_norms = np.empty(n, dtype=DTYPE) - - for i in prange(n, schedule='static', nogil=True, num_threads=num_threads): - squared_row_norms[i] = _dot(d, X_ptr + i * d, 1, X_ptr + i * d, 1) - - return squared_row_norms - -cdef class GEMMTermComputer64: - """Component for `FastEuclidean*` variant wrapping the logic for the call to GEMM. - - `FastEuclidean*` classes internally compute the squared Euclidean distances between - chunks of vectors X_c and Y_c using the following decomposition: - - - ||X_c_i - Y_c_j||² = ||X_c_i||² - 2 X_c_i.Y_c_j^T + ||Y_c_j||² - - - This helper class is in charge of wrapping the common logic to compute - the middle term `- 2 X_c_i.Y_c_j^T` with a call to GEMM, which has a high - arithmetic intensity. - """ - cdef: - const DTYPE_t[:, ::1] X - const DTYPE_t[:, ::1] Y - - ITYPE_t effective_n_threads - ITYPE_t chunks_n_threads - ITYPE_t dist_middle_terms_chunks_size - ITYPE_t n_features - ITYPE_t chunk_size - - # Buffers for the `-2 * X_c @ Y_c.T` term computed via GEMM - vector[vector[DTYPE_t]] dist_middle_terms_chunks - - def __init__(self, - DTYPE_t[:, ::1] X, - DTYPE_t[:, ::1] Y, - ITYPE_t effective_n_threads, - ITYPE_t chunks_n_threads, - ITYPE_t dist_middle_terms_chunks_size, - ITYPE_t n_features, - ITYPE_t chunk_size, - ): - self.X = X - self.Y = Y - self.effective_n_threads = effective_n_threads - self.chunks_n_threads = chunks_n_threads - self.dist_middle_terms_chunks_size = dist_middle_terms_chunks_size - self.n_features = n_features - self.chunk_size = chunk_size - - self.dist_middle_terms_chunks = vector[vector[DTYPE_t]](self.effective_n_threads) - - - cdef void _parallel_on_X_pre_compute_and_reduce_distances_on_chunks( - self, - ITYPE_t X_start, - ITYPE_t X_end, - ITYPE_t Y_start, - ITYPE_t Y_end, - ITYPE_t thread_num, - ) nogil: - return - - cdef void _parallel_on_X_parallel_init(self, ITYPE_t thread_num) nogil: - self.dist_middle_terms_chunks[thread_num].resize(self.dist_middle_terms_chunks_size) - - cdef void _parallel_on_X_init_chunk( - self, - ITYPE_t thread_num, - ITYPE_t X_start, - ITYPE_t X_end, - ) nogil: - return - - cdef void _parallel_on_Y_init(self) nogil: - for thread_num in range(self.chunks_n_threads): - self.dist_middle_terms_chunks[thread_num].resize( - self.dist_middle_terms_chunks_size - ) - - cdef void _parallel_on_Y_parallel_init( - self, - ITYPE_t thread_num, - ITYPE_t X_start, - ITYPE_t X_end, - ) nogil: - return - - cdef void _parallel_on_Y_pre_compute_and_reduce_distances_on_chunks( - self, - ITYPE_t X_start, - ITYPE_t X_end, - ITYPE_t Y_start, - ITYPE_t Y_end, - ITYPE_t thread_num - ) nogil: - return - - cdef DTYPE_t * _compute_distances_on_chunks( - self, - ITYPE_t X_start, - ITYPE_t X_end, - ITYPE_t Y_start, - ITYPE_t Y_end, - ITYPE_t thread_num, - ) nogil: - cdef: - ITYPE_t i, j - DTYPE_t squared_dist_i_j - const DTYPE_t[:, ::1] X_c = self.X[X_start:X_end, :] - const DTYPE_t[:, ::1] Y_c = self.Y[Y_start:Y_end, :] - DTYPE_t *dist_middle_terms = self.dist_middle_terms_chunks[thread_num].data() - - # Careful: LDA, LDB and LDC are given for F-ordered arrays - # in BLAS documentations, for instance: - # https://www.netlib.org/lapack/explore-html/db/dc9/group__single__blas__level3_gafe51bacb54592ff5de056acabd83c260.html #noqa - # - # Here, we use their counterpart values to work with C-ordered arrays. - BLAS_Order order = RowMajor - BLAS_Trans ta = NoTrans - BLAS_Trans tb = Trans - ITYPE_t m = X_c.shape[0] - ITYPE_t n = Y_c.shape[0] - ITYPE_t K = X_c.shape[1] - DTYPE_t alpha = - 2. - # Casting for A and B to remove the const is needed because APIs exposed via - # scipy.linalg.cython_blas aren't reflecting the arguments' const qualifier. - # See: https://github.com/scipy/scipy/issues/14262 - DTYPE_t * A = &X_c[0, 0] - DTYPE_t * B = &Y_c[0, 0] - ITYPE_t lda = X_c.shape[1] - ITYPE_t ldb = X_c.shape[1] - DTYPE_t beta = 0. - ITYPE_t ldc = Y_c.shape[0] - - # dist_middle_terms = `-2 * X_c @ Y_c.T` - _gemm(order, ta, tb, m, n, K, alpha, A, lda, B, ldb, beta, dist_middle_terms, ldc) - - return dist_middle_terms - -cdef class PairwiseDistancesReduction64: - """Base 64bit implementation of PairwiseDistancesReduction.""" - - cdef: - readonly DatasetsPair datasets_pair - - # The number of threads that can be used is stored in effective_n_threads. - # - # The number of threads to use in the parallelization strategy - # (i.e. parallel_on_X or parallel_on_Y) can be smaller than effective_n_threads: - # for small datasets, fewer threads might be needed to loop over pair of chunks. - # - # Hence, the number of threads that _will_ be used for looping over chunks - # is stored in chunks_n_threads, allowing solely using what we need. - # - # Thus, an invariant is: - # - # chunks_n_threads <= effective_n_threads - # - ITYPE_t effective_n_threads - ITYPE_t chunks_n_threads - - ITYPE_t n_samples_chunk, chunk_size - - ITYPE_t n_samples_X, X_n_samples_chunk, X_n_chunks, X_n_samples_last_chunk - ITYPE_t n_samples_Y, Y_n_samples_chunk, Y_n_chunks, Y_n_samples_last_chunk - - bint execute_in_parallel_on_Y - - def __init__( - self, - DatasetsPair datasets_pair, - chunk_size=None, - strategy=None, - ): - cdef: - ITYPE_t n_samples_chunk, X_n_full_chunks, Y_n_full_chunks - - if chunk_size is None: - chunk_size = get_config().get("pairwise_dist_chunk_size", 256) - - self.chunk_size = check_scalar(chunk_size, "chunk_size", Integral, min_val=20) - - self.effective_n_threads = _openmp_effective_n_threads() - - self.datasets_pair = datasets_pair - - self.n_samples_X = datasets_pair.n_samples_X() - self.X_n_samples_chunk = min(self.n_samples_X, self.chunk_size) - X_n_full_chunks = self.n_samples_X // self.X_n_samples_chunk - X_n_samples_remainder = self.n_samples_X % self.X_n_samples_chunk - self.X_n_chunks = X_n_full_chunks + (X_n_samples_remainder != 0) - - if X_n_samples_remainder != 0: - self.X_n_samples_last_chunk = X_n_samples_remainder - else: - self.X_n_samples_last_chunk = self.X_n_samples_chunk - - self.n_samples_Y = datasets_pair.n_samples_Y() - self.Y_n_samples_chunk = min(self.n_samples_Y, self.chunk_size) - Y_n_full_chunks = self.n_samples_Y // self.Y_n_samples_chunk - Y_n_samples_remainder = self.n_samples_Y % self.Y_n_samples_chunk - self.Y_n_chunks = Y_n_full_chunks + (Y_n_samples_remainder != 0) - - if Y_n_samples_remainder != 0: - self.Y_n_samples_last_chunk = Y_n_samples_remainder - else: - self.Y_n_samples_last_chunk = self.Y_n_samples_chunk - - if strategy is None: - strategy = get_config().get("pairwise_dist_parallel_strategy", 'auto') - - if strategy not in ('parallel_on_X', 'parallel_on_Y', 'auto'): - raise RuntimeError(f"strategy must be 'parallel_on_X, 'parallel_on_Y', " - f"or 'auto', but currently strategy='{self.strategy}'.") - - if strategy == 'auto': - # This is a simple heuristic whose constant for the - # comparison has been chosen based on experiments. - if 4 * self.chunk_size * self.effective_n_threads < self.n_samples_X: - strategy = 'parallel_on_X' - else: - strategy = 'parallel_on_Y' - - self.execute_in_parallel_on_Y = strategy == "parallel_on_Y" - - # Not using less, not using more. - self.chunks_n_threads = min( - self.Y_n_chunks if self.execute_in_parallel_on_Y else self.X_n_chunks, - self.effective_n_threads, - ) - - @final - cdef void _parallel_on_X(self) nogil: - """Compute the pairwise distances of each row vector of X on Y - by parallelizing computation on the outer loop on chunks of X - and reduce them. - - This strategy dispatches chunks of Y uniformly on threads. - Each thread processes all the chunks of X in turn. This strategy is - a sequence of embarrassingly parallel subtasks (the inner loop on Y - chunks) with intermediate datastructures synchronisation at each - iteration of the sequential outer loop on X chunks. - - Private datastructures are modified internally by threads. - - Private template methods can be implemented on subclasses to - interact with those datastructures at various stages. - """ - cdef: - ITYPE_t Y_start, Y_end, X_start, X_end, X_chunk_idx, Y_chunk_idx - ITYPE_t thread_num - - with nogil, parallel(num_threads=self.chunks_n_threads): - thread_num = _openmp_thread_num() - - # Allocating thread datastructures - self._parallel_on_X_parallel_init(thread_num) - - for X_chunk_idx in prange(self.X_n_chunks, schedule='static'): - X_start = X_chunk_idx * self.X_n_samples_chunk - if X_chunk_idx == self.X_n_chunks - 1: - X_end = X_start + self.X_n_samples_last_chunk - else: - X_end = X_start + self.X_n_samples_chunk - - # Reinitializing thread datastructures for the new X chunk - # If necessary, upcast X[X_start:X_end] to 64bit - self._parallel_on_X_init_chunk(thread_num, X_start, X_end) - - for Y_chunk_idx in range(self.Y_n_chunks): - Y_start = Y_chunk_idx * self.Y_n_samples_chunk - if Y_chunk_idx == self.Y_n_chunks - 1: - Y_end = Y_start + self.Y_n_samples_last_chunk - else: - Y_end = Y_start + self.Y_n_samples_chunk - - # If necessary, upcast Y[Y_start:Y_end] to 64bit - self._parallel_on_X_pre_compute_and_reduce_distances_on_chunks( - X_start, X_end, - Y_start, Y_end, - thread_num, - ) - - self._compute_and_reduce_distances_on_chunks( - X_start, X_end, - Y_start, Y_end, - thread_num, - ) - - # Adjusting thread datastructures on the full pass on Y - self._parallel_on_X_prange_iter_finalize(thread_num, X_start, X_end) - - # end: for X_chunk_idx - - # Deallocating thread datastructures - self._parallel_on_X_parallel_finalize(thread_num) - - # end: with nogil, parallel - return - - @final - cdef void _parallel_on_Y(self) nogil: - """Compute the pairwise distances of each row vector of X on Y - by parallelizing computation on the inner loop on chunks of Y - and reduce them. - - This strategy dispatches chunks of Y uniformly on threads. - Each thread processes all the chunks of X in turn. This strategy is - a sequence of embarrassingly parallel subtasks (the inner loop on Y - chunks) with intermediate datastructures synchronisation at each - iteration of the sequential outer loop on X chunks. - - Private datastructures are modified internally by threads. - - Private template methods can be implemented on subclasses to - interact with those datastructures at various stages. - """ - cdef: - ITYPE_t Y_start, Y_end, X_start, X_end, X_chunk_idx, Y_chunk_idx - ITYPE_t thread_num - - # Allocating datastructures shared by all threads - self._parallel_on_Y_init() - - for X_chunk_idx in range(self.X_n_chunks): - X_start = X_chunk_idx * self.X_n_samples_chunk - if X_chunk_idx == self.X_n_chunks - 1: - X_end = X_start + self.X_n_samples_last_chunk - else: - X_end = X_start + self.X_n_samples_chunk - - with nogil, parallel(num_threads=self.chunks_n_threads): - thread_num = _openmp_thread_num() - - # Initializing datastructures used in this thread - # If necessary, upcast X[X_start:X_end] to 64bit - self._parallel_on_Y_parallel_init(thread_num, X_start, X_end) - - for Y_chunk_idx in prange(self.Y_n_chunks, schedule='static'): - Y_start = Y_chunk_idx * self.Y_n_samples_chunk - if Y_chunk_idx == self.Y_n_chunks - 1: - Y_end = Y_start + self.Y_n_samples_last_chunk - else: - Y_end = Y_start + self.Y_n_samples_chunk - - # If necessary, upcast Y[Y_start:Y_end] to 64bit - self._parallel_on_Y_pre_compute_and_reduce_distances_on_chunks( - X_start, X_end, - Y_start, Y_end, - thread_num, - ) - - self._compute_and_reduce_distances_on_chunks( - X_start, X_end, - Y_start, Y_end, - thread_num, - ) - # end: prange - - # Note: we don't need a _parallel_on_Y_finalize similarly. - # This can be introduced if needed. - - # end: with nogil, parallel - - # Synchronizing the thread datastructures with the main ones - self._parallel_on_Y_synchronize(X_start, X_end) - - # end: for X_chunk_idx - # Deallocating temporary datastructures and adjusting main datastructures - self._parallel_on_Y_finalize() - return - - # Placeholder methods which have to be implemented - - cdef void _compute_and_reduce_distances_on_chunks( - self, - ITYPE_t X_start, - ITYPE_t X_end, - ITYPE_t Y_start, - ITYPE_t Y_end, - ITYPE_t thread_num, - ) nogil: - """Compute the pairwise distances on two chunks of X and Y and reduce them. - - This is THE core computational method of PairwiseDistanceReductions64. - This must be implemented in subclasses agnostically from the parallelization - strategies. - """ - return - - def _finalize_results(self, bint return_distance): - """Callback adapting datastructures before returning results. - - This must be implemented in subclasses. - """ - return None - - # Placeholder methods which can be implemented - - cdef void compute_exact_distances(self) nogil: - """Convert rank-preserving distances to exact distances or recompute them.""" - return - - cdef void _parallel_on_X_parallel_init( - self, - ITYPE_t thread_num, - ) nogil: - """Allocate datastructures used in a thread given its number.""" - return - - cdef void _parallel_on_X_init_chunk( - self, - ITYPE_t thread_num, - ITYPE_t X_start, - ITYPE_t X_end, - ) nogil: - """Initialise datastructures used in a thread given its number.""" - return - - cdef void _parallel_on_X_pre_compute_and_reduce_distances_on_chunks( - self, - ITYPE_t X_start, - ITYPE_t X_end, - ITYPE_t Y_start, - ITYPE_t Y_end, - ITYPE_t thread_num, - ) nogil: - """Initialise datastructures just before the _compute_and_reduce_distances_on_chunks. - - This is eventually used to upcast X[X_start:X_end] to 64bit. - """ - return - - cdef void _parallel_on_X_prange_iter_finalize( - self, - ITYPE_t thread_num, - ITYPE_t X_start, - ITYPE_t X_end, - ) nogil: - """Interact with datastructures after a reduction on chunks.""" - return - - cdef void _parallel_on_X_parallel_finalize( - self, - ITYPE_t thread_num - ) nogil: - """Interact with datastructures after executing all the reductions.""" - return - - cdef void _parallel_on_Y_init( - self, - ) nogil: - """Allocate datastructures used in all threads.""" - return - - cdef void _parallel_on_Y_parallel_init( - self, - ITYPE_t thread_num, - ITYPE_t X_start, - ITYPE_t X_end, - ) nogil: - """Initialise datastructures used in a thread given its number.""" - return - - cdef void _parallel_on_Y_pre_compute_and_reduce_distances_on_chunks( - self, - ITYPE_t X_start, - ITYPE_t X_end, - ITYPE_t Y_start, - ITYPE_t Y_end, - ITYPE_t thread_num, - ) nogil: - """Initialise datastructures just before the _compute_and_reduce_distances_on_chunks. - - This is eventually used to upcast Y[Y_start:Y_end] to 64bit. - """ - return - - cdef void _parallel_on_Y_synchronize( - self, - ITYPE_t X_start, - ITYPE_t X_end, - ) nogil: - """Update thread datastructures before leaving a parallel region.""" - return - - cdef void _parallel_on_Y_finalize( - self, - ) nogil: - """Update datastructures after executing all the reductions.""" - return - -cdef class PairwiseDistancesArgKmin64(PairwiseDistancesReduction64): - """64bit implementation of PairwiseDistancesArgKmin.""" - - cdef: - ITYPE_t k - - ITYPE_t[:, ::1] argkmin_indices - DTYPE_t[:, ::1] argkmin_distances - - # Used as array of pointers to private datastructures used in threads. - DTYPE_t ** heaps_r_distances_chunks - ITYPE_t ** heaps_indices_chunks - - @classmethod - def compute( - cls, - X, - Y, - ITYPE_t k, - str metric="euclidean", - chunk_size=None, - dict metric_kwargs=None, - str strategy=None, - bint return_distance=False, - ): - """Compute the argkmin reduction. - - This classmethod is responsible for introspecting the arguments - values to dispatch to the most appropriate implementation of - :class:`PairwiseDistancesArgKmin64`. - - This allows decoupling the API entirely from the implementation details - whilst maintaining RAII: all temporarily allocated datastructures necessary - for the concrete implementation are therefore freed when this classmethod - returns. - - No instance should directly be created outside of this class method. - """ - if ( - metric in ("euclidean", "sqeuclidean") - and not issparse(X) - and not issparse(Y) - ): - # Specialized implementation with improved arithmetic intensity - # and vector instructions (SIMD) by processing several vectors - # at time to leverage a call to the BLAS GEMM routine as explained - # in more details in the docstring. - use_squared_distances = metric == "sqeuclidean" - pda = FastEuclideanPairwiseDistancesArgKmin64( - X=X, Y=Y, k=k, - use_squared_distances=use_squared_distances, - chunk_size=chunk_size, - strategy=strategy, - metric_kwargs=metric_kwargs, - ) - else: - # Fall back on a generic implementation that handles most scipy - # metrics by computing the distances between 2 vectors at a time. - pda = PairwiseDistancesArgKmin64( - datasets_pair=DatasetsPair.get_for(X, Y, metric, metric_kwargs), - k=k, - chunk_size=chunk_size, - strategy=strategy, - ) - - # Limit the number of threads in second level of nested parallelism for BLAS - # to avoid threads over-subscription (in GEMM for instance). - with threadpool_limits(limits=1, user_api="blas"): - if pda.execute_in_parallel_on_Y: - pda._parallel_on_Y() - else: - pda._parallel_on_X() - - return pda._finalize_results(return_distance) - - def __init__( - self, - DatasetsPair datasets_pair, - chunk_size=None, - strategy=None, - ITYPE_t k=1, - ): - super().__init__( - datasets_pair=datasets_pair, - chunk_size=chunk_size, - strategy=strategy, - ) - self.k = check_scalar(k, "k", Integral, min_val=1) - - # Allocating pointers to datastructures but not the datastructures themselves. - # There are as many pointers as effective threads. - # - # For the sake of explicitness: - # - when parallelizing on X, the pointers of those heaps are referencing - # (with proper offsets) addresses of the two main heaps (see below) - # - when parallelizing on Y, the pointers of those heaps are referencing - # small heaps which are thread-wise-allocated and whose content will be - # merged with the main heaps'. - self.heaps_r_distances_chunks = malloc( - sizeof(DTYPE_t *) * self.chunks_n_threads - ) - self.heaps_indices_chunks = malloc( - sizeof(ITYPE_t *) * self.chunks_n_threads - ) - - # Main heaps which will be returned as results by `PairwiseDistancesArgKmin64.compute`. - self.argkmin_indices = np.full((self.n_samples_X, self.k), 0, dtype=ITYPE) - self.argkmin_distances = np.full((self.n_samples_X, self.k), DBL_MAX, dtype=DTYPE) - - def __dealloc__(self): - if self.heaps_indices_chunks is not NULL: - free(self.heaps_indices_chunks) - - if self.heaps_r_distances_chunks is not NULL: - free(self.heaps_r_distances_chunks) - - cdef void _compute_and_reduce_distances_on_chunks( - self, - ITYPE_t X_start, - ITYPE_t X_end, - ITYPE_t Y_start, - ITYPE_t Y_end, - ITYPE_t thread_num, - ) nogil: - cdef: - ITYPE_t i, j - ITYPE_t n_samples_X = X_end - X_start - ITYPE_t n_samples_Y = Y_end - Y_start - DTYPE_t *heaps_r_distances = self.heaps_r_distances_chunks[thread_num] - ITYPE_t *heaps_indices = self.heaps_indices_chunks[thread_num] - - # Pushing the distances and their associated indices on a heap - # which by construction will keep track of the argkmin. - for i in range(n_samples_X): - for j in range(n_samples_Y): - heap_push( - heaps_r_distances + i * self.k, - heaps_indices + i * self.k, - self.k, - self.datasets_pair.surrogate_dist(X_start + i, Y_start + j), - Y_start + j, - ) - - cdef void _parallel_on_X_init_chunk( - self, - ITYPE_t thread_num, - ITYPE_t X_start, - ITYPE_t X_end, - ) nogil: - # As this strategy is embarrassingly parallel, we can set each - # thread's heaps pointer to the proper position on the main heaps. - self.heaps_r_distances_chunks[thread_num] = &self.argkmin_distances[X_start, 0] - self.heaps_indices_chunks[thread_num] = &self.argkmin_indices[X_start, 0] - - @final - cdef void _parallel_on_X_prange_iter_finalize( - self, - ITYPE_t thread_num, - ITYPE_t X_start, - ITYPE_t X_end, - ) nogil: - cdef: - ITYPE_t idx, jdx - - # Sorting the main heaps portion associated to `X[X_start:X_end]` - # in ascending order w.r.t the distances. - for idx in range(X_end - X_start): - simultaneous_sort( - self.heaps_r_distances_chunks[thread_num] + idx * self.k, - self.heaps_indices_chunks[thread_num] + idx * self.k, - self.k - ) - - cdef void _parallel_on_Y_init( - self, - ) nogil: - cdef: - # Maximum number of scalar elements (the last chunks can be smaller) - ITYPE_t heaps_size = self.X_n_samples_chunk * self.k - ITYPE_t thread_num - - # The allocation is done in parallel for data locality purposes: this way - # the heaps used in each threads are allocated in pages which are closer - # to the CPU core used by the thread. - # See comments about First Touch Placement Policy: - # https://www.openmp.org/wp-content/uploads/openmp-webinar-vanderPas-20210318.pdf #noqa - for thread_num in prange(self.chunks_n_threads, schedule='static', nogil=True, - num_threads=self.chunks_n_threads): - # As chunks of X are shared across threads, so must their - # heaps. To solve this, each thread has its own heaps - # which are then synchronised back in the main ones. - self.heaps_r_distances_chunks[thread_num] = malloc( - heaps_size * sizeof(DTYPE_t) - ) - self.heaps_indices_chunks[thread_num] = malloc( - heaps_size * sizeof(ITYPE_t) - ) - - cdef void _parallel_on_Y_parallel_init( - self, - ITYPE_t thread_num, - ITYPE_t X_start, - ITYPE_t X_end, - ) nogil: - # Initialising heaps (memset can't be used here) - for idx in range(self.X_n_samples_chunk * self.k): - self.heaps_r_distances_chunks[thread_num][idx] = DBL_MAX - self.heaps_indices_chunks[thread_num][idx] = -1 - - @final - cdef void _parallel_on_Y_synchronize( - self, - ITYPE_t X_start, - ITYPE_t X_end, - ) nogil: - cdef: - ITYPE_t idx, jdx, thread_num - with nogil, parallel(num_threads=self.effective_n_threads): - # Synchronising the thread heaps with the main heaps. - # This is done in parallel sample-wise (no need for locks). - # - # This might break each thread's data locality as each heap which - # was allocated in a thread is being now being used in several threads. - # - # Still, this parallel pattern has shown to be efficient in practice. - for idx in prange(X_end - X_start, schedule="static"): - for thread_num in range(self.chunks_n_threads): - for jdx in range(self.k): - heap_push( - &self.argkmin_distances[X_start + idx, 0], - &self.argkmin_indices[X_start + idx, 0], - self.k, - self.heaps_r_distances_chunks[thread_num][idx * self.k + jdx], - self.heaps_indices_chunks[thread_num][idx * self.k + jdx], - ) - - cdef void _parallel_on_Y_finalize( - self, - ) nogil: - cdef: - ITYPE_t idx, thread_num - - with nogil, parallel(num_threads=self.chunks_n_threads): - # Deallocating temporary datastructures - for thread_num in prange(self.chunks_n_threads, schedule='static'): - free(self.heaps_r_distances_chunks[thread_num]) - free(self.heaps_indices_chunks[thread_num]) - - # Sorting the main in ascending order w.r.t the distances. - # This is done in parallel sample-wise (no need for locks). - for idx in prange(self.n_samples_X, schedule='static'): - simultaneous_sort( - &self.argkmin_distances[idx, 0], - &self.argkmin_indices[idx, 0], - self.k, - ) - return - - cdef void compute_exact_distances(self) nogil: - cdef: - ITYPE_t i, j - ITYPE_t[:, ::1] Y_indices = self.argkmin_indices - DTYPE_t[:, ::1] distances = self.argkmin_distances - for i in prange(self.n_samples_X, schedule='static', nogil=True, - num_threads=self.effective_n_threads): - for j in range(self.k): - distances[i, j] = self.datasets_pair.distance_metric._rdist_to_dist( - # Guard against eventual -0., causing nan production. - max(distances[i, j], 0.) - ) - - def _finalize_results(self, bint return_distance=False): - if return_distance: - # We need to recompute distances because we relied on - # surrogate distances for the reduction. - self.compute_exact_distances() - - # Values are returned identically to the way `KNeighborsMixin.kneighbors` - # returns values. This is counter-intuitive but this allows not using - # complex adaptations where `PairwiseDistancesArgKmin64.compute` is called. - return np.asarray(self.argkmin_distances), np.asarray(self.argkmin_indices) - - return np.asarray(self.argkmin_indices) - - -cdef class FastEuclideanPairwiseDistancesArgKmin64(PairwiseDistancesArgKmin64): - """EuclideanDistance-specialized 64bit implementation for PairwiseDistancesArgKmin.""" - cdef: - GEMMTermComputer64 gemm_term_computer - const DTYPE_t[::1] X_norm_squared - const DTYPE_t[::1] Y_norm_squared - - bint use_squared_distances - - @classmethod - def is_usable_for(cls, X, Y, metric) -> bool: - return (PairwiseDistancesArgKmin64.is_usable_for(X, Y, metric) and - not _in_unstable_openblas_configuration()) - - def __init__( - self, - X, - Y, - ITYPE_t k, - bint use_squared_distances=False, - chunk_size=None, - strategy=None, - metric_kwargs=None, - ): - if ( - metric_kwargs is not None and - len(metric_kwargs) > 0 and - "Y_norm_squared" not in metric_kwargs - ): - warnings.warn( - f"Some metric_kwargs have been passed ({metric_kwargs}) but aren't " - f"usable for this case (FastEuclideanPairwiseDistancesArgKmin) and will be ignored.", - UserWarning, - stacklevel=3, - ) - - super().__init__( - # The datasets pair here is used for exact distances computations - datasets_pair=DatasetsPair.get_for(X, Y, metric="euclidean"), - chunk_size=chunk_size, - strategy=strategy, - k=k, - ) - # X and Y are checked by the DatasetsPair implemented as a DenseDenseDatasetsPair - cdef: - DenseDenseDatasetsPair datasets_pair = ( - self.datasets_pair - ) - ITYPE_t dist_middle_terms_chunks_size = self.Y_n_samples_chunk * self.X_n_samples_chunk - - self.gemm_term_computer = GEMMTermComputer64( - datasets_pair.X, - datasets_pair.Y, - self.effective_n_threads, - self.chunks_n_threads, - dist_middle_terms_chunks_size, - n_features=datasets_pair.X.shape[1], - chunk_size=self.chunk_size, - ) - - if metric_kwargs is not None and "Y_norm_squared" in metric_kwargs: - self.Y_norm_squared = metric_kwargs.pop("Y_norm_squared") - else: - self.Y_norm_squared = _sqeuclidean_row_norms64(datasets_pair.Y, self.effective_n_threads) - - # Do not recompute norms if datasets are identical. - self.X_norm_squared = ( - self.Y_norm_squared if X is Y else - _sqeuclidean_row_norms64(datasets_pair.X, self.effective_n_threads) - ) - self.use_squared_distances = use_squared_distances - - @final - cdef void compute_exact_distances(self) nogil: - if not self.use_squared_distances: - PairwiseDistancesArgKmin64.compute_exact_distances(self) - - @final - cdef void _parallel_on_X_parallel_init( - self, - ITYPE_t thread_num, - ) nogil: - PairwiseDistancesArgKmin64._parallel_on_X_parallel_init(self, thread_num) - self.gemm_term_computer._parallel_on_X_parallel_init(thread_num) - - - @final - cdef void _parallel_on_X_init_chunk( - self, - ITYPE_t thread_num, - ITYPE_t X_start, - ITYPE_t X_end, - ) nogil: - PairwiseDistancesArgKmin64._parallel_on_X_init_chunk(self, thread_num, X_start, X_end) - self.gemm_term_computer._parallel_on_X_init_chunk(thread_num, X_start, X_end) - - - @final - cdef void _parallel_on_X_pre_compute_and_reduce_distances_on_chunks( - self, - ITYPE_t X_start, - ITYPE_t X_end, - ITYPE_t Y_start, - ITYPE_t Y_end, - ITYPE_t thread_num, - ) nogil: - PairwiseDistancesArgKmin64._parallel_on_X_pre_compute_and_reduce_distances_on_chunks( - self, - X_start, X_end, - Y_start, Y_end, - thread_num, - ) - self.gemm_term_computer._parallel_on_X_pre_compute_and_reduce_distances_on_chunks( - X_start, X_end, Y_start, Y_end, thread_num, - ) - - - @final - cdef void _parallel_on_Y_init( - self, - ) nogil: - cdef ITYPE_t thread_num - PairwiseDistancesArgKmin64._parallel_on_Y_init(self) - self.gemm_term_computer._parallel_on_Y_init() - - - @final - cdef void _parallel_on_Y_parallel_init( - self, - ITYPE_t thread_num, - ITYPE_t X_start, - ITYPE_t X_end, - ) nogil: - PairwiseDistancesArgKmin64._parallel_on_Y_parallel_init(self, thread_num, X_start, X_end) - self.gemm_term_computer._parallel_on_Y_parallel_init(thread_num, X_start, X_end) - - - @final - cdef void _parallel_on_Y_pre_compute_and_reduce_distances_on_chunks( - self, - ITYPE_t X_start, - ITYPE_t X_end, - ITYPE_t Y_start, - ITYPE_t Y_end, - ITYPE_t thread_num, - ) nogil: - PairwiseDistancesArgKmin64._parallel_on_Y_pre_compute_and_reduce_distances_on_chunks( - self, - X_start, X_end, - Y_start, Y_end, - thread_num, - ) - self.gemm_term_computer._parallel_on_Y_pre_compute_and_reduce_distances_on_chunks( - X_start, X_end, Y_start, Y_end, thread_num - ) - - - @final - cdef void _compute_and_reduce_distances_on_chunks( - self, - ITYPE_t X_start, - ITYPE_t X_end, - ITYPE_t Y_start, - ITYPE_t Y_end, - ITYPE_t thread_num, - ) nogil: - cdef: - ITYPE_t i, j - DTYPE_t squared_dist_i_j - ITYPE_t n_X = X_end - X_start - ITYPE_t n_Y = Y_end - Y_start - DTYPE_t * dist_middle_terms = self.gemm_term_computer._compute_distances_on_chunks( - X_start, X_end, Y_start, Y_end, thread_num - ) - DTYPE_t * heaps_r_distances = self.heaps_r_distances_chunks[thread_num] - ITYPE_t * heaps_indices = self.heaps_indices_chunks[thread_num] - - - # Pushing the distance and their associated indices on heaps - # which keep tracks of the argkmin. - for i in range(n_X): - for j in range(n_Y): - heap_push( - heaps_r_distances + i * self.k, - heaps_indices + i * self.k, - self.k, - # Using the squared euclidean distance as the rank-preserving distance: - # - # ||X_c_i||² - 2 X_c_i.Y_c_j^T + ||Y_c_j||² - # - ( - self.X_norm_squared[i + X_start] + - dist_middle_terms[i * n_Y + j] + - self.Y_norm_squared[j + Y_start] - ), - j + Y_start, - ) - - -cdef class PairwiseDistancesRadiusNeighborhood64(PairwiseDistancesReduction64): - """64bit implementation of PairwiseDistancesArgKmin.""" - - cdef: - DTYPE_t radius - - # DistanceMetric compute rank-preserving surrogate distance via rdist - # which are proxies necessitating less computations. - # We get the equivalent for the radius to be able to compare it against - # vectors' rank-preserving surrogate distances. - DTYPE_t r_radius - - # Neighbors indices and distances are returned as np.ndarrays of np.ndarrays. - # - # For this implementation, we want resizable buffers which we will wrap - # into numpy arrays at the end. std::vector comes as a handy container - # for interacting efficiently with resizable buffers. - # - # Though it is possible to access their buffer address with - # std::vector::data, they can't be stolen: buffers lifetime - # is tied to their std::vector and are deallocated when - # std::vectors are. - # - # To solve this, we dynamically allocate std::vectors and then - # encapsulate them in a StdVectorSentinel responsible for - # freeing them when the associated np.ndarray is freed. - # - # Shared pointers (defined via shared_ptr) are use for safer memory management. - # Unique pointers (defined via unique_ptr) can't be used as datastructures - # are shared across threads for parallel_on_X; see _parallel_on_X_init_chunk. - shared_ptr[vector[vector[ITYPE_t]]] neigh_indices - shared_ptr[vector[vector[DTYPE_t]]] neigh_distances - - # Used as array of pointers to private datastructures used in threads. - vector[shared_ptr[vector[vector[ITYPE_t]]]] neigh_indices_chunks - vector[shared_ptr[vector[vector[DTYPE_t]]]] neigh_distances_chunks - - bint sort_results - - @classmethod - def compute( - cls, - X, - Y, - DTYPE_t radius, - str metric="euclidean", - chunk_size=None, - dict metric_kwargs=None, - str strategy=None, - bint return_distance=False, - bint sort_results=False, - ): - """Compute the radius-neighbors reduction. - - This classmethod is responsible for introspecting the arguments - values to dispatch to the most appropriate implementation of - :class:`PairwiseDistancesRadiusNeighborhood64`. - - This allows decoupling the API entirely from the implementation details - whilst maintaining RAII: all temporarily allocated datastructures necessary - for the concrete implementation are therefore freed when this classmethod - returns. - - No instance should directly be created outside of this class method. - """ - if ( - metric in ("euclidean", "sqeuclidean") - and not issparse(X) - and not issparse(Y) - ): - # Specialized implementation with improved arithmetic intensity - # and vector instructions (SIMD) by processing several vectors - # at time to leverage a call to the BLAS GEMM routine as explained - # in more details in the docstring. - use_squared_distances = metric == "sqeuclidean" - pda = FastEuclideanPairwiseDistancesRadiusNeighborhood64( - X=X, Y=Y, radius=radius, - use_squared_distances=use_squared_distances, - chunk_size=chunk_size, - metric_kwargs=metric_kwargs, - strategy=strategy, - sort_results=sort_results, - ) - else: - # Fall back on a generic implementation that handles most scipy - # metrics by computing the distances between 2 vectors at a time. - pda = PairwiseDistancesRadiusNeighborhood64( - datasets_pair=DatasetsPair.get_for(X, Y, metric, metric_kwargs), - radius=radius, - chunk_size=chunk_size, - metric_kwargs=metric_kwargs, - strategy=strategy, - sort_results=sort_results, - ) - - # Limit the number of threads in second level of nested parallelism for BLAS - # to avoid threads over-subscription (in GEMM for instance). - with threadpool_limits(limits=1, user_api="blas"): - if pda.execute_in_parallel_on_Y: - pda._parallel_on_Y() - else: - pda._parallel_on_X() - - return pda._finalize_results(return_distance) - - - def __init__( - self, - DatasetsPair datasets_pair, - DTYPE_t radius, - chunk_size=None, - strategy=None, - sort_results=False, - metric_kwargs=None, - ): - super().__init__( - datasets_pair=datasets_pair, - chunk_size=chunk_size, - strategy=strategy, - ) - - self.radius = check_scalar(radius, "radius", Real, min_val=0) - self.r_radius = self.datasets_pair.distance_metric._dist_to_rdist(radius) - self.sort_results = sort_results - - # Allocating pointers to datastructures but not the datastructures themselves. - # There are as many pointers as effective threads. - # - # For the sake of explicitness: - # - when parallelizing on X, the pointers of those heaps are referencing - # self.neigh_distances and self.neigh_indices - # - when parallelizing on Y, the pointers of those heaps are referencing - # std::vectors of std::vectors which are thread-wise-allocated and whose - # content will be merged into self.neigh_distances and self.neigh_indices. - self.neigh_distances_chunks = vector[shared_ptr[vector[vector[DTYPE_t]]]]( - self.chunks_n_threads - ) - self.neigh_indices_chunks = vector[shared_ptr[vector[vector[ITYPE_t]]]]( - self.chunks_n_threads - ) - - # Temporary datastructures which will be coerced to numpy arrays on before - # PairwiseDistancesRadiusNeighborhood.compute "return" and will be then freed. - self.neigh_distances = make_shared[vector[vector[DTYPE_t]]](self.n_samples_X) - self.neigh_indices = make_shared[vector[vector[ITYPE_t]]](self.n_samples_X) - - cdef void _compute_and_reduce_distances_on_chunks( - self, - ITYPE_t X_start, - ITYPE_t X_end, - ITYPE_t Y_start, - ITYPE_t Y_end, - ITYPE_t thread_num, - ) nogil: - cdef: - ITYPE_t i, j - DTYPE_t r_dist_i_j - - for i in range(X_start, X_end): - for j in range(Y_start, Y_end): - r_dist_i_j = self.datasets_pair.surrogate_dist(i, j) - if r_dist_i_j <= self.r_radius: - deref(self.neigh_distances_chunks[thread_num])[i].push_back(r_dist_i_j) - deref(self.neigh_indices_chunks[thread_num])[i].push_back(j) - - def _finalize_results(self, bint return_distance=False): - if return_distance: - # We need to recompute distances because we relied on - # surrogate distances for the reduction. - self.compute_exact_distances() - return ( - coerce_vectors_to_nd_arrays(self.neigh_distances), - coerce_vectors_to_nd_arrays(self.neigh_indices), - ) - - return coerce_vectors_to_nd_arrays(self.neigh_indices) - - cdef void _parallel_on_X_init_chunk( - self, - ITYPE_t thread_num, - ITYPE_t X_start, - ITYPE_t X_end, - ) nogil: - - # As this strategy is embarrassingly parallel, we can set the - # thread vectors' pointers to the main vectors'. - self.neigh_distances_chunks[thread_num] = self.neigh_distances - self.neigh_indices_chunks[thread_num] = self.neigh_indices - - @final - cdef void _parallel_on_X_prange_iter_finalize( - self, - ITYPE_t thread_num, - ITYPE_t X_start, - ITYPE_t X_end, - ) nogil: - cdef: - ITYPE_t idx, jdx - - # Sorting neighbors for each query vector of X - if self.sort_results: - for idx in range(X_start, X_end): - simultaneous_sort( - deref(self.neigh_distances)[idx].data(), - deref(self.neigh_indices)[idx].data(), - deref(self.neigh_indices)[idx].size() - ) - - cdef void _parallel_on_Y_init( - self, - ) nogil: - cdef: - ITYPE_t thread_num - # As chunks of X are shared across threads, so must datastructures to avoid race - # conditions: each thread has its own vectors of n_samples_X vectors which are - # then merged back in the main n_samples_X vectors. - for thread_num in range(self.chunks_n_threads): - self.neigh_distances_chunks[thread_num] = make_shared[vector[vector[DTYPE_t]]](self.n_samples_X) - self.neigh_indices_chunks[thread_num] = make_shared[vector[vector[ITYPE_t]]](self.n_samples_X) - - @final - cdef void _merge_vectors( - self, - ITYPE_t idx, - ITYPE_t num_threads, - ) nogil: - cdef: - ITYPE_t thread_num - ITYPE_t idx_n_elements = 0 - ITYPE_t last_element_idx = deref(self.neigh_indices)[idx].size() - - # Resizing buffers only once for the given number of elements. - for thread_num in range(num_threads): - idx_n_elements += deref(self.neigh_distances_chunks[thread_num])[idx].size() - - deref(self.neigh_distances)[idx].resize(last_element_idx + idx_n_elements) - deref(self.neigh_indices)[idx].resize(last_element_idx + idx_n_elements) - - # Moving the elements by range using the range first element - # as the reference for the insertion. - for thread_num in range(num_threads): - move( - deref(self.neigh_distances_chunks[thread_num])[idx].begin(), - deref(self.neigh_distances_chunks[thread_num])[idx].end(), - deref(self.neigh_distances)[idx].begin() + last_element_idx - ) - move( - deref(self.neigh_indices_chunks[thread_num])[idx].begin(), - deref(self.neigh_indices_chunks[thread_num])[idx].end(), - deref(self.neigh_indices)[idx].begin() + last_element_idx - ) - last_element_idx += deref(self.neigh_distances_chunks[thread_num])[idx].size() - - - cdef void _parallel_on_Y_finalize( - self, - ) nogil: - cdef: - ITYPE_t idx, jdx, thread_num, idx_n_element, idx_current - - with nogil, parallel(num_threads=self.effective_n_threads): - # Merge vectors used in threads into the main ones. - # This is done in parallel sample-wise (no need for locks) - # using dynamic scheduling because we might not have - # the same number of neighbors for each query vector. - for idx in prange(self.n_samples_X, schedule='static'): - self._merge_vectors(idx, self.chunks_n_threads) - - # The content of the vector have been std::moved. - # Hence they can't be used anymore and can be deleted. - # Their deletion is carried out automatically as the - # implementation relies on shared pointers. - - # Sort in parallel in ascending order w.r.t the distances if requested. - if self.sort_results: - for idx in prange(self.n_samples_X, schedule='static'): - simultaneous_sort( - deref(self.neigh_distances)[idx].data(), - deref(self.neigh_indices)[idx].data(), - deref(self.neigh_indices)[idx].size() - ) - - return - - cdef void compute_exact_distances(self) nogil: - """Convert rank-preserving distances to pairwise distances in parallel.""" - cdef: - ITYPE_t i, j - - for i in prange(self.n_samples_X, nogil=True, schedule='static', - num_threads=self.effective_n_threads): - for j in range(deref(self.neigh_indices)[i].size()): - deref(self.neigh_distances)[i][j] = ( - self.datasets_pair.distance_metric._rdist_to_dist( - # Guard against eventual -0., causing nan production. - max(deref(self.neigh_distances)[i][j], 0.) - ) - ) - - -cdef class FastEuclideanPairwiseDistancesRadiusNeighborhood64(PairwiseDistancesRadiusNeighborhood64): - """EuclideanDistance-specialized 64bit implementation for PairwiseDistancesRadiusNeighborhood.""" - cdef: - GEMMTermComputer64 gemm_term_computer - const DTYPE_t[::1] X_norm_squared - const DTYPE_t[::1] Y_norm_squared - - bint use_squared_distances - - @classmethod - def is_usable_for(cls, X, Y, metric) -> bool: - return (PairwiseDistancesRadiusNeighborhood64.is_usable_for(X, Y, metric) - and not _in_unstable_openblas_configuration()) - - def __init__( - self, - X, - Y, - DTYPE_t radius, - bint use_squared_distances=False, - chunk_size=None, - strategy=None, - sort_results=False, - metric_kwargs=None, - ): - if ( - metric_kwargs is not None and - len(metric_kwargs) > 0 and - "Y_norm_squared" not in metric_kwargs - ): - warnings.warn( - f"Some metric_kwargs have been passed ({metric_kwargs}) but aren't " - f"usable for this case (FastEuclideanPairwiseDistancesRadiusNeighborhood) and will be ignored.", - UserWarning, - stacklevel=3, - ) - - super().__init__( - # The datasets pair here is used for exact distances computations - datasets_pair=DatasetsPair.get_for(X, Y, metric="euclidean"), - radius=radius, - chunk_size=chunk_size, - strategy=strategy, - sort_results=sort_results, - metric_kwargs=metric_kwargs, - ) - # X and Y are checked by the DatasetsPair implemented as a DenseDenseDatasetsPair - cdef: - DenseDenseDatasetsPair datasets_pair = self.datasets_pair - ITYPE_t dist_middle_terms_chunks_size = self.Y_n_samples_chunk * self.X_n_samples_chunk - - self.gemm_term_computer = GEMMTermComputer64( - datasets_pair.X, - datasets_pair.Y, - self.effective_n_threads, - self.chunks_n_threads, - dist_middle_terms_chunks_size, - n_features=datasets_pair.X.shape[1], - chunk_size=self.chunk_size, - ) - - if metric_kwargs is not None and "Y_norm_squared" in metric_kwargs: - self.Y_norm_squared = metric_kwargs.pop("Y_norm_squared") - else: - self.Y_norm_squared = _sqeuclidean_row_norms64(datasets_pair.Y, self.effective_n_threads) - - # Do not recompute norms if datasets are identical. - self.X_norm_squared = ( - self.Y_norm_squared if X is Y else - _sqeuclidean_row_norms64(datasets_pair.X, self.effective_n_threads) - ) - self.use_squared_distances = use_squared_distances - - if use_squared_distances: - # In this specialisation and this setup, the value passed to the radius is - # already considered to be the adapted radius, so we overwrite it. - self.r_radius = radius - - @final - cdef void _parallel_on_X_parallel_init( - self, - ITYPE_t thread_num, - ) nogil: - PairwiseDistancesRadiusNeighborhood64._parallel_on_X_parallel_init(self, thread_num) - self.gemm_term_computer._parallel_on_X_parallel_init(thread_num) - - @final - cdef void _parallel_on_X_init_chunk( - self, - ITYPE_t thread_num, - ITYPE_t X_start, - ITYPE_t X_end, - ) nogil: - PairwiseDistancesRadiusNeighborhood64._parallel_on_X_init_chunk(self, thread_num, X_start, X_end) - self.gemm_term_computer._parallel_on_X_init_chunk(thread_num, X_start, X_end) - - @final - cdef void _parallel_on_X_pre_compute_and_reduce_distances_on_chunks( - self, - ITYPE_t X_start, - ITYPE_t X_end, - ITYPE_t Y_start, - ITYPE_t Y_end, - ITYPE_t thread_num, - ) nogil: - PairwiseDistancesRadiusNeighborhood64._parallel_on_X_pre_compute_and_reduce_distances_on_chunks( - self, - X_start, X_end, - Y_start, Y_end, - thread_num, - ) - self.gemm_term_computer._parallel_on_X_pre_compute_and_reduce_distances_on_chunks( - X_start, X_end, Y_start, Y_end, thread_num, - ) - - @final - cdef void _parallel_on_Y_init( - self, - ) nogil: - cdef ITYPE_t thread_num - PairwiseDistancesRadiusNeighborhood64._parallel_on_Y_init(self) - self.gemm_term_computer._parallel_on_Y_init() - - @final - cdef void _parallel_on_Y_parallel_init( - self, - ITYPE_t thread_num, - ITYPE_t X_start, - ITYPE_t X_end, - ) nogil: - PairwiseDistancesRadiusNeighborhood64._parallel_on_Y_parallel_init(self, thread_num, X_start, X_end) - self.gemm_term_computer._parallel_on_Y_parallel_init(thread_num, X_start, X_end) - - @final - cdef void _parallel_on_Y_pre_compute_and_reduce_distances_on_chunks( - self, - ITYPE_t X_start, - ITYPE_t X_end, - ITYPE_t Y_start, - ITYPE_t Y_end, - ITYPE_t thread_num, - ) nogil: - PairwiseDistancesRadiusNeighborhood64._parallel_on_Y_pre_compute_and_reduce_distances_on_chunks( - self, - X_start, X_end, - Y_start, Y_end, - thread_num, - ) - self.gemm_term_computer._parallel_on_Y_pre_compute_and_reduce_distances_on_chunks( - X_start, X_end, Y_start, Y_end, thread_num - ) - - @final - cdef void compute_exact_distances(self) nogil: - if not self.use_squared_distances: - PairwiseDistancesRadiusNeighborhood64.compute_exact_distances(self) - - @final - cdef void _compute_and_reduce_distances_on_chunks( - self, - ITYPE_t X_start, - ITYPE_t X_end, - ITYPE_t Y_start, - ITYPE_t Y_end, - ITYPE_t thread_num, - ) nogil: - cdef: - ITYPE_t i, j - DTYPE_t squared_dist_i_j - ITYPE_t n_X = X_end - X_start - ITYPE_t n_Y = Y_end - Y_start - DTYPE_t *dist_middle_terms = self.gemm_term_computer._compute_distances_on_chunks( - X_start, X_end, Y_start, Y_end, thread_num - ) - - # Pushing the distance and their associated indices in vectors. - for i in range(n_X): - for j in range(n_Y): - # Using the squared euclidean distance as the rank-preserving distance: - # - # ||X_c_i||² - 2 X_c_i.Y_c_j^T + ||Y_c_j||² - # - squared_dist_i_j = ( - self.X_norm_squared[i + X_start] - + dist_middle_terms[i * n_Y + j] - + self.Y_norm_squared[j + Y_start] - ) - if squared_dist_i_j <= self.r_radius: - deref(self.neigh_distances_chunks[thread_num])[i + X_start].push_back(squared_dist_i_j) - deref(self.neigh_indices_chunks[thread_num])[i + X_start].push_back(j + Y_start) diff --git a/sklearn/metrics/_pairwise_distances_reduction/__init__.py b/sklearn/metrics/_pairwise_distances_reduction/__init__.py index 943afc720d1fe..78fe01cea5948 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/__init__.py +++ b/sklearn/metrics/_pairwise_distances_reduction/__init__.py @@ -85,6 +85,7 @@ # to optimally handle the Euclidean distance case using the Generalized Matrix # Multiplication (see the docstring of :class:`GEMMTermComputer64` for details). + from ._dispatcher import ( PairwiseDistancesReduction, PairwiseDistances, From f573a596fe9ab0d85e7186c945228c078ef58ff5 Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Fri, 26 Aug 2022 13:30:20 +0200 Subject: [PATCH 05/36] fixup! Merge branch 'main' into feat/pairwise_distances-pdr-backend Switch condition for safety --- sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py b/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py index 6f4e46288413d..0c67089806bb6 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py +++ b/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py @@ -148,7 +148,7 @@ class PairwiseDistances(PairwiseDistancesReduction): @classmethod def is_usable_for(cls, X, Y, metric) -> bool: # TODO: support float32 - return X.dtype == np.float64 and super().is_usable_for(X, Y, metric) + return super().is_usable_for(X, Y, metric) and X.dtype == np.float64 @classmethod def compute( From 8e17871357137284cace94902476d8d5b0a4f6b2 Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Fri, 2 Sep 2022 22:50:24 +0200 Subject: [PATCH 06/36] Do not offset by X_start and Y_start --- .../_pairwise_distances.pyx | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/sklearn/metrics/_pairwise_distances_reduction/_pairwise_distances.pyx b/sklearn/metrics/_pairwise_distances_reduction/_pairwise_distances.pyx index b6cd8affcc763..8105c9adac1ef 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_pairwise_distances.pyx +++ b/sklearn/metrics/_pairwise_distances_reduction/_pairwise_distances.pyx @@ -141,7 +141,7 @@ cdef class PairwiseDistances64(BaseDistanceReducer64): ) def _finalize_results(self): - # If Y is X, then catastrophic cancellation might + # If X is Y, then catastrophic cancellation might # have occurred for computations of term on the diagonal # which must be null. We enforce nullity of those term # by zeroing the diagonal. @@ -166,7 +166,7 @@ cdef class PairwiseDistances64(BaseDistanceReducer64): for i in range(X_start, X_end): for j in range(Y_start, Y_end): dist_i_j = self.datasets_pair.dist(i, j) - self.pairwise_distances_matrix[X_start + i, Y_start + j] = dist_i_j + self.pairwise_distances_matrix[i, j] = dist_i_j cdef class EuclideanPairwiseDistances64(PairwiseDistances64): @@ -332,10 +332,10 @@ cdef class EuclideanPairwiseDistances64(PairwiseDistances64): # # ||X_c_i||² - 2 X_c_i.Y_c_j^T + ||Y_c_j||² # - self.pairwise_distances_matrix[i + X_start, j + Y_start] = ( - self.X_norm_squared[i + X_start] + self.pairwise_distances_matrix[i, j] = ( + self.X_norm_squared[i] + dist_middle_terms[i * n_Y + j] - + self.Y_norm_squared[j + Y_start] + + self.Y_norm_squared[j] ) def _finalize_results(self): From 1f1a3ce27119efcb8df987af3b6295c651501eb4 Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Fri, 2 Sep 2022 23:03:18 +0200 Subject: [PATCH 07/36] Do not progated metric_kwargs unneedlessly --- .../_pairwise_distances_reduction/_pairwise_distances.pyx | 3 --- 1 file changed, 3 deletions(-) diff --git a/sklearn/metrics/_pairwise_distances_reduction/_pairwise_distances.pyx b/sklearn/metrics/_pairwise_distances_reduction/_pairwise_distances.pyx index 8105c9adac1ef..557e7d7a89ae9 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_pairwise_distances.pyx +++ b/sklearn/metrics/_pairwise_distances_reduction/_pairwise_distances.pyx @@ -106,7 +106,6 @@ cdef class PairwiseDistances64(BaseDistanceReducer64): pdr = PairwiseDistances64( datasets_pair=DatasetsPair64.get_for(X, Y, metric, metric_kwargs), chunk_size=chunk_size, - metric_kwargs=metric_kwargs, strategy=strategy, ) @@ -127,7 +126,6 @@ cdef class PairwiseDistances64(BaseDistanceReducer64): chunk_size=None, strategy=None, sort_results=False, - metric_kwargs=None, ): super().__init__( datasets_pair=datasets_pair, @@ -203,7 +201,6 @@ cdef class EuclideanPairwiseDistances64(PairwiseDistances64): datasets_pair=DatasetsPair64.get_for(X, Y, metric="euclidean"), chunk_size=chunk_size, strategy=strategy, - metric_kwargs=metric_kwargs, ) # X and Y are checked by the DatasetsPair64 implemented as a DenseDenseDatasetsPair64 cdef: From 045a7b2579fdec69cfc5821940e9569ffb5180e8 Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Mon, 19 Sep 2022 09:36:57 +0200 Subject: [PATCH 08/36] Use the proper vectors' indices --- .../_pairwise_distances_reduction/_pairwise_distances.pyx | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/sklearn/metrics/_pairwise_distances_reduction/_pairwise_distances.pyx b/sklearn/metrics/_pairwise_distances_reduction/_pairwise_distances.pyx index 557e7d7a89ae9..5048ba92080ac 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_pairwise_distances.pyx +++ b/sklearn/metrics/_pairwise_distances_reduction/_pairwise_distances.pyx @@ -316,7 +316,6 @@ cdef class EuclideanPairwiseDistances64(PairwiseDistances64): ) nogil: cdef: ITYPE_t i, j - DTYPE_t squared_dist_i_j ITYPE_t n_X = X_end - X_start ITYPE_t n_Y = Y_end - Y_start DTYPE_t *dist_middle_terms = self.gemm_term_computer._compute_distances_on_chunks( @@ -329,10 +328,10 @@ cdef class EuclideanPairwiseDistances64(PairwiseDistances64): # # ||X_c_i||² - 2 X_c_i.Y_c_j^T + ||Y_c_j||² # - self.pairwise_distances_matrix[i, j] = ( - self.X_norm_squared[i] + self.pairwise_distances_matrix[X_start + i, Y_start + j] = ( + self.X_norm_squared[X_start + i] + dist_middle_terms[i * n_Y + j] - + self.Y_norm_squared[j] + + self.Y_norm_squared[Y_start + j] ) def _finalize_results(self): From 2b70db8b87531ce5463f5d5eabbbeed32a297002 Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Mon, 19 Sep 2022 10:49:54 +0200 Subject: [PATCH 09/36] Port PairwiseDistances to support float32 datasets --- .gitignore | 2 + setup.cfg | 2 + .../_dispatcher.py | 19 ++-- .../_pairwise_distances.pxd | 24 ----- .../_pairwise_distances.pxd.tp | 25 +++++ ...stances.pyx => _pairwise_distances.pyx.tp} | 97 +++++++++++-------- .../_pairwise_distances_reduction/setup.py | 3 + 7 files changed, 100 insertions(+), 72 deletions(-) delete mode 100644 sklearn/metrics/_pairwise_distances_reduction/_pairwise_distances.pxd create mode 100644 sklearn/metrics/_pairwise_distances_reduction/_pairwise_distances.pxd.tp rename sklearn/metrics/_pairwise_distances_reduction/{_pairwise_distances.pyx => _pairwise_distances.pyx.tp} (77%) diff --git a/.gitignore b/.gitignore index 24f562af3df15..d093aa71a8265 100644 --- a/.gitignore +++ b/.gitignore @@ -95,5 +95,7 @@ sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pxd sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pyx sklearn/metrics/_pairwise_distances_reduction/_gemm_term_computer.pxd sklearn/metrics/_pairwise_distances_reduction/_gemm_term_computer.pyx +sklearn/metrics/_pairwise_distances_reduction/_pairwise_distances.pxd +sklearn/metrics/_pairwise_distances_reduction/_pairwise_distances.pyx sklearn/metrics/_pairwise_distances_reduction/_radius_neighborhood.pxd sklearn/metrics/_pairwise_distances_reduction/_radius_neighborhood.pyx diff --git a/setup.cfg b/setup.cfg index 81fbbffadb233..2759f8adaf437 100644 --- a/setup.cfg +++ b/setup.cfg @@ -79,6 +79,8 @@ ignore = sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pyx sklearn/metrics/_pairwise_distances_reduction/_gemm_term_computer.pxd sklearn/metrics/_pairwise_distances_reduction/_gemm_term_computer.pyx + sklearn/metrics/_pairwise_distances_reduction/_pairwise_distances.pxd + sklearn/metrics/_pairwise_distances_reduction/_pairwise_distances.pyx sklearn/metrics/_pairwise_distances_reduction/_radius_neighborhood.pxd sklearn/metrics/_pairwise_distances_reduction/_radius_neighborhood.pyx diff --git a/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py b/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py index 49c6808670bbd..81f63615adb05 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py +++ b/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py @@ -15,6 +15,7 @@ ) from ._pairwise_distances import ( PairwiseDistances64, + PairwiseDistances32, ) from ._radius_neighborhood import ( RadiusNeighbors64, @@ -145,11 +146,6 @@ class PairwiseDistances(BaseDistanceReductionDispatcher): deallocation consistently. """ - @classmethod - def is_usable_for(cls, X, Y, metric) -> bool: - # TODO: support float32 - return super().is_usable_for(X, Y, metric) and X.dtype == np.float64 - @classmethod def compute( cls, @@ -243,8 +239,19 @@ def compute( metric_kwargs=metric_kwargs, strategy=strategy, ) + + if X.dtype == Y.dtype == np.float32: + return PairwiseDistances32.compute( + X=X, + Y=Y, + metric=metric, + chunk_size=chunk_size, + metric_kwargs=metric_kwargs, + strategy=strategy, + ) + raise ValueError( - "Only 64bit float datasets are supported at this time, " + "Only float64 or float32 datasets pairs are supported at this time, " f"got: X.dtype={X.dtype} and Y.dtype={Y.dtype}." ) diff --git a/sklearn/metrics/_pairwise_distances_reduction/_pairwise_distances.pxd b/sklearn/metrics/_pairwise_distances_reduction/_pairwise_distances.pxd deleted file mode 100644 index 632c3cd533c40..0000000000000 --- a/sklearn/metrics/_pairwise_distances_reduction/_pairwise_distances.pxd +++ /dev/null @@ -1,24 +0,0 @@ -cimport numpy as cnp - -from ._base cimport BaseDistanceReducer64 -from ._gemm_term_computer cimport GEMMTermComputer64 - -from ...utils._typedefs cimport DTYPE_t - -cnp.import_array() - -cdef class PairwiseDistances64(BaseDistanceReducer64): - """64bit implementation of PairwiseDistances.""" - - cdef: - DTYPE_t[:, ::1] pairwise_distances_matrix - - -cdef class EuclideanPairwiseDistances64(PairwiseDistances64): - """EuclideanDistance-specialized 64bit implementation for PairwiseDistances.""" - cdef: - GEMMTermComputer64 gemm_term_computer - const DTYPE_t[::1] X_norm_squared - const DTYPE_t[::1] Y_norm_squared - - bint use_squared_distances diff --git a/sklearn/metrics/_pairwise_distances_reduction/_pairwise_distances.pxd.tp b/sklearn/metrics/_pairwise_distances_reduction/_pairwise_distances.pxd.tp new file mode 100644 index 0000000000000..9e457ecb218cc --- /dev/null +++ b/sklearn/metrics/_pairwise_distances_reduction/_pairwise_distances.pxd.tp @@ -0,0 +1,25 @@ +from ...utils._typedefs cimport DTYPE_t + +{{for name_suffix in ['32', '64']}} + +from ._base cimport BaseDistanceReducer{{name_suffix}} +from ._gemm_term_computer cimport GEMMTermComputer{{name_suffix}} + + +cdef class PairwiseDistances{{name_suffix}}(BaseDistanceReducer{{name_suffix}}): + """{{name_suffix}}bit implementation of PairwiseDistances.""" + + cdef: + DTYPE_t[:, ::1] pairwise_distances_matrix + + +cdef class EuclideanPairwiseDistances{{name_suffix}}(PairwiseDistances{{name_suffix}}): + """EuclideanDistance-specialized {{name_suffix}}bit implementation for PairwiseDistances.""" + cdef: + GEMMTermComputer{{name_suffix}} gemm_term_computer + const DTYPE_t[::1] X_norm_squared + const DTYPE_t[::1] Y_norm_squared + + bint use_squared_distances + +{{endfor}} diff --git a/sklearn/metrics/_pairwise_distances_reduction/_pairwise_distances.pyx b/sklearn/metrics/_pairwise_distances_reduction/_pairwise_distances.pyx.tp similarity index 77% rename from sklearn/metrics/_pairwise_distances_reduction/_pairwise_distances.pyx rename to sklearn/metrics/_pairwise_distances_reduction/_pairwise_distances.pyx.tp index 5048ba92080ac..587096323f663 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_pairwise_distances.pyx +++ b/sklearn/metrics/_pairwise_distances_reduction/_pairwise_distances.pyx.tp @@ -1,27 +1,13 @@ cimport numpy as cnp - from cython cimport final - -from ._base cimport ( - BaseDistanceReducer64, - _sqeuclidean_row_norms64, -) - -from ._datasets_pair cimport ( - DatasetsPair64, - DenseDenseDatasetsPair64, -) - -from ._gemm_term_computer cimport GEMMTermComputer64 - from ...utils._typedefs cimport ITYPE_t, DTYPE_t import numpy as np import warnings from scipy.sparse import issparse -from sklearn.utils import _in_unstable_openblas_configuration -from sklearn.utils.fixes import threadpool_limits, sp_version, parse_version +from ...utils import _in_unstable_openblas_configuration +from ...utils.fixes import threadpool_limits, sp_version, parse_version from ...utils._typedefs import ITYPE, DTYPE cnp.import_array() @@ -52,9 +38,23 @@ def _precompute_metric_params(X, Y, metric=None, **kwds): return {"VI": VI} return {} +{{for name_suffix in ['64', '32']}} + +from ._base cimport ( + BaseDistanceReducer{{name_suffix}}, + _sqeuclidean_row_norms{{name_suffix}}, +) + +from ._datasets_pair cimport ( + DatasetsPair{{name_suffix}}, + DenseDenseDatasetsPair{{name_suffix}}, +) + +from ._gemm_term_computer cimport GEMMTermComputer{{name_suffix}} + -cdef class PairwiseDistances64(BaseDistanceReducer64): - """64bit implementation of PairwiseDistances.""" +cdef class PairwiseDistances{{name_suffix}}(BaseDistanceReducer{{name_suffix}}): + """{{name_suffix}}bit implementation of PairwiseDistances.""" @classmethod def compute( @@ -70,7 +70,7 @@ cdef class PairwiseDistances64(BaseDistanceReducer64): This classmethod is responsible for introspecting the arguments values to dispatch to the most appropriate implementation of - :class:`PairwiseDistances64`. + :class:`PairwiseDistances{{name_suffix}}`. This allows decoupling the API entirely from the implementation details whilst maintaining RAII: all temporarily allocated datastructures necessary @@ -89,7 +89,7 @@ cdef class PairwiseDistances64(BaseDistanceReducer64): # at time to leverage a call to the BLAS GEMM routine as explained # in more details in the docstring. use_squared_distances = metric == "sqeuclidean" - pdr = EuclideanPairwiseDistances64( + pdr = EuclideanPairwiseDistances{{name_suffix}}( X=X, Y=Y, use_squared_distances=use_squared_distances, chunk_size=chunk_size, @@ -103,8 +103,8 @@ cdef class PairwiseDistances64(BaseDistanceReducer64): # Fall back on a generic implementation that handles most scipy # metrics by computing the distances between 2 vectors at a time. - pdr = PairwiseDistances64( - datasets_pair=DatasetsPair64.get_for(X, Y, metric, metric_kwargs), + pdr = PairwiseDistances{{name_suffix}}( + datasets_pair=DatasetsPair{{name_suffix}}.get_for(X, Y, metric, metric_kwargs), chunk_size=chunk_size, strategy=strategy, ) @@ -122,7 +122,7 @@ cdef class PairwiseDistances64(BaseDistanceReducer64): def __init__( self, - DatasetsPair64 datasets_pair, + DatasetsPair{{name_suffix}} datasets_pair, chunk_size=None, strategy=None, sort_results=False, @@ -167,12 +167,12 @@ cdef class PairwiseDistances64(BaseDistanceReducer64): self.pairwise_distances_matrix[i, j] = dist_i_j -cdef class EuclideanPairwiseDistances64(PairwiseDistances64): - """EuclideanDistance-specialized 64bit implementation for PairwiseDistances.""" +cdef class EuclideanPairwiseDistances{{name_suffix}}(PairwiseDistances{{name_suffix}}): + """EuclideanDistance-specialized {{name_suffix}}bit implementation for PairwiseDistances.""" @classmethod def is_usable_for(cls, X, Y, metric) -> bool: - return (PairwiseDistances64.is_usable_for(X, Y, metric) + return (PairwiseDistances{{name_suffix}}.is_usable_for(X, Y, metric) and not _in_unstable_openblas_configuration()) def __init__( @@ -191,23 +191,28 @@ cdef class EuclideanPairwiseDistances64(PairwiseDistances64): ): warnings.warn( f"Some metric_kwargs have been passed ({metric_kwargs}) but aren't " - f"usable for this case (EuclideanPairwiseDistances64) and will be ignored.", + f"usable for this case (EuclideanPairwiseDistances{{name_suffix}}) and will be ignored.", UserWarning, stacklevel=3, ) super().__init__( # The datasets pair here is used for exact distances computations - datasets_pair=DatasetsPair64.get_for(X, Y, metric="euclidean"), + datasets_pair=DatasetsPair{{name_suffix}}.get_for(X, Y, metric="euclidean"), chunk_size=chunk_size, strategy=strategy, ) - # X and Y are checked by the DatasetsPair64 implemented as a DenseDenseDatasetsPair64 + # X and Y are checked by the DatasetsPair{{name_suffix}} implemented as + # a DenseDenseDatasetsPair{{name_suffix}} cdef: - DenseDenseDatasetsPair64 datasets_pair = self.datasets_pair - ITYPE_t dist_middle_terms_chunks_size = self.Y_n_samples_chunk * self.X_n_samples_chunk + DenseDenseDatasetsPair{{name_suffix}} datasets_pair = ( + self.datasets_pair + ) + ITYPE_t dist_middle_terms_chunks_size = ( + self.Y_n_samples_chunk * self.X_n_samples_chunk + ) - self.gemm_term_computer = GEMMTermComputer64( + self.gemm_term_computer = GEMMTermComputer{{name_suffix}}( datasets_pair.X, datasets_pair.Y, self.effective_n_threads, @@ -220,12 +225,18 @@ cdef class EuclideanPairwiseDistances64(PairwiseDistances64): if metric_kwargs is not None and "Y_norm_squared" in metric_kwargs: self.Y_norm_squared = metric_kwargs.pop("Y_norm_squared") else: - self.Y_norm_squared = _sqeuclidean_row_norms64(datasets_pair.Y, self.effective_n_threads) + self.Y_norm_squared = _sqeuclidean_row_norms{{name_suffix}}( + datasets_pair.Y, + self.effective_n_threads, + ) # Do not recompute norms if datasets are identical. self.X_norm_squared = ( self.Y_norm_squared if self.datasets_pair.X_is_Y else - _sqeuclidean_row_norms64(datasets_pair.X, self.effective_n_threads) + _sqeuclidean_row_norms{{name_suffix}}( + datasets_pair.X, + self.effective_n_threads, + ) ) self.use_squared_distances = use_squared_distances @@ -235,7 +246,7 @@ cdef class EuclideanPairwiseDistances64(PairwiseDistances64): self, ITYPE_t thread_num, ) nogil: - PairwiseDistances64._parallel_on_X_parallel_init(self, thread_num) + PairwiseDistances{{name_suffix}}._parallel_on_X_parallel_init(self, thread_num) self.gemm_term_computer._parallel_on_X_parallel_init(thread_num) @final @@ -245,7 +256,7 @@ cdef class EuclideanPairwiseDistances64(PairwiseDistances64): ITYPE_t X_start, ITYPE_t X_end, ) nogil: - PairwiseDistances64._parallel_on_X_init_chunk(self, thread_num, X_start, X_end) + PairwiseDistances{{name_suffix}}._parallel_on_X_init_chunk(self, thread_num, X_start, X_end) self.gemm_term_computer._parallel_on_X_init_chunk(thread_num, X_start, X_end) @final @@ -257,7 +268,7 @@ cdef class EuclideanPairwiseDistances64(PairwiseDistances64): ITYPE_t Y_end, ITYPE_t thread_num, ) nogil: - PairwiseDistances64._parallel_on_X_pre_compute_and_reduce_distances_on_chunks( + PairwiseDistances{{name_suffix}}._parallel_on_X_pre_compute_and_reduce_distances_on_chunks( self, X_start, X_end, Y_start, Y_end, @@ -272,7 +283,7 @@ cdef class EuclideanPairwiseDistances64(PairwiseDistances64): self, ) nogil: cdef ITYPE_t thread_num - PairwiseDistances64._parallel_on_Y_init(self) + PairwiseDistances{{name_suffix}}._parallel_on_Y_init(self) self.gemm_term_computer._parallel_on_Y_init() @final @@ -282,7 +293,7 @@ cdef class EuclideanPairwiseDistances64(PairwiseDistances64): ITYPE_t X_start, ITYPE_t X_end, ) nogil: - PairwiseDistances64._parallel_on_Y_parallel_init(self, thread_num, X_start, X_end) + PairwiseDistances{{name_suffix}}._parallel_on_Y_parallel_init(self, thread_num, X_start, X_end) self.gemm_term_computer._parallel_on_Y_parallel_init(thread_num, X_start, X_end) @final @@ -294,7 +305,7 @@ cdef class EuclideanPairwiseDistances64(PairwiseDistances64): ITYPE_t Y_end, ITYPE_t thread_num, ) nogil: - PairwiseDistances64._parallel_on_Y_pre_compute_and_reduce_distances_on_chunks( + PairwiseDistances{{name_suffix}}._parallel_on_Y_pre_compute_and_reduce_distances_on_chunks( self, X_start, X_end, Y_start, Y_end, @@ -335,10 +346,12 @@ cdef class EuclideanPairwiseDistances64(PairwiseDistances64): ) def _finalize_results(self): - distance_matrix = PairwiseDistances64._finalize_results(self) + distance_matrix = PairwiseDistances{{name_suffix}}._finalize_results(self) # Squared Euclidean distances have been used for efficiency. # We remap them to Euclidean distances here before finalizing # results. if not self.use_squared_distances: return np.sqrt(distance_matrix) - return PairwiseDistances64._finalize_results(self) + return PairwiseDistances{{name_suffix}}._finalize_results(self) + +{{endfor}} diff --git a/sklearn/metrics/_pairwise_distances_reduction/setup.py b/sklearn/metrics/_pairwise_distances_reduction/setup.py index c56ca7e7364f2..d3ce860776b27 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/setup.py +++ b/sklearn/metrics/_pairwise_distances_reduction/setup.py @@ -21,6 +21,8 @@ def configuration(parent_package="", top_path=None): "sklearn/metrics/_pairwise_distances_reduction/_base.pxd.tp", "sklearn/metrics/_pairwise_distances_reduction/_argkmin.pyx.tp", "sklearn/metrics/_pairwise_distances_reduction/_argkmin.pxd.tp", + "sklearn/metrics/_pairwise_distances_reduction/_pairwise_distances.pyx.tp", + "sklearn/metrics/_pairwise_distances_reduction/_pairwise_distances.pxd.tp", "sklearn/metrics/_pairwise_distances_reduction/_radius_neighborhood.pyx.tp", "sklearn/metrics/_pairwise_distances_reduction/_radius_neighborhood.pxd.tp", ] @@ -33,6 +35,7 @@ def configuration(parent_package="", top_path=None): "_base.pyx", "_pairwise_distances.pyx", "_argkmin.pyx", + "_pairwise_distances.pyx", "_radius_neighborhood.pyx", ] From f9431cc0da9733b2e6299efe9557b14c1082ff9f Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Mon, 19 Sep 2022 11:16:47 +0200 Subject: [PATCH 10/36] Simplify indices --- .../_pairwise_distances.pyx.tp | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/sklearn/metrics/_pairwise_distances_reduction/_pairwise_distances.pyx.tp b/sklearn/metrics/_pairwise_distances_reduction/_pairwise_distances.pyx.tp index 587096323f663..45d60f96e3156 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_pairwise_distances.pyx.tp +++ b/sklearn/metrics/_pairwise_distances_reduction/_pairwise_distances.pyx.tp @@ -327,23 +327,23 @@ cdef class EuclideanPairwiseDistances{{name_suffix}}(PairwiseDistances{{name_suf ) nogil: cdef: ITYPE_t i, j - ITYPE_t n_X = X_end - X_start - ITYPE_t n_Y = Y_end - Y_start + ITYPE_t pair_index = 0 DTYPE_t *dist_middle_terms = self.gemm_term_computer._compute_distances_on_chunks( X_start, X_end, Y_start, Y_end, thread_num ) - for i in range(n_X): - for j in range(n_Y): + for i in range(X_start, X_end): + for j in range(Y_start, Y_end): # Using the squared euclidean distance as the rank-preserving distance: # # ||X_c_i||² - 2 X_c_i.Y_c_j^T + ||Y_c_j||² # - self.pairwise_distances_matrix[X_start + i, Y_start + j] = ( - self.X_norm_squared[X_start + i] - + dist_middle_terms[i * n_Y + j] - + self.Y_norm_squared[Y_start + j] + self.pairwise_distances_matrix[i, j] = ( + self.X_norm_squared[i] + + dist_middle_terms[pair_index] + + self.Y_norm_squared[j] ) + pair_index += 1 def _finalize_results(self): distance_matrix = PairwiseDistances{{name_suffix}}._finalize_results(self) From 24cfd04cc75ed90ad2a5dbeac5d3f148031ff4a8 Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Mon, 19 Sep 2022 11:17:52 +0200 Subject: [PATCH 11/36] Adapt implementations for a previous 'sqeuclidean' specification --- sklearn/metrics/pairwise.py | 6 ++++++ sklearn/neighbors/_nca.py | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/sklearn/metrics/pairwise.py b/sklearn/metrics/pairwise.py index 302bee3808818..859b4e53ad71c 100644 --- a/sklearn/metrics/pairwise.py +++ b/sklearn/metrics/pairwise.py @@ -1958,6 +1958,12 @@ def pairwise_distances( ) if PairwiseDistances.is_usable_for(X, X if Y is None else Y, metric=metric): + # This is an adaptor for one "sqeuclidean" specification. + # For this backend, we can directly use "sqeuclidean". + if kwds.get("squared", False) and metric == "euclidean": + metric = "sqeuclidean" + kwds = {} + return PairwiseDistances.compute( X, X if Y is None else Y, metric=metric, metric_kwargs=kwds ) diff --git a/sklearn/neighbors/_nca.py b/sklearn/neighbors/_nca.py index cf6e520767718..872e54be6d5bb 100644 --- a/sklearn/neighbors/_nca.py +++ b/sklearn/neighbors/_nca.py @@ -486,7 +486,7 @@ def _loss_grad_lbfgs(self, transformation, X, same_class_mask, sign=1.0): X_embedded = np.dot(X, transformation.T) # (n_samples, n_components) # Compute softmax distances - p_ij = pairwise_distances(X_embedded, squared=True) + p_ij = pairwise_distances(X_embedded, metric="sqeuclidean") np.fill_diagonal(p_ij, np.inf) p_ij = softmax(-p_ij) # (n_samples, n_samples) From 527a43c6ff6f8e97cd9175fab60e9dd049342d40 Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Mon, 19 Sep 2022 11:18:45 +0200 Subject: [PATCH 12/36] TST Downcast distance matrix to float32 --- sklearn/manifold/tests/test_t_sne.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/manifold/tests/test_t_sne.py b/sklearn/manifold/tests/test_t_sne.py index 51239d7ffd2cc..2fc45b04c5ddd 100644 --- a/sklearn/manifold/tests/test_t_sne.py +++ b/sklearn/manifold/tests/test_t_sne.py @@ -164,7 +164,7 @@ def test_binary_search_neighbors(): desired_perplexity = 25.0 random_state = check_random_state(0) data = random_state.randn(n_samples, 2).astype(np.float32, copy=False) - distances = pairwise_distances(data) + distances = pairwise_distances(data).astype(np.float32, copy=False) P1 = _binary_search_perplexity(distances, desired_perplexity, verbose=0) # Test that when we use all the neighbors the results are identical From d26fafb3cd7cc67b97ad5bbb134f51fea641c78c Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Mon, 19 Sep 2022 11:49:02 +0200 Subject: [PATCH 13/36] Adapt instanciation --- .../metrics/_pairwise_distances_reduction/_dispatcher.py | 7 ++++++- .../_pairwise_distances.pyx.tp | 9 ++++++++- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py b/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py index 81f63615adb05..837a0c78d341a 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py +++ b/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py @@ -141,11 +141,16 @@ class PairwiseDistances(BaseDistanceReductionDispatcher): The distance function `dist` depends on the values of the `metric` and `metric_kwargs` parameters. - This class is not meant to be instanciated, one should only use + This class is not meant to be instantiated, one should only use its :meth:`compute` classmethod which handles allocation and deallocation consistently. """ + @classmethod + def is_usable_for(cls, X, Y, metric) -> bool: + Y = X if Y is None else Y + return super().is_usable_for(X, Y, metric) + @classmethod def compute( cls, diff --git a/sklearn/metrics/_pairwise_distances_reduction/_pairwise_distances.pyx.tp b/sklearn/metrics/_pairwise_distances_reduction/_pairwise_distances.pyx.tp index 45d60f96e3156..049a68554be3b 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_pairwise_distances.pyx.tp +++ b/sklearn/metrics/_pairwise_distances_reduction/_pairwise_distances.pyx.tp @@ -98,7 +98,14 @@ cdef class PairwiseDistances{{name_suffix}}(BaseDistanceReducer{{name_suffix}}): ) else: # Precompute data-derived distance metric parameters - params = _precompute_metric_params(X, Y, metric=metric, **metric_kwargs) + metric_kwargs = {} if metric_kwargs is None else metric_kwargs + + params = _precompute_metric_params( + X, + Y, + metric=metric, + **metric_kwargs, + ) metric_kwargs.update(**params) # Fall back on a generic implementation that handles most scipy From 81b74e1e71c42f1dcda6e48c6d02a75d59b61ce0 Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Mon, 19 Sep 2022 11:49:53 +0200 Subject: [PATCH 14/36] Use PairwiseDistances as a back-end for haversine_distances --- sklearn/metrics/pairwise.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/sklearn/metrics/pairwise.py b/sklearn/metrics/pairwise.py index 859b4e53ad71c..b2b4c0ab345d5 100644 --- a/sklearn/metrics/pairwise.py +++ b/sklearn/metrics/pairwise.py @@ -869,6 +869,10 @@ def haversine_distances(X, Y=None): array([[ 0. , 11099.54035582], [11099.54035582, 0. ]]) """ + + if PairwiseDistances.is_usable_for(X, Y, metric="haversine"): + return PairwiseDistances.compute(X, Y, metric="haversine") + from ..metrics import DistanceMetric return DistanceMetric.get_metric("haversine").pairwise(X, Y) From 0aa688e7e05501cc15bed3c81a31e8398787e433 Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Mon, 19 Sep 2022 11:50:07 +0200 Subject: [PATCH 15/36] Use PairwiseDistances as a back-end for manhattan_distances --- sklearn/metrics/pairwise.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/sklearn/metrics/pairwise.py b/sklearn/metrics/pairwise.py index b2b4c0ab345d5..ccf498413ff05 100644 --- a/sklearn/metrics/pairwise.py +++ b/sklearn/metrics/pairwise.py @@ -938,6 +938,9 @@ def manhattan_distances(X, Y=None, *, sum_over_features=True): """ X, Y = check_pairwise_arrays(X, Y) + if sum_over_features and PairwiseDistances.is_usable_for(X, Y, metric="manhattan"): + return PairwiseDistances.compute(X, Y, metric="manhattan") + if issparse(X) or issparse(Y): if not sum_over_features: raise TypeError( From 0006764475e3bcdf3ca726e168e446eeb4a31948 Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Mon, 19 Sep 2022 12:04:02 +0200 Subject: [PATCH 16/36] Use PairwiseDistances as a back-end for euclidean_distances Comes with minor adaptations --- .../_pairwise_distances.pyx.tp | 35 ++++++++++++------- 1 file changed, 22 insertions(+), 13 deletions(-) diff --git a/sklearn/metrics/_pairwise_distances_reduction/_pairwise_distances.pyx.tp b/sklearn/metrics/_pairwise_distances_reduction/_pairwise_distances.pyx.tp index 049a68554be3b..dd68f620fd282 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_pairwise_distances.pyx.tp +++ b/sklearn/metrics/_pairwise_distances_reduction/_pairwise_distances.pyx.tp @@ -152,7 +152,7 @@ cdef class PairwiseDistances{{name_suffix}}(BaseDistanceReducer{{name_suffix}}): # by zeroing the diagonal. distance_matrix = np.asarray(self.pairwise_distances_matrix) if self.datasets_pair.X_is_Y: - np.fill_diagonal(distance_matrix, 0) + np.fill_diagonal(distance_matrix, 0.) return distance_matrix @@ -166,12 +166,10 @@ cdef class PairwiseDistances{{name_suffix}}(BaseDistanceReducer{{name_suffix}}): ) nogil: cdef: ITYPE_t i, j - DTYPE_t dist_i_j for i in range(X_start, X_end): for j in range(Y_start, Y_end): - dist_i_j = self.datasets_pair.dist(i, j) - self.pairwise_distances_matrix[i, j] = dist_i_j + self.pairwise_distances_matrix[i, j] = self.datasets_pair.dist(i, j) cdef class EuclideanPairwiseDistances{{name_suffix}}(PairwiseDistances{{name_suffix}}): @@ -229,7 +227,7 @@ cdef class EuclideanPairwiseDistances{{name_suffix}}(PairwiseDistances{{name_suf chunk_size=self.chunk_size, ) - if metric_kwargs is not None and "Y_norm_squared" in metric_kwargs: + if metric_kwargs is not None and metric_kwargs.get("Y_norm_squared", None) is not None: self.Y_norm_squared = metric_kwargs.pop("Y_norm_squared") else: self.Y_norm_squared = _sqeuclidean_row_norms{{name_suffix}}( @@ -237,14 +235,18 @@ cdef class EuclideanPairwiseDistances{{name_suffix}}(PairwiseDistances{{name_suf self.effective_n_threads, ) - # Do not recompute norms if datasets are identical. - self.X_norm_squared = ( - self.Y_norm_squared if self.datasets_pair.X_is_Y else - _sqeuclidean_row_norms{{name_suffix}}( - datasets_pair.X, - self.effective_n_threads, + if metric_kwargs is not None and metric_kwargs.get("X_norm_squared", None) is not None: + self.X_norm_squared = metric_kwargs.pop("X_norm_squared") + else: + # Do not recompute norms if datasets are identical. + self.X_norm_squared = ( + self.Y_norm_squared if self.datasets_pair.X_is_Y else + _sqeuclidean_row_norms{{name_suffix}}( + datasets_pair.X, + self.effective_n_threads, + ) ) - ) + self.use_squared_distances = use_squared_distances @@ -335,6 +337,8 @@ cdef class EuclideanPairwiseDistances{{name_suffix}}(PairwiseDistances{{name_suf cdef: ITYPE_t i, j ITYPE_t pair_index = 0 + DTYPE_t sq_dist_i_j = 0. + DTYPE_t *dist_middle_terms = self.gemm_term_computer._compute_distances_on_chunks( X_start, X_end, Y_start, Y_end, thread_num ) @@ -345,11 +349,16 @@ cdef class EuclideanPairwiseDistances{{name_suffix}}(PairwiseDistances{{name_suf # # ||X_c_i||² - 2 X_c_i.Y_c_j^T + ||Y_c_j||² # - self.pairwise_distances_matrix[i, j] = ( + sq_dist_i_j = ( self.X_norm_squared[i] + dist_middle_terms[pair_index] + self.Y_norm_squared[j] ) + # Guard against eventual -0. and NaN caused by catastrophic + # cancellation (e.g. if X is Y) + self.pairwise_distances_matrix[i, j] = ( + sq_dist_i_j if 0. < sq_dist_i_j == sq_dist_i_j else 0. + ) pair_index += 1 def _finalize_results(self): From 05b5ae7774425c4e07004fd0967a85c71c124d4e Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Mon, 19 Sep 2022 17:03:11 +0200 Subject: [PATCH 17/36] Remove duplicated line --- sklearn/metrics/_pairwise_distances_reduction/setup.py | 1 - 1 file changed, 1 deletion(-) diff --git a/sklearn/metrics/_pairwise_distances_reduction/setup.py b/sklearn/metrics/_pairwise_distances_reduction/setup.py index d3ce860776b27..f2161017ae652 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/setup.py +++ b/sklearn/metrics/_pairwise_distances_reduction/setup.py @@ -33,7 +33,6 @@ def configuration(parent_package="", top_path=None): "_datasets_pair.pyx", "_gemm_term_computer.pyx", "_base.pyx", - "_pairwise_distances.pyx", "_argkmin.pyx", "_pairwise_distances.pyx", "_radius_neighborhood.pyx", From 5442dddd86a2e99c49487350556585f53caac13a Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Wed, 21 Sep 2022 15:20:53 +0200 Subject: [PATCH 18/36] TST Remove checks on errors now that minkowski with sparse data is supported --- sklearn/metrics/tests/test_pairwise.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/sklearn/metrics/tests/test_pairwise.py b/sklearn/metrics/tests/test_pairwise.py index d2ade2650832a..03395fd5d9cd5 100644 --- a/sklearn/metrics/tests/test_pairwise.py +++ b/sklearn/metrics/tests/test_pairwise.py @@ -156,12 +156,6 @@ def test_pairwise_distances(): S2 = pairwise_distances(X, metric=minkowski, **kwds) assert_array_almost_equal(S, S2) - # Test that scipy distance metrics throw an error if sparse matrix given - with pytest.raises(TypeError): - pairwise_distances(X_sparse, metric="minkowski") - with pytest.raises(TypeError): - pairwise_distances(X, Y_sparse, metric="minkowski") - # Test that a value error is raised if the metric is unknown with pytest.raises(ValueError): pairwise_distances(X, Y, metric="blah") From 7087ad16abcef8d21c522a21c065d55f008c24bf Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Thu, 22 Sep 2022 14:58:48 +0200 Subject: [PATCH 19/36] TST PairwiseDistances factory methods --- .../test_pairwise_distances_reduction.py | 51 +++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/sklearn/metrics/tests/test_pairwise_distances_reduction.py b/sklearn/metrics/tests/test_pairwise_distances_reduction.py index 49768b4e80364..9b422c2d4f87c 100644 --- a/sklearn/metrics/tests/test_pairwise_distances_reduction.py +++ b/sklearn/metrics/tests/test_pairwise_distances_reduction.py @@ -13,6 +13,7 @@ BaseDistanceReductionDispatcher, ArgKmin, RadiusNeighbors, + PairwiseDistances, sqeuclidean_row_norms, ) @@ -691,6 +692,56 @@ def test_radius_neighborhood_factory_method_wrong_usages(): ) +def test_pairwise_distances_factory_method_wrong_usages(): + rng = np.random.RandomState(1) + X = rng.rand(100, 10) + Y = rng.rand(100, 10) + metric = "euclidean" + + msg = ( + "Only float64 or float32 datasets pairs are supported at this time, " + "got: X.dtype=float32 and Y.dtype=float64" + ) + with pytest.raises( + ValueError, + match=msg, + ): + PairwiseDistances.compute(X=X.astype(np.float32), Y=Y, metric=metric) + + msg = ( + "Only float64 or float32 datasets pairs are supported at this time, " + "got: X.dtype=float64 and Y.dtype=int32" + ) + with pytest.raises( + ValueError, + match=msg, + ): + PairwiseDistances.compute(X=X, Y=Y.astype(np.int32), metric=metric) + + with pytest.raises(ValueError, match="Unrecognized metric"): + PairwiseDistances.compute(X=X, Y=Y, metric="wrong metric") + + with pytest.raises( + ValueError, match=r"Buffer has wrong number of dimensions \(expected 2, got 1\)" + ): + PairwiseDistances.compute(X=np.array([1.0, 2.0]), Y=Y, metric=metric) + + with pytest.raises(ValueError, match="ndarray is not C-contiguous"): + PairwiseDistances.compute(X=np.asfortranarray(X), Y=Y, metric=metric) + + unused_metric_kwargs = {"p": 3} + + message = ( + r"Some metric_kwargs have been passed \({'p': 3}\) but aren't usable for this" + r" case \(EuclideanPairwiseDistances64" + ) + + with pytest.warns(UserWarning, match=message): + PairwiseDistances.compute( + X=X, Y=Y, metric=metric, metric_kwargs=unused_metric_kwargs + ) + + @pytest.mark.parametrize("n_samples", [100, 1000]) @pytest.mark.parametrize("chunk_size", [50, 512, 1024]) @pytest.mark.parametrize( From 33560b2d82356a71beda08d80aaf9321872087df Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Tue, 27 Sep 2022 09:37:14 +0200 Subject: [PATCH 20/36] DOC Remove comment regarding X_is_Y MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Jérémie du Boisberranger --- .../_pairwise_distances_reduction/_datasets_pair.pxd.tp | 5 ----- 1 file changed, 5 deletions(-) diff --git a/sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pxd.tp b/sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pxd.tp index 03e642be6945c..62a6d8b4179b8 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pxd.tp +++ b/sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pxd.tp @@ -25,11 +25,6 @@ cdef class DatasetsPair{{name_suffix}}: {{DistanceMetric}} distance_metric ITYPE_t n_features - # Note regarding semantic: X and Y can be different PyObjects (in this - # case, X_is_Y == False), yet they can have the exact same numerical - # values. Inferring identity based on numerical values is costly - # (especially for sparse-dense and dense-sparse datasets pairs), so we - # solely infer the value of X_is_Y based the PyObjects' identity. readonly bint X_is_Y cdef ITYPE_t n_samples_X(self) nogil From cfc145ab51fcfd7203fd22f8ede08529331c1125 Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Tue, 27 Sep 2022 10:01:37 +0200 Subject: [PATCH 21/36] Preserve dtype for PairwiseDistances Following discussions in: https://github.com/scikit-learn/scikit-learn/issues/24502 --- sklearn/manifold/tests/test_t_sne.py | 2 +- .../_pairwise_distances.pxd.tp | 20 ++++++++++++++++--- .../_pairwise_distances.pyx.tp | 19 +++++++++++++++--- 3 files changed, 34 insertions(+), 7 deletions(-) diff --git a/sklearn/manifold/tests/test_t_sne.py b/sklearn/manifold/tests/test_t_sne.py index 3d229ee09d9d4..4b00c7b228969 100644 --- a/sklearn/manifold/tests/test_t_sne.py +++ b/sklearn/manifold/tests/test_t_sne.py @@ -164,7 +164,7 @@ def test_binary_search_neighbors(): desired_perplexity = 25.0 random_state = check_random_state(0) data = random_state.randn(n_samples, 2).astype(np.float32, copy=False) - distances = pairwise_distances(data).astype(np.float32, copy=False) + distances = pairwise_distances(data) P1 = _binary_search_perplexity(distances, desired_perplexity, verbose=0) # Test that when we use all the neighbors the results are identical diff --git a/sklearn/metrics/_pairwise_distances_reduction/_pairwise_distances.pxd.tp b/sklearn/metrics/_pairwise_distances_reduction/_pairwise_distances.pxd.tp index 9e457ecb218cc..340399edd3a50 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_pairwise_distances.pxd.tp +++ b/sklearn/metrics/_pairwise_distances_reduction/_pairwise_distances.pxd.tp @@ -1,6 +1,20 @@ -from ...utils._typedefs cimport DTYPE_t +{{py: + +implementation_specific_values = [ + # Values are the following ones: + # + # name_suffix, INPUT_DTYPE_t + # + # + ('64', 'cnp.float64_t'), + ('32', 'cnp.float32_t') +] -{{for name_suffix in ['32', '64']}} +}} +cimport numpy as cnp + +from ...utils._typedefs cimport DTYPE_t +{{for name_suffix, INPUT_DTYPE_t in implementation_specific_values}} from ._base cimport BaseDistanceReducer{{name_suffix}} from ._gemm_term_computer cimport GEMMTermComputer{{name_suffix}} @@ -10,7 +24,7 @@ cdef class PairwiseDistances{{name_suffix}}(BaseDistanceReducer{{name_suffix}}): """{{name_suffix}}bit implementation of PairwiseDistances.""" cdef: - DTYPE_t[:, ::1] pairwise_distances_matrix + {{INPUT_DTYPE_t}}[:, ::1] pairwise_distances_matrix cdef class EuclideanPairwiseDistances{{name_suffix}}(PairwiseDistances{{name_suffix}}): diff --git a/sklearn/metrics/_pairwise_distances_reduction/_pairwise_distances.pyx.tp b/sklearn/metrics/_pairwise_distances_reduction/_pairwise_distances.pyx.tp index dd68f620fd282..a10584e58101b 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_pairwise_distances.pyx.tp +++ b/sklearn/metrics/_pairwise_distances_reduction/_pairwise_distances.pyx.tp @@ -1,3 +1,16 @@ +{{py: + +implementation_specific_values = [ + # Values are the following ones: + # + # name_suffix, INPUT_DTYPE + # + # + ('64', 'DTYPE'), + ('32', 'np.float32') +] + +}} cimport numpy as cnp from cython cimport final from ...utils._typedefs cimport ITYPE_t, DTYPE_t @@ -8,7 +21,7 @@ import warnings from scipy.sparse import issparse from ...utils import _in_unstable_openblas_configuration from ...utils.fixes import threadpool_limits, sp_version, parse_version -from ...utils._typedefs import ITYPE, DTYPE +from ...utils._typedefs import DTYPE cnp.import_array() @@ -38,7 +51,7 @@ def _precompute_metric_params(X, Y, metric=None, **kwds): return {"VI": VI} return {} -{{for name_suffix in ['64', '32']}} +{{for name_suffix, INPUT_DTYPE in implementation_specific_values}} from ._base cimport ( BaseDistanceReducer{{name_suffix}}, @@ -142,7 +155,7 @@ cdef class PairwiseDistances{{name_suffix}}(BaseDistanceReducer{{name_suffix}}): # Distance matrix which will be complete and returned to the caller. self.pairwise_distances_matrix = np.empty( - (self.n_samples_X, self.n_samples_Y), dtype=DTYPE, + (self.n_samples_X, self.n_samples_Y), dtype={{INPUT_DTYPE}}, ) def _finalize_results(self): From 8604ec6c8951db7ae7b0a73f78991d0b92e3cc3d Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Tue, 27 Sep 2022 13:57:18 +0200 Subject: [PATCH 22/36] TST Remove TODO now that dtypes are preserved --- sklearn/metrics/tests/test_pairwise.py | 21 +++++---------------- 1 file changed, 5 insertions(+), 16 deletions(-) diff --git a/sklearn/metrics/tests/test_pairwise.py b/sklearn/metrics/tests/test_pairwise.py index cd316c5d7b426..c91f98dd816c9 100644 --- a/sklearn/metrics/tests/test_pairwise.py +++ b/sklearn/metrics/tests/test_pairwise.py @@ -153,25 +153,14 @@ def test_pairwise_distances(global_dtype): S = pairwise_distances(X_sparse, Y_sparse.tocsc(), metric="manhattan") S2 = manhattan_distances(X_sparse.tobsr(), Y_sparse.tocoo()) assert_allclose(S, S2) - if global_dtype == np.float64: - assert S.dtype == S2.dtype == global_dtype - else: - # TODO Fix manhattan_distances to preserve dtype. - # currently pairwise_distances uses manhattan_distances but converts the result - # back to the input dtype - with pytest.raises(AssertionError): - assert S.dtype == S2.dtype == global_dtype + + # pairwise_distances must preserves dtypes for the manhattan distance metric + assert S.dtype == S2.dtype == global_dtype S2 = manhattan_distances(X, Y) assert_allclose(S, S2) - if global_dtype == np.float64: - assert S.dtype == S2.dtype == global_dtype - else: - # TODO Fix manhattan_distances to preserve dtype. - # currently pairwise_distances uses manhattan_distances but converts the result - # back to the input dtype - with pytest.raises(AssertionError): - assert S.dtype == S2.dtype == global_dtype + # manhattan_distances must preserves dtypes + assert S.dtype == S2.dtype == global_dtype # Test with scipy.spatial.distance metric, with a kwd kwds = {"p": 2.0} From 62a751a210da4b6cc7620a229b6325502c9f615a Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Tue, 27 Sep 2022 14:33:31 +0200 Subject: [PATCH 23/36] DOC Remove and adapt some comments --- .../_pairwise_distances_reduction/_datasets_pair.pyx.tp | 2 -- .../metrics/_pairwise_distances_reduction/_dispatcher.py | 5 ----- .../_pairwise_distances.pyx.tp | 7 +++---- 3 files changed, 3 insertions(+), 11 deletions(-) diff --git a/sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pyx.tp b/sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pyx.tp index 57f5c682f8790..91ef6307b643b 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pyx.tp +++ b/sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pyx.tp @@ -103,8 +103,6 @@ cdef class DatasetsPair{{name_suffix}}: metric, **(metric_kwargs or {}) ) - # TODO: potentially we could infer identity of X and Y based on numerical - # values and not (uniquely) on the PyObject's identity bint X_is_Y = X is Y # Metric-specific checks that do not replace nor duplicate `check_array`. diff --git a/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py b/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py index 8b9323d5d97fb..70ef7783cb155 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py +++ b/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py @@ -263,11 +263,6 @@ def compute( This allows entirely decoupling the API entirely from the implementation details whilst maintaining RAII. """ - # Note (jjerphan): Some design thoughts for future extensions. - # This factory comes to handle specialisations for the given arguments. - # For future work, this might can be an entrypoint to specialise operations - # for various backend and/or hardware and/or datatypes, and/or fused - # {sparse, dense}-datasetspair etc. Y = X if Y is None else Y if X.dtype == Y.dtype == np.float64: return PairwiseDistances64.compute( diff --git a/sklearn/metrics/_pairwise_distances_reduction/_pairwise_distances.pyx.tp b/sklearn/metrics/_pairwise_distances_reduction/_pairwise_distances.pyx.tp index a10584e58101b..239085b63e9ba 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_pairwise_distances.pyx.tp +++ b/sklearn/metrics/_pairwise_distances_reduction/_pairwise_distances.pyx.tp @@ -159,10 +159,9 @@ cdef class PairwiseDistances{{name_suffix}}(BaseDistanceReducer{{name_suffix}}): ) def _finalize_results(self): - # If X is Y, then catastrophic cancellation might - # have occurred for computations of term on the diagonal - # which must be null. We enforce nullity of those term - # by zeroing the diagonal. + # If X is Y, then catastrophic cancellation might have occurred for + # computations of terms on the diagonal which must equal zero. + # We enforce it by zeroing the diagonal. distance_matrix = np.asarray(self.pairwise_distances_matrix) if self.datasets_pair.X_is_Y: np.fill_diagonal(distance_matrix, 0.) From 259edc1d64660262ebfa54101d64a474fc52666d Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Tue, 27 Sep 2022 09:37:14 +0200 Subject: [PATCH 24/36] MAINT Remove _sparse_manhattan MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Jérémie du Boisberranger --- sklearn/metrics/_pairwise_fast.pyx | 74 ------------------------------ sklearn/metrics/pairwise.py | 16 +++---- 2 files changed, 7 insertions(+), 83 deletions(-) diff --git a/sklearn/metrics/_pairwise_fast.pyx b/sklearn/metrics/_pairwise_fast.pyx index b9006773c015d..b6941dba290db 100644 --- a/sklearn/metrics/_pairwise_fast.pyx +++ b/sklearn/metrics/_pairwise_fast.pyx @@ -6,10 +6,6 @@ cimport numpy as cnp from cython cimport floating -from cython.parallel cimport prange -from libc.math cimport fabs - -from ..utils._openmp_helpers import _openmp_effective_n_threads cnp.import_array() @@ -33,73 +29,3 @@ def _chi2_kernel_fast(floating[:, :] X, if nom != 0: res += denom * denom / nom result[i, j] = -res - - -def _sparse_manhattan(floating[::1] X_data, int[:] X_indices, int[:] X_indptr, - floating[::1] Y_data, int[:] Y_indices, int[:] Y_indptr, - double[:, ::1] D): - """Pairwise L1 distances for CSR matrices. - - Usage: - >>> D = np.zeros(X.shape[0], Y.shape[0]) - >>> _sparse_manhattan(X.data, X.indices, X.indptr, - ... Y.data, Y.indices, Y.indptr, - ... D) - """ - cdef cnp.npy_intp px, py, i, j, ix, iy - cdef double d = 0.0 - - cdef int m = D.shape[0] - cdef int n = D.shape[1] - - cdef int X_indptr_end = 0 - cdef int Y_indptr_end = 0 - - cdef int num_threads = _openmp_effective_n_threads() - - # We scan the matrices row by row. - # Given row px in X and row py in Y, we find the positions (i and j - # respectively), in .indices where the indices for the two rows start. - # If the indices (ix and iy) are the same, the corresponding data values - # are processed and the cursors i and j are advanced. - # If not, the lowest index is considered. Its associated data value is - # processed and its cursor is advanced. - # We proceed like this until one of the cursors hits the end for its row. - # Then we process all remaining data values in the other row. - - # Below the avoidance of inplace operators is intentional. - # When prange is used, the inplace operator has a special meaning, i.e. it - # signals a "reduction" - - for px in prange(m, nogil=True, num_threads=num_threads): - X_indptr_end = X_indptr[px + 1] - for py in range(n): - Y_indptr_end = Y_indptr[py + 1] - i = X_indptr[px] - j = Y_indptr[py] - d = 0.0 - while i < X_indptr_end and j < Y_indptr_end: - ix = X_indices[i] - iy = Y_indices[j] - - if ix == iy: - d = d + fabs(X_data[i] - Y_data[j]) - i = i + 1 - j = j + 1 - elif ix < iy: - d = d + fabs(X_data[i]) - i = i + 1 - else: - d = d + fabs(Y_data[j]) - j = j + 1 - - if i == X_indptr_end: - while j < Y_indptr_end: - d = d + fabs(Y_data[j]) - j = j + 1 - else: - while i < X_indptr_end: - d = d + fabs(X_data[i]) - i = i + 1 - - D[px, py] = d diff --git a/sklearn/metrics/pairwise.py b/sklearn/metrics/pairwise.py index 33ea20759e5c8..176064159df67 100644 --- a/sklearn/metrics/pairwise.py +++ b/sklearn/metrics/pairwise.py @@ -34,7 +34,7 @@ PairwiseDistances, _precompute_metric_params, ) -from ._pairwise_fast import _chi2_kernel_fast, _sparse_manhattan +from ._pairwise_fast import _chi2_kernel_fast from ..exceptions import DataConversionWarning @@ -940,6 +940,12 @@ def manhattan_distances(X, Y=None, *, sum_over_features=True): """ X, Y = check_pairwise_arrays(X, Y) + if issparse(X) or issparse(Y): + X = csr_matrix(X, copy=False) + Y = csr_matrix(Y, copy=False) + X.sum_duplicates() # this also sorts indices in-place + Y.sum_duplicates() + if sum_over_features and PairwiseDistances.is_usable_for(X, Y, metric="manhattan"): return PairwiseDistances.compute(X, Y, metric="manhattan") @@ -950,14 +956,6 @@ def manhattan_distances(X, Y=None, *, sum_over_features=True): % sum_over_features ) - X = csr_matrix(X, copy=False) - Y = csr_matrix(Y, copy=False) - X.sum_duplicates() # this also sorts indices in-place - Y.sum_duplicates() - D = np.zeros((X.shape[0], Y.shape[0])) - _sparse_manhattan(X.data, X.indices, X.indptr, Y.data, Y.indices, Y.indptr, D) - return D - if sum_over_features: return distance.cdist(X, Y, "cityblock") From 50f032e82aa2b6086389f9123fcad307305a2098 Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Tue, 27 Sep 2022 15:20:38 +0200 Subject: [PATCH 25/36] Use PairwiseDistance as a back-end for _euclidean_distances --- sklearn/metrics/pairwise.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/sklearn/metrics/pairwise.py b/sklearn/metrics/pairwise.py index 176064159df67..cf4bf8212e241 100644 --- a/sklearn/metrics/pairwise.py +++ b/sklearn/metrics/pairwise.py @@ -340,6 +340,14 @@ def _euclidean_distances(X, Y, X_norm_squared=None, Y_norm_squared=None, squared float32, norms needs to be recomputed on upcast chunks. TODO: use a float64 accumulator in row_norms to avoid the latter. """ + metric = "sqeuclidean" if squared else "euclidean" + if PairwiseDistances.is_usable_for(X, Y, metric): + metric_kwargs = { + "X_norm_squared": X_norm_squared, + "Y_norm_squared": Y_norm_squared, + } + return PairwiseDistances.compute(X, Y, metric, metric_kwargs=metric_kwargs) + if X_norm_squared is not None: if X_norm_squared.dtype == np.float32: XX = None From 40d0a0bd7d8bc5c22d0bba93460b975b723821ef Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Tue, 27 Sep 2022 15:28:48 +0200 Subject: [PATCH 26/36] MAINT Keep classmethods at the top --- .../_datasets_pair.pyx.tp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pyx.tp b/sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pyx.tp index 91ef6307b643b..799a83bef5e41 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pyx.tp +++ b/sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pyx.tp @@ -123,11 +123,6 @@ cdef class DatasetsPair{{name_suffix}}: return DenseSparseDatasetsPair{{name_suffix}}(X, Y, distance_metric, X_is_Y) - def __init__(self, {{DistanceMetric}} distance_metric, bint X_is_Y, ITYPE_t n_features): - self.distance_metric = distance_metric - self.X_is_Y = X_is_Y - self.n_features = n_features - @classmethod def unpack_csr_matrix(cls, X: csr_matrix): """Ensure that the CSR matrix is indexed with SPARSE_INDEX_TYPE.""" @@ -136,6 +131,11 @@ cdef class DatasetsPair{{name_suffix}}: X_indptr = np.asarray(X.indptr, dtype=SPARSE_INDEX_TYPE) return X_data, X_indices, X_indptr + def __init__(self, {{DistanceMetric}} distance_metric, ITYPE_t n_features, bint X_is_Y): + self.distance_metric = distance_metric + self.X_is_Y = X_is_Y + self.n_features = n_features + cdef ITYPE_t n_samples_X(self) nogil: """Number of samples in X.""" # This is a abstract method. From 9017aa709a01093d2de8906fe2b346a56174e043 Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Wed, 28 Sep 2022 17:58:19 +0200 Subject: [PATCH 27/36] Safely pack {X,Y}_squared_norms --- sklearn/metrics/pairwise.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/sklearn/metrics/pairwise.py b/sklearn/metrics/pairwise.py index cf4bf8212e241..f1773c5846168 100644 --- a/sklearn/metrics/pairwise.py +++ b/sklearn/metrics/pairwise.py @@ -343,8 +343,12 @@ def _euclidean_distances(X, Y, X_norm_squared=None, Y_norm_squared=None, squared metric = "sqeuclidean" if squared else "euclidean" if PairwiseDistances.is_usable_for(X, Y, metric): metric_kwargs = { - "X_norm_squared": X_norm_squared, - "Y_norm_squared": Y_norm_squared, + "X_norm_squared": np.ravel(X_norm_squared) + if X_norm_squared is not None + else X_norm_squared, + "Y_norm_squared": np.ravel(Y_norm_squared) + if Y_norm_squared is not None + else Y_norm_squared, } return PairwiseDistances.compute(X, Y, metric, metric_kwargs=metric_kwargs) From dd136a7d12e891627ae206897b8becf7a7275ec7 Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Wed, 28 Sep 2022 17:58:25 +0200 Subject: [PATCH 28/36] TST Adapt test_euclidean_distances_extreme_values This: - decreases the number of features by an order to magnetude because in the case of float32, the vectors gets entirely copied for the upcast to float64. This might use too much memory and crash the program - this now accepts the previously xfail parametrisation case by setting on absolute error (seen we are comparing small values) --- sklearn/metrics/tests/test_pairwise.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/sklearn/metrics/tests/test_pairwise.py b/sklearn/metrics/tests/test_pairwise.py index c91f98dd816c9..4c1f7c1b19f5b 100644 --- a/sklearn/metrics/tests/test_pairwise.py +++ b/sklearn/metrics/tests/test_pairwise.py @@ -908,15 +908,10 @@ def test_euclidean_distances_upcast_sym(batch_size, x_array_constr): "dtype, eps, rtol", [ (np.float32, 1e-4, 1e-5), - pytest.param( - np.float64, - 1e-8, - 0.99, - marks=pytest.mark.xfail(reason="failing due to lack of precision"), - ), + (np.float64, 1e-8, 0.99), ], ) -@pytest.mark.parametrize("dim", [1, 1000000]) +@pytest.mark.parametrize("dim", [1, 100000]) def test_euclidean_distances_extreme_values(dtype, eps, rtol, dim): # check that euclidean distances is correct with float32 input thanks to # upcasting. On float64 there are still precision issues. @@ -926,7 +921,7 @@ def test_euclidean_distances_extreme_values(dtype, eps, rtol, dim): distances = euclidean_distances(X, Y) expected = cdist(X, Y) - assert_allclose(distances, expected, rtol=1e-5) + assert_allclose(distances, expected, rtol=1e-5, atol=4e-6) @pytest.mark.parametrize("squared", [True, False]) From 91b32055ba0be4171b0606dcb881f44e6bd94b07 Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Thu, 29 Sep 2022 09:06:20 +0200 Subject: [PATCH 29/36] DOC Improve docstrings Co-authored-by: Olivier Grisel --- sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py | 5 +++++ .../_pairwise_distances_reduction/_pairwise_distances.pyx.tp | 4 ++-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py b/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py index 70ef7783cb155..aa595953894e8 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py +++ b/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py @@ -175,6 +175,11 @@ class PairwiseDistances(BaseDistanceReductionDispatcher): The distance function `dist` depends on the values of the `metric` and `metric_kwargs` parameters. + This class only computes the pairwise distances matrix without + applying any reduction on it. It shares most of the underlying + code infrastructure with reducing variants to leverage + cache-aware chunking and multi-thread parallelism. + This class is not meant to be instantiated, one should only use its :meth:`compute` classmethod which handles allocation and deallocation consistently. diff --git a/sklearn/metrics/_pairwise_distances_reduction/_pairwise_distances.pyx.tp b/sklearn/metrics/_pairwise_distances_reduction/_pairwise_distances.pyx.tp index 239085b63e9ba..5366ac1bc15f0 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_pairwise_distances.pyx.tp +++ b/sklearn/metrics/_pairwise_distances_reduction/_pairwise_distances.pyx.tp @@ -67,7 +67,7 @@ from ._gemm_term_computer cimport GEMMTermComputer{{name_suffix}} cdef class PairwiseDistances{{name_suffix}}(BaseDistanceReducer{{name_suffix}}): - """{{name_suffix}}bit implementation of PairwiseDistances.""" + """float{{name_suffix}} implementation of PairwiseDistances.""" @classmethod def compute( @@ -185,7 +185,7 @@ cdef class PairwiseDistances{{name_suffix}}(BaseDistanceReducer{{name_suffix}}): cdef class EuclideanPairwiseDistances{{name_suffix}}(PairwiseDistances{{name_suffix}}): - """EuclideanDistance-specialized {{name_suffix}}bit implementation for PairwiseDistances.""" + """EuclideanDistance-specialized float{{name_suffix}} implementation for PairwiseDistances.""" @classmethod def is_usable_for(cls, X, Y, metric) -> bool: From 0abe5603689c58e0cdf3ea309406a6ec82ebe939 Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Thu, 29 Sep 2022 09:07:10 +0200 Subject: [PATCH 30/36] Simplify manhattan_distances --- sklearn/metrics/pairwise.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/sklearn/metrics/pairwise.py b/sklearn/metrics/pairwise.py index f1773c5846168..62d6f92d6a82e 100644 --- a/sklearn/metrics/pairwise.py +++ b/sklearn/metrics/pairwise.py @@ -952,10 +952,14 @@ def manhattan_distances(X, Y=None, *, sum_over_features=True): """ X, Y = check_pairwise_arrays(X, Y) - if issparse(X) or issparse(Y): + if issparse(X): X = csr_matrix(X, copy=False) + # This also sorts indices in-place. + X.sum_duplicates() + + if issparse(Y): Y = csr_matrix(Y, copy=False) - X.sum_duplicates() # this also sorts indices in-place + # This also sorts indices in-place. Y.sum_duplicates() if sum_over_features and PairwiseDistances.is_usable_for(X, Y, metric="manhattan"): @@ -963,10 +967,7 @@ def manhattan_distances(X, Y=None, *, sum_over_features=True): if issparse(X) or issparse(Y): if not sum_over_features: - raise TypeError( - "sum_over_features=%r not supported for sparse matrices" - % sum_over_features - ) + raise TypeError("sum_over_features=False not supported for sparse matrices") if sum_over_features: return distance.cdist(X, Y, "cityblock") From b5003f923c679d43230a642cc762eb67cba41e90 Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Thu, 29 Sep 2022 09:53:04 +0200 Subject: [PATCH 31/36] fixup! DOC Improve docstrings --- .../_pairwise_distances_reduction/_pairwise_distances.pxd.tp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/metrics/_pairwise_distances_reduction/_pairwise_distances.pxd.tp b/sklearn/metrics/_pairwise_distances_reduction/_pairwise_distances.pxd.tp index 340399edd3a50..4a6b5b3804776 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_pairwise_distances.pxd.tp +++ b/sklearn/metrics/_pairwise_distances_reduction/_pairwise_distances.pxd.tp @@ -21,14 +21,14 @@ from ._gemm_term_computer cimport GEMMTermComputer{{name_suffix}} cdef class PairwiseDistances{{name_suffix}}(BaseDistanceReducer{{name_suffix}}): - """{{name_suffix}}bit implementation of PairwiseDistances.""" + """float{{name_suffix}} implementation of PairwiseDistances.""" cdef: {{INPUT_DTYPE_t}}[:, ::1] pairwise_distances_matrix cdef class EuclideanPairwiseDistances{{name_suffix}}(PairwiseDistances{{name_suffix}}): - """EuclideanDistance-specialized {{name_suffix}}bit implementation for PairwiseDistances.""" + """EuclideanDistance-specialized float{{name_suffix}} implementation for PairwiseDistances.""" cdef: GEMMTermComputer{{name_suffix}} gemm_term_computer const DTYPE_t[::1] X_norm_squared From f38b010944d1f54387bb2b4e34bd40d9a04061f3 Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Thu, 29 Sep 2022 13:15:11 +0200 Subject: [PATCH 32/36] fixup! TST Adapt test_euclidean_distances_extreme_values --- sklearn/metrics/tests/test_pairwise.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/metrics/tests/test_pairwise.py b/sklearn/metrics/tests/test_pairwise.py index 4c1f7c1b19f5b..88d170337629f 100644 --- a/sklearn/metrics/tests/test_pairwise.py +++ b/sklearn/metrics/tests/test_pairwise.py @@ -921,7 +921,7 @@ def test_euclidean_distances_extreme_values(dtype, eps, rtol, dim): distances = euclidean_distances(X, Y) expected = cdist(X, Y) - assert_allclose(distances, expected, rtol=1e-5, atol=4e-6) + assert_allclose(distances, expected, rtol=1e-5, atol=4e-4) @pytest.mark.parametrize("squared", [True, False]) From f2d8cbeac4d1cb81f1f58b7bdf15e8649d0b20a2 Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Fri, 7 Oct 2022 10:50:32 +0200 Subject: [PATCH 33/36] Apply review comments Co-authored-by: Olivier Grisel --- .../_dispatcher.py | 2 +- .../_pairwise_distances.pyx.tp | 29 ++++--------------- sklearn/metrics/pairwise.py | 22 ++++++++++++-- .../test_pairwise_distances_reduction.py | 16 ++-------- 4 files changed, 29 insertions(+), 40 deletions(-) diff --git a/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py b/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py index aa595953894e8..c639d448f8545 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py +++ b/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py @@ -290,7 +290,7 @@ def compute( ) raise ValueError( - "Only float64 or float32 datasets pairs are supported at this time, " + "Only float64 or float32 datasets pairs are supported, but " f"got: X.dtype={X.dtype} and Y.dtype={Y.dtype}." ) diff --git a/sklearn/metrics/_pairwise_distances_reduction/_pairwise_distances.pyx.tp b/sklearn/metrics/_pairwise_distances_reduction/_pairwise_distances.pyx.tp index 5366ac1bc15f0..6ff1cdbc4ebf8 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_pairwise_distances.pyx.tp +++ b/sklearn/metrics/_pairwise_distances_reduction/_pairwise_distances.pyx.tp @@ -201,41 +201,24 @@ cdef class EuclideanPairwiseDistances{{name_suffix}}(PairwiseDistances{{name_suf strategy=None, metric_kwargs=None, ): - if ( - metric_kwargs is not None and - len(metric_kwargs) > 0 and - "Y_norm_squared" not in metric_kwargs - ): - warnings.warn( - f"Some metric_kwargs have been passed ({metric_kwargs}) but aren't " - f"usable for this case (EuclideanPairwiseDistances{{name_suffix}}) and will be ignored.", - UserWarning, - stacklevel=3, - ) - super().__init__( # The datasets pair here is used for exact distances computations datasets_pair=DatasetsPair{{name_suffix}}.get_for(X, Y, metric="euclidean"), chunk_size=chunk_size, strategy=strategy, ) - # X and Y are checked by the DatasetsPair{{name_suffix}} implemented as - # a DenseDenseDatasetsPair{{name_suffix}} cdef: - DenseDenseDatasetsPair{{name_suffix}} datasets_pair = ( - self.datasets_pair - ) ITYPE_t dist_middle_terms_chunks_size = ( self.Y_n_samples_chunk * self.X_n_samples_chunk ) self.gemm_term_computer = GEMMTermComputer{{name_suffix}}( - datasets_pair.X, - datasets_pair.Y, + X, + Y, self.effective_n_threads, self.chunks_n_threads, dist_middle_terms_chunks_size, - n_features=datasets_pair.X.shape[1], + n_features=X.shape[1], chunk_size=self.chunk_size, ) @@ -243,7 +226,7 @@ cdef class EuclideanPairwiseDistances{{name_suffix}}(PairwiseDistances{{name_suf self.Y_norm_squared = metric_kwargs.pop("Y_norm_squared") else: self.Y_norm_squared = _sqeuclidean_row_norms{{name_suffix}}( - datasets_pair.Y, + Y, self.effective_n_threads, ) @@ -252,9 +235,9 @@ cdef class EuclideanPairwiseDistances{{name_suffix}}(PairwiseDistances{{name_suf else: # Do not recompute norms if datasets are identical. self.X_norm_squared = ( - self.Y_norm_squared if self.datasets_pair.X_is_Y else + self.Y_norm_squared if X is Y else _sqeuclidean_row_norms{{name_suffix}}( - datasets_pair.X, + X, self.effective_n_threads, ) ) diff --git a/sklearn/metrics/pairwise.py b/sklearn/metrics/pairwise.py index 62d6f92d6a82e..35f22d69cc66a 100644 --- a/sklearn/metrics/pairwise.py +++ b/sklearn/metrics/pairwise.py @@ -345,13 +345,16 @@ def _euclidean_distances(X, Y, X_norm_squared=None, Y_norm_squared=None, squared metric_kwargs = { "X_norm_squared": np.ravel(X_norm_squared) if X_norm_squared is not None - else X_norm_squared, + else None, "Y_norm_squared": np.ravel(Y_norm_squared) if Y_norm_squared is not None - else Y_norm_squared, + else None, } return PairwiseDistances.compute(X, Y, metric, metric_kwargs=metric_kwargs) + # XXX: the following code is still used for list-of-lists of numbers which + # aren't converted to numpy arrays in validation steps done in `check_array`. + # TODO: convert list-of-lists to numpy arrays in `check_array`. if X_norm_squared is not None: if X_norm_squared.dtype == np.float32: XX = None @@ -887,6 +890,10 @@ def haversine_distances(X, Y=None): if PairwiseDistances.is_usable_for(X, Y, metric="haversine"): return PairwiseDistances.compute(X, Y, metric="haversine") + # XXX: the following code is still used for list-of-lists of numbers which + # aren't converted to numpy arrays in validation steps done in `check_array`. + # TODO: convert list-of-lists to numpy arrays in `check_array`. + from ..metrics import DistanceMetric return DistanceMetric.get_metric("haversine").pairwise(X, Y) @@ -965,6 +972,13 @@ def manhattan_distances(X, Y=None, *, sum_over_features=True): if sum_over_features and PairwiseDistances.is_usable_for(X, Y, metric="manhattan"): return PairwiseDistances.compute(X, Y, metric="manhattan") + # XXX: the following code is still used for list-of-lists of numbers which + # aren't converted to numpy arrays in validation steps done in `check_array` + # and for supporting `sum_over_features` which we should probably remove. + # TODO: convert list-of-lists to numpy arrays in `check_array`. + # TODO: remove `sum_over_features`, see: + # https://github.com/scikit-learn/scikit-learn/issues/24597 + if issparse(X) or issparse(Y): if not sum_over_features: raise TypeError("sum_over_features=False not supported for sparse matrices") @@ -2006,6 +2020,10 @@ def pairwise_distances( _pairwise_callable, metric=metric, force_all_finite=force_all_finite, **kwds ) else: + # XXX: the following code is still used for list-of-lists of numbers which + # aren't converted to numpy arrays in validation steps done in `check_array`. + # TODO: convert list-of-lists to numpy arrays in `check_array`. + if issparse(X) or issparse(Y): raise TypeError("scipy distance metrics do not support sparse matrices.") diff --git a/sklearn/metrics/tests/test_pairwise_distances_reduction.py b/sklearn/metrics/tests/test_pairwise_distances_reduction.py index cdb8c496cae41..b13444fb0f01a 100644 --- a/sklearn/metrics/tests/test_pairwise_distances_reduction.py +++ b/sklearn/metrics/tests/test_pairwise_distances_reduction.py @@ -699,7 +699,7 @@ def test_pairwise_distances_factory_method_wrong_usages(): metric = "euclidean" msg = ( - "Only float64 or float32 datasets pairs are supported at this time, " + "Only float64 or float32 datasets pairs are supported, but " "got: X.dtype=float32 and Y.dtype=float64" ) with pytest.raises( @@ -709,7 +709,7 @@ def test_pairwise_distances_factory_method_wrong_usages(): PairwiseDistances.compute(X=X.astype(np.float32), Y=Y, metric=metric) msg = ( - "Only float64 or float32 datasets pairs are supported at this time, " + "Only float64 or float32 datasets pairs are supported, but " "got: X.dtype=float64 and Y.dtype=int32" ) with pytest.raises( @@ -729,18 +729,6 @@ def test_pairwise_distances_factory_method_wrong_usages(): with pytest.raises(ValueError, match="ndarray is not C-contiguous"): PairwiseDistances.compute(X=np.asfortranarray(X), Y=Y, metric=metric) - unused_metric_kwargs = {"p": 3} - - message = ( - r"Some metric_kwargs have been passed \({'p': 3}\) but aren't usable for this" - r" case \(EuclideanPairwiseDistances64" - ) - - with pytest.warns(UserWarning, match=message): - PairwiseDistances.compute( - X=X, Y=Y, metric=metric, metric_kwargs=unused_metric_kwargs - ) - @pytest.mark.parametrize( "n_samples_X, n_samples_Y", [(100, 100), (500, 100), (100, 500)] From 3673479c207a55d7aa34065afb30f8cce46f9aa6 Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Fri, 7 Oct 2022 14:07:05 +0200 Subject: [PATCH 34/36] Rework poping {X,Y}_norm_squared in DatasetsPair.get_for --- .../_datasets_pair.pyx.tp | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pyx.tp b/sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pyx.tp index 799a83bef5e41..e267f7c683415 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pyx.tp +++ b/sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pyx.tp @@ -1,3 +1,5 @@ +import copy + {{py: implementation_specific_values = [ @@ -91,12 +93,17 @@ cdef class DatasetsPair{{name_suffix}}: datasets_pair: DatasetsPair{{name_suffix}} The suited DatasetsPair{{name_suffix}} implementation. """ - # Y_norm_squared might be propagated down to DatasetsPairs - # via metrics_kwargs when the Euclidean specialisations - # can't be used. To prevent Y_norm_squared to be passed + # X_norm_squared and Y_norm_squared might be propagated + # down to DatasetsPairs via metrics_kwargs when the Euclidean + # specialisations can't be used. + # To prevent X_norm_squared and Y_norm_squared to be passed # down to DistanceMetrics (whose constructors would raise - # a RuntimeError), we pop it here. + # a RuntimeError), we pop them here. if metric_kwargs is not None: + # Copying metric_kwargs not to pop "X_norm_squared" + # and "Y_norm_squared" where they are used + metric_kwargs = copy.copy(metric_kwargs) + metric_kwargs.pop("X_norm_squared", None) metric_kwargs.pop("Y_norm_squared", None) cdef: {{DistanceMetric}} distance_metric = {{DistanceMetric}}.get_metric( From 0c748568cf6de6114125f41e662e79382cd9ac91 Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Fri, 7 Oct 2022 14:07:05 +0200 Subject: [PATCH 35/36] Rework poping {X,Y}_norm_squared in EuclideanPairwiseDistances --- .../_pairwise_distances.pyx.tp | 36 +++++++++---------- sklearn/metrics/pairwise.py | 15 ++++---- 2 files changed, 23 insertions(+), 28 deletions(-) diff --git a/sklearn/metrics/_pairwise_distances_reduction/_pairwise_distances.pyx.tp b/sklearn/metrics/_pairwise_distances_reduction/_pairwise_distances.pyx.tp index 6ff1cdbc4ebf8..2b473353a7585 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_pairwise_distances.pyx.tp +++ b/sklearn/metrics/_pairwise_distances_reduction/_pairwise_distances.pyx.tp @@ -1,25 +1,11 @@ -{{py: - -implementation_specific_values = [ - # Values are the following ones: - # - # name_suffix, INPUT_DTYPE - # - # - ('64', 'DTYPE'), - ('32', 'np.float32') -] - -}} cimport numpy as cnp from cython cimport final from ...utils._typedefs cimport ITYPE_t, DTYPE_t import numpy as np -import warnings from scipy.sparse import issparse -from ...utils import _in_unstable_openblas_configuration +from ...utils import check_array, _in_unstable_openblas_configuration from ...utils.fixes import threadpool_limits, sp_version, parse_version from ...utils._typedefs import DTYPE @@ -51,7 +37,7 @@ def _precompute_metric_params(X, Y, metric=None, **kwds): return {"VI": VI} return {} -{{for name_suffix, INPUT_DTYPE in implementation_specific_values}} +{{for name_suffix, INPUT_DTYPE in (('64', 'DTYPE'),('32', 'np.float32'))}} from ._base cimport ( BaseDistanceReducer{{name_suffix}}, @@ -222,16 +208,26 @@ cdef class EuclideanPairwiseDistances{{name_suffix}}(PairwiseDistances{{name_suf chunk_size=self.chunk_size, ) - if metric_kwargs is not None and metric_kwargs.get("Y_norm_squared", None) is not None: - self.Y_norm_squared = metric_kwargs.pop("Y_norm_squared") + if metric_kwargs is not None and "Y_norm_squared" in metric_kwargs: + self.Y_norm_squared = check_array( + metric_kwargs.pop("Y_norm_squared"), + ensure_2d=False, + input_name="Y_norm_squared", + dtype=np.float64 + ) else: self.Y_norm_squared = _sqeuclidean_row_norms{{name_suffix}}( Y, self.effective_n_threads, ) - if metric_kwargs is not None and metric_kwargs.get("X_norm_squared", None) is not None: - self.X_norm_squared = metric_kwargs.pop("X_norm_squared") + if metric_kwargs is not None and "X_norm_squared" in metric_kwargs: + self.X_norm_squared = check_array( + metric_kwargs.pop("X_norm_squared"), + ensure_2d=False, + input_name="X_norm_squared", + dtype=np.float64 + ) else: # Do not recompute norms if datasets are identical. self.X_norm_squared = ( diff --git a/sklearn/metrics/pairwise.py b/sklearn/metrics/pairwise.py index 2e4db70e3b34d..3105e20e4376c 100644 --- a/sklearn/metrics/pairwise.py +++ b/sklearn/metrics/pairwise.py @@ -342,14 +342,13 @@ def _euclidean_distances(X, Y, X_norm_squared=None, Y_norm_squared=None, squared """ metric = "sqeuclidean" if squared else "euclidean" if PairwiseDistances.is_usable_for(X, Y, metric): - metric_kwargs = { - "X_norm_squared": np.ravel(X_norm_squared) - if X_norm_squared is not None - else None, - "Y_norm_squared": np.ravel(Y_norm_squared) - if Y_norm_squared is not None - else None, - } + metric_kwargs = {} + if X_norm_squared is not None: + metric_kwargs["X_norm_squared"] = np.ravel(X_norm_squared) + + if Y_norm_squared is not None: + metric_kwargs["Y_norm_squared"] = np.ravel(Y_norm_squared) + return PairwiseDistances.compute(X, Y, metric, metric_kwargs=metric_kwargs) # XXX: the following code is still used for list-of-lists of numbers which From 4fccb5f2810bee40fa7090e7b77973ff8833ef13 Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Mon, 24 Oct 2022 17:37:16 +0200 Subject: [PATCH 36/36] DOC Add references to #24745 --- sklearn/metrics/pairwise.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/sklearn/metrics/pairwise.py b/sklearn/metrics/pairwise.py index 3105e20e4376c..330123a396466 100644 --- a/sklearn/metrics/pairwise.py +++ b/sklearn/metrics/pairwise.py @@ -354,6 +354,7 @@ def _euclidean_distances(X, Y, X_norm_squared=None, Y_norm_squared=None, squared # XXX: the following code is still used for list-of-lists of numbers which # aren't converted to numpy arrays in validation steps done in `check_array`. # TODO: convert list-of-lists to numpy arrays in `check_array`. + # See: https://github.com/scikit-learn/scikit-learn/issues/24745 if X_norm_squared is not None: if X_norm_squared.dtype == np.float32: XX = None @@ -892,6 +893,7 @@ def haversine_distances(X, Y=None): # XXX: the following code is still used for list-of-lists of numbers which # aren't converted to numpy arrays in validation steps done in `check_array`. # TODO: convert list-of-lists to numpy arrays in `check_array`. + # See: https://github.com/scikit-learn/scikit-learn/issues/24745 from ..metrics import DistanceMetric @@ -975,6 +977,7 @@ def manhattan_distances(X, Y=None, *, sum_over_features=True): # aren't converted to numpy arrays in validation steps done in `check_array` # and for supporting `sum_over_features` which we should probably remove. # TODO: convert list-of-lists to numpy arrays in `check_array`. + # See: https://github.com/scikit-learn/scikit-learn/issues/24745 # TODO: remove `sum_over_features`, see: # https://github.com/scikit-learn/scikit-learn/issues/24597 @@ -2023,6 +2026,7 @@ def pairwise_distances( # XXX: the following code is still used for list-of-lists of numbers which # aren't converted to numpy arrays in validation steps done in `check_array`. # TODO: convert list-of-lists to numpy arrays in `check_array`. + # See: https://github.com/scikit-learn/scikit-learn/issues/24745 if issparse(X) or issparse(Y): raise TypeError("scipy distance metrics do not support sparse matrices.")