diff --git a/.gitignore b/.gitignore index 47ec8fa2faf79..7754773320c6e 100644 --- a/.gitignore +++ b/.gitignore @@ -95,5 +95,7 @@ sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pxd sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pyx sklearn/metrics/_pairwise_distances_reduction/_middle_term_computer.pxd sklearn/metrics/_pairwise_distances_reduction/_middle_term_computer.pyx +sklearn/metrics/_pairwise_distances_reduction/_pairwise_distances.pxd +sklearn/metrics/_pairwise_distances_reduction/_pairwise_distances.pyx sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors.pxd sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors.pyx diff --git a/setup.cfg b/setup.cfg index c6a2f37c2ed58..b7d6acdb62832 100644 --- a/setup.cfg +++ b/setup.cfg @@ -85,6 +85,8 @@ ignore = sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pyx sklearn/metrics/_pairwise_distances_reduction/_middle_term_computer.pxd sklearn/metrics/_pairwise_distances_reduction/_middle_term_computer.pyx + sklearn/metrics/_pairwise_distances_reduction/_pairwise_distances.pxd + sklearn/metrics/_pairwise_distances_reduction/_pairwise_distances.pyx sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors.pxd sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors.pyx diff --git a/setup.py b/setup.py index 17d90a45f35ea..53ea2a68e529a 100755 --- a/setup.py +++ b/setup.py @@ -90,6 +90,7 @@ "sklearn.metrics._pairwise_distances_reduction._datasets_pair", "sklearn.metrics._pairwise_distances_reduction._middle_term_computer", "sklearn.metrics._pairwise_distances_reduction._base", + "sklearn.metrics._pairwise_distances_reduction._pairwise_distances", "sklearn.metrics._pairwise_distances_reduction._argkmin", "sklearn.metrics._pairwise_distances_reduction._radius_neighbors", "sklearn.metrics._pairwise_fast", @@ -327,6 +328,12 @@ def check_package_status(package, min_version): "include_np": True, "extra_compile_args": ["-std=c++11"], }, + { + "sources": ["_pairwise_distances.pyx.tp", "_pairwise_distances.pxd.tp"], + "language": "c++", + "include_np": True, + "extra_compile_args": ["-std=c++11"], + }, { "sources": ["_argkmin.pyx.tp", "_argkmin.pxd.tp"], "language": "c++", diff --git a/sklearn/metrics/_pairwise_distances_reduction/__init__.py b/sklearn/metrics/_pairwise_distances_reduction/__init__.py index 133c854682f0c..f4caf911eb898 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/__init__.py +++ b/sklearn/metrics/_pairwise_distances_reduction/__init__.py @@ -89,13 +89,18 @@ from ._dispatcher import ( BaseDistancesReductionDispatcher, ArgKmin, + PairwiseDistances, RadiusNeighbors, sqeuclidean_row_norms, ) +from ._pairwise_distances import _precompute_metric_params + __all__ = [ "BaseDistancesReductionDispatcher", "ArgKmin", + "PairwiseDistances", "RadiusNeighbors", "sqeuclidean_row_norms", + "_precompute_metric_params", ] diff --git a/sklearn/metrics/_pairwise_distances_reduction/_base.pxd.tp b/sklearn/metrics/_pairwise_distances_reduction/_base.pxd.tp index 35c8184d25a6c..2a9cb92a1fe9b 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_base.pxd.tp +++ b/sklearn/metrics/_pairwise_distances_reduction/_base.pxd.tp @@ -50,6 +50,7 @@ cdef class BaseDistancesReduction{{name_suffix}}: ITYPE_t n_samples_X, X_n_samples_chunk, X_n_chunks, X_n_samples_last_chunk ITYPE_t n_samples_Y, Y_n_samples_chunk, Y_n_chunks, Y_n_samples_last_chunk + bint X_is_Y bint execute_in_parallel_on_Y @final diff --git a/sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pxd.tp b/sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pxd.tp index e220f730e7529..16521561a58dc 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pxd.tp +++ b/sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pxd.tp @@ -25,6 +25,8 @@ cdef class DatasetsPair{{name_suffix}}: {{DistanceMetric}} distance_metric ITYPE_t n_features + readonly bint X_is_Y + cdef ITYPE_t n_samples_X(self) nogil cdef ITYPE_t n_samples_Y(self) nogil diff --git a/sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pyx.tp b/sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pyx.tp index 78857341f9c97..18238d781f73a 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pyx.tp +++ b/sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pyx.tp @@ -1,3 +1,5 @@ +import copy + {{py: implementation_specific_values = [ @@ -91,18 +93,24 @@ cdef class DatasetsPair{{name_suffix}}: datasets_pair: DatasetsPair{{name_suffix}} The suited DatasetsPair{{name_suffix}} implementation. """ - # Y_norm_squared might be propagated down to DatasetsPairs - # via metrics_kwargs when the Euclidean specialisations - # can't be used. To prevent Y_norm_squared to be passed + # X_norm_squared and Y_norm_squared might be propagated + # down to DatasetsPairs via metrics_kwargs when the Euclidean + # specialisations can't be used. + # To prevent X_norm_squared and Y_norm_squared to be passed # down to DistanceMetrics (whose constructors would raise - # a RuntimeError), we pop it here. + # a RuntimeError), we pop them here. if metric_kwargs is not None: + # Copying metric_kwargs not to pop "X_norm_squared" + # and "Y_norm_squared" where they are used + metric_kwargs = copy.copy(metric_kwargs) + metric_kwargs.pop("X_norm_squared", None) metric_kwargs.pop("Y_norm_squared", None) cdef: {{DistanceMetric}} distance_metric = {{DistanceMetric}}.get_metric( metric, **(metric_kwargs or {}) ) + bint X_is_Y = X is Y # Metric-specific checks that do not replace nor duplicate `check_array`. distance_metric._validate_data(X) @@ -112,15 +120,15 @@ cdef class DatasetsPair{{name_suffix}}: Y_is_sparse = issparse(Y) if not X_is_sparse and not Y_is_sparse: - return DenseDenseDatasetsPair{{name_suffix}}(X, Y, distance_metric) + return DenseDenseDatasetsPair{{name_suffix}}(X, Y, distance_metric, X_is_Y) if X_is_sparse and Y_is_sparse: - return SparseSparseDatasetsPair{{name_suffix}}(X, Y, distance_metric) + return SparseSparseDatasetsPair{{name_suffix}}(X, Y, distance_metric, X_is_Y) if X_is_sparse and not Y_is_sparse: - return SparseDenseDatasetsPair{{name_suffix}}(X, Y, distance_metric) + return SparseDenseDatasetsPair{{name_suffix}}(X, Y, distance_metric, X_is_Y) - return DenseSparseDatasetsPair{{name_suffix}}(X, Y, distance_metric) + return DenseSparseDatasetsPair{{name_suffix}}(X, Y, distance_metric, X_is_Y) @classmethod def unpack_csr_matrix(cls, X: csr_matrix): @@ -130,8 +138,9 @@ cdef class DatasetsPair{{name_suffix}}: X_indptr = np.asarray(X.indptr, dtype=SPARSE_INDEX_TYPE) return X_data, X_indices, X_indptr - def __init__(self, {{DistanceMetric}} distance_metric, ITYPE_t n_features): + def __init__(self, {{DistanceMetric}} distance_metric, ITYPE_t n_features, bint X_is_Y): self.distance_metric = distance_metric + self.X_is_Y = X_is_Y self.n_features = n_features cdef ITYPE_t n_samples_X(self) nogil: @@ -179,8 +188,9 @@ cdef class DenseDenseDatasetsPair{{name_suffix}}(DatasetsPair{{name_suffix}}): const {{INPUT_DTYPE_t}}[:, ::1] X, const {{INPUT_DTYPE_t}}[:, ::1] Y, {{DistanceMetric}} distance_metric, + bint X_is_Y, ): - super().__init__(distance_metric, n_features=X.shape[1]) + super().__init__(distance_metric, n_features=X.shape[1], X_is_Y=X_is_Y) # Arrays have already been checked self.X = X self.Y = Y @@ -219,8 +229,8 @@ cdef class SparseSparseDatasetsPair{{name_suffix}}(DatasetsPair{{name_suffix}}): between two vectors of (X, Y). """ - def __init__(self, X, Y, {{DistanceMetric}} distance_metric): - super().__init__(distance_metric, n_features=X.shape[1]) + def __init__(self, X, Y, {{DistanceMetric}} distance_metric, bint X_is_Y): + super().__init__(distance_metric, n_features=X.shape[1], X_is_Y=X_is_Y) self.X_data, self.X_indices, self.X_indptr = self.unpack_csr_matrix(X) self.Y_data, self.Y_indices, self.Y_indptr = self.unpack_csr_matrix(Y) @@ -279,8 +289,8 @@ cdef class SparseDenseDatasetsPair{{name_suffix}}(DatasetsPair{{name_suffix}}): between two vectors of (X, Y). """ - def __init__(self, X, Y, {{DistanceMetric}} distance_metric): - super().__init__(distance_metric, n_features=X.shape[1]) + def __init__(self, X, Y, {{DistanceMetric}} distance_metric, bint X_is_Y): + super().__init__(distance_metric, n_features=X.shape[1], X_is_Y=X_is_Y) self.X_data, self.X_indices, self.X_indptr = self.unpack_csr_matrix(X) @@ -377,10 +387,10 @@ cdef class DenseSparseDatasetsPair{{name_suffix}}(DatasetsPair{{name_suffix}}): between two vectors of (X, Y). """ - def __init__(self, X, Y, {{DistanceMetric}} distance_metric): - super().__init__(distance_metric, n_features=X.shape[1]) + def __init__(self, X, Y, {{DistanceMetric}} distance_metric, bint X_is_Y): + super().__init__(distance_metric, n_features=X.shape[1], X_is_Y=X_is_Y) # Swapping arguments on the constructor - self.datasets_pair = SparseDenseDatasetsPair{{name_suffix}}(Y, X, distance_metric) + self.datasets_pair = SparseDenseDatasetsPair{{name_suffix}}(Y, X, distance_metric, X_is_Y) @final cdef ITYPE_t n_samples_X(self) nogil: diff --git a/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py b/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py index 5bde2b063e89f..b9c79abfbac18 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py +++ b/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py @@ -16,6 +16,12 @@ ArgKmin64, ArgKmin32, ) + +from ._pairwise_distances import ( + PairwiseDistances64, + PairwiseDistances32, +) + from ._radius_neighbors import ( RadiusNeighbors64, RadiusNeighbors32, @@ -168,6 +174,132 @@ def compute( """ +class PairwiseDistances(BaseDistancesReductionDispatcher): + """Compute the pairwise distances matrix for two sets of vectors. + + The distance function `dist` depends on the values of the `metric` + and `metric_kwargs` parameters. + + This class only computes the pairwise distances matrix without + applying any reduction on it. It shares most of the underlying + code infrastructure with reducing variants to leverage + cache-aware chunking and multi-thread parallelism. + + This class is not meant to be instantiated, one should only use + its :meth:`compute` classmethod which handles allocation and + deallocation consistently. + """ + + @classmethod + def is_usable_for(cls, X, Y, metric) -> bool: + Y = X if Y is None else Y + return super().is_usable_for(X, Y, metric) + + @classmethod + def compute( + cls, + X, + Y, + metric="euclidean", + chunk_size=None, + metric_kwargs=None, + strategy=None, + ): + """Return pairwise distances matrix for the given arguments. + + Parameters + ---------- + X : ndarray or CSR matrix of shape (n_samples_X, n_features) + Input data. + + Y : ndarray or CSR matrix of shape (n_samples_Y, n_features) + Input data. + + metric : str, default='euclidean' + The distance metric to use. + For a list of available metrics, see the documentation of + :class:`~sklearn.metrics.DistanceMetric`. + + chunk_size : int, default=None, + The number of vectors per chunk. If None (default) looks-up in + scikit-learn configuration for `pairwise_dist_chunk_size`, + and use 256 if it is not set. + + metric_kwargs : dict, default=None + Keyword arguments to pass to specified metric function. + + strategy : str, {'auto', 'parallel_on_X', 'parallel_on_Y'}, default=None + The chunking strategy defining which dataset parallelization are made on. + + For both strategies the computations happens with two nested loops, + respectively on chunks of X and chunks of Y. + Strategies differs on which loop (outer or inner) is made to run + in parallel with the Cython `prange` construct: + + - 'parallel_on_X' dispatches chunks of X uniformly on threads. + Each thread then iterates on all the chunks of Y. This strategy is + embarrassingly parallel and comes with no datastructures + synchronisation. + + - 'parallel_on_Y' dispatches chunks of Y uniformly on threads. + Each thread processes all the chunks of X in turn. This strategy is + a sequence of embarrassingly parallel subtasks (the inner loop on Y + chunks) with intermediate datastructures synchronisation at each + iteration of the sequential outer loop on X chunks. + + - 'auto' relies on a simple heuristic to choose between + 'parallel_on_X' and 'parallel_on_Y': when `X.shape[0]` is large enough, + 'parallel_on_X' is usually the most efficient strategy. + When `X.shape[0]` is small but `Y.shape[0]` is large, 'parallel_on_Y' + brings more opportunity for parallelism and is therefore more efficient. + + - None (default) looks-up in scikit-learn configuration for + `pairwise_dist_parallel_strategy`, and use 'auto' if it is not set. + + Returns + ------- + pairwise_distances_matrix : ndarray of shape (n_samples_X, n_samples_Y) + The pairwise distances matrix. + + Notes + ----- + This public classmethod is responsible for introspecting the arguments + values to dispatch to the private dtype-specialized implementation of + :class:`PairwiseDistances`. + + All temporarily allocated datastructures necessary for the concrete + implementation are therefore freed when this classmethod returns. + + This allows entirely decoupling the API entirely from the + implementation details whilst maintaining RAII. + """ + Y = X if Y is None else Y + if X.dtype == Y.dtype == np.float64: + return PairwiseDistances64.compute( + X=X, + Y=Y, + metric=metric, + chunk_size=chunk_size, + metric_kwargs=metric_kwargs, + strategy=strategy, + ) + + if X.dtype == Y.dtype == np.float32: + return PairwiseDistances32.compute( + X=X, + Y=Y, + metric=metric, + chunk_size=chunk_size, + metric_kwargs=metric_kwargs, + strategy=strategy, + ) + + raise ValueError( + "Only float64 or float32 datasets pairs are supported, but " + f"got: X.dtype={X.dtype} and Y.dtype={Y.dtype}." + ) + + class ArgKmin(BaseDistancesReductionDispatcher): """Compute the argkmin of row vectors of X on the ones of Y. @@ -243,7 +375,7 @@ def compute( 'parallel_on_X' and 'parallel_on_Y': when `X.shape[0]` is large enough, 'parallel_on_X' is usually the most efficient strategy. When `X.shape[0]` is small but `Y.shape[0]` is large, 'parallel_on_Y' - brings more opportunity for parallelism and is therefore more efficient + brings more opportunity for parallelism and is therefore more efficient. - None (default) looks-up in scikit-learn configuration for `pairwise_dist_parallel_strategy`, and use 'auto' if it is not set. @@ -382,9 +514,7 @@ def compute( 'parallel_on_X' and 'parallel_on_Y': when `X.shape[0]` is large enough, 'parallel_on_X' is usually the most efficient strategy. When `X.shape[0]` is small but `Y.shape[0]` is large, 'parallel_on_Y' - brings more opportunity for parallelism and is therefore more efficient - despite the synchronization step at each iteration of the outer loop - on chunks of `X`. + brings more opportunity for parallelism and is therefore more efficient. - None (default) looks-up in scikit-learn configuration for `pairwise_dist_parallel_strategy`, and use 'auto' if it is not set. diff --git a/sklearn/metrics/_pairwise_distances_reduction/_pairwise_distances.pxd.tp b/sklearn/metrics/_pairwise_distances_reduction/_pairwise_distances.pxd.tp new file mode 100644 index 0000000000000..c0fea721284e3 --- /dev/null +++ b/sklearn/metrics/_pairwise_distances_reduction/_pairwise_distances.pxd.tp @@ -0,0 +1,39 @@ +{{py: + +implementation_specific_values = [ + # Values are the following ones: + # + # name_suffix, INPUT_DTYPE_t + # + # + ('64', 'cnp.float64_t'), + ('32', 'cnp.float32_t') +] + +}} +cimport numpy as cnp + +from ...utils._typedefs cimport DTYPE_t +{{for name_suffix, INPUT_DTYPE_t in implementation_specific_values}} + +from ._base cimport BaseDistancesReduction{{name_suffix}} +from ._middle_term_computer cimport MiddleTermComputer{{name_suffix}} + + +cdef class PairwiseDistances{{name_suffix}}(BaseDistancesReduction{{name_suffix}}): + """float{{name_suffix}} implementation of PairwiseDistances.""" + + cdef: + {{INPUT_DTYPE_t}}[:, ::1] pairwise_distances_matrix + + +cdef class EuclideanPairwiseDistances{{name_suffix}}(PairwiseDistances{{name_suffix}}): + """EuclideanDistance-specialized float{{name_suffix}} implementation for PairwiseDistances.""" + cdef: + MiddleTermComputer{{name_suffix}} middle_term_computer + const DTYPE_t[::1] X_norm_squared + const DTYPE_t[::1] Y_norm_squared + + bint use_squared_distances + +{{endfor}} diff --git a/sklearn/metrics/_pairwise_distances_reduction/_pairwise_distances.pyx.tp b/sklearn/metrics/_pairwise_distances_reduction/_pairwise_distances.pyx.tp new file mode 100644 index 0000000000000..f15a65bb43b89 --- /dev/null +++ b/sklearn/metrics/_pairwise_distances_reduction/_pairwise_distances.pyx.tp @@ -0,0 +1,368 @@ +cimport numpy as cnp +from cython cimport final +from ...utils._typedefs cimport ITYPE_t, DTYPE_t + +import numpy as np + +from scipy.sparse import issparse +from ...utils import check_array, _in_unstable_openblas_configuration +from ...utils.fixes import threadpool_limits, sp_version, parse_version +from ...utils._typedefs import DTYPE + +cnp.import_array() + + +def _precompute_metric_params(X, Y, metric=None, **kwds): + """Precompute data-derived metric parameters if not provided.""" + if metric == "seuclidean" and "V" not in kwds: + # There is a bug in scipy < 1.5 that will cause a crash if + # X.dtype != np.double (float64). See PR #15730 + dtype = np.float64 if sp_version < parse_version("1.5") else None + if X is Y: + V = np.var(X, axis=0, ddof=1, dtype=dtype) + else: + raise ValueError( + "The 'V' parameter is required for the seuclidean metric " + "when Y is passed." + ) + return {"V": V} + if metric == "mahalanobis" and "VI" not in kwds: + if X is Y: + VI = np.linalg.inv(np.cov(X.T)).T + else: + raise ValueError( + "The 'VI' parameter is required for the mahalanobis metric " + "when Y is passed." + ) + return {"VI": VI} + return {} + +{{for name_suffix, INPUT_DTYPE in (('64', 'DTYPE'),('32', 'np.float32'))}} + +from ._base cimport ( + BaseDistancesReduction{{name_suffix}}, + _sqeuclidean_row_norms{{name_suffix}}, +) + +from ._datasets_pair cimport ( + DatasetsPair{{name_suffix}}, + DenseDenseDatasetsPair{{name_suffix}}, +) + +from ._middle_term_computer cimport MiddleTermComputer{{name_suffix}} + + +cdef class PairwiseDistances{{name_suffix}}(BaseDistancesReduction{{name_suffix}}): + """float{{name_suffix}} implementation of PairwiseDistances.""" + + @classmethod + def compute( + cls, + X, + Y, + str metric="euclidean", + chunk_size=None, + dict metric_kwargs=None, + str strategy=None, + ): + """Compute the pairwise-distances matrix. + + This classmethod is responsible for introspecting the arguments + values to dispatch to the most appropriate implementation of + :class:`PairwiseDistances{{name_suffix}}`. + + This allows decoupling the API entirely from the implementation details + whilst maintaining RAII: all temporarily allocated datastructures necessary + for the concrete implementation are therefore freed when this classmethod + returns. + + No instance should directly be created outside of this class method. + """ + if ( + metric in ("euclidean", "l2", "sqeuclidean") + and not issparse(X) + and not issparse(Y) + ): + # Specialized implementation of PairwiseDistances for the Euclidean distance. + # This implementation computes the distances by chunk using + # a decomposition of the Squared Euclidean distance. + # This specialisation has an improved arithmetic intensity for both + # the dense and sparse settings, allowing in most case speed-ups of + # several orders of magnitude compared to the generic RadiusNeighbors + # implementation. + # For more information see MiddleTermComputer. + use_squared_distances = metric == "sqeuclidean" + pdr = EuclideanPairwiseDistances{{name_suffix}}( + X=X, Y=Y, + use_squared_distances=use_squared_distances, + chunk_size=chunk_size, + metric_kwargs=metric_kwargs, + strategy=strategy, + ) + else: + # Precompute data-derived distance metric parameters + metric_kwargs = {} if metric_kwargs is None else metric_kwargs + + params = _precompute_metric_params( + X, + Y, + metric=metric, + **metric_kwargs, + ) + metric_kwargs.update(**params) + + # Fall back on a generic implementation that handles most scipy + # metrics by computing the distances between 2 vectors at a time. + pdr = PairwiseDistances{{name_suffix}}( + datasets_pair=DatasetsPair{{name_suffix}}.get_for(X, Y, metric, metric_kwargs), + chunk_size=chunk_size, + strategy=strategy, + ) + + # Limit the number of threads in second level of nested parallelism for BLAS + # to avoid threads over-subscription (in GEMM for instance). + with threadpool_limits(limits=1, user_api="blas"): + if pdr.execute_in_parallel_on_Y: + pdr._parallel_on_Y() + else: + pdr._parallel_on_X() + + return pdr._finalize_results() + + + def __init__( + self, + DatasetsPair{{name_suffix}} datasets_pair, + chunk_size=None, + strategy=None, + sort_results=False, + ): + super().__init__( + datasets_pair=datasets_pair, + chunk_size=chunk_size, + strategy=strategy, + ) + + # Distance matrix which will be complete and returned to the caller. + self.pairwise_distances_matrix = np.empty( + (self.n_samples_X, self.n_samples_Y), dtype={{INPUT_DTYPE}}, + ) + + def _finalize_results(self): + # If X is Y, then catastrophic cancellation might have occurred for + # computations of terms on the diagonal which must equal zero. + # We enforce it by zeroing the diagonal. + distance_matrix = np.asarray(self.pairwise_distances_matrix) + if self.datasets_pair.X_is_Y: + np.fill_diagonal(distance_matrix, 0.) + + return distance_matrix + + cdef void _compute_and_reduce_distances_on_chunks( + self, + ITYPE_t X_start, + ITYPE_t X_end, + ITYPE_t Y_start, + ITYPE_t Y_end, + ITYPE_t thread_num, + ) nogil: + cdef: + ITYPE_t i, j + + for i in range(X_start, X_end): + for j in range(Y_start, Y_end): + self.pairwise_distances_matrix[i, j] = self.datasets_pair.dist(i, j) + + +cdef class EuclideanPairwiseDistances{{name_suffix}}(PairwiseDistances{{name_suffix}}): + """EuclideanDistance-specialized float{{name_suffix}} implementation for PairwiseDistances.""" + + @classmethod + def is_usable_for(cls, X, Y, metric) -> bool: + return (PairwiseDistances{{name_suffix}}.is_usable_for(X, Y, metric) + and not _in_unstable_openblas_configuration()) + + def __init__( + self, + X, + Y, + bint use_squared_distances=False, + chunk_size=None, + strategy=None, + metric_kwargs=None, + ): + super().__init__( + # The datasets pair here is used for exact distances computations + datasets_pair=DatasetsPair{{name_suffix}}.get_for(X, Y, metric="euclidean"), + chunk_size=chunk_size, + strategy=strategy, + ) + cdef: + ITYPE_t dist_middle_terms_chunks_size = ( + self.Y_n_samples_chunk * self.X_n_samples_chunk + ) + + self.middle_term_computer = MiddleTermComputer{{name_suffix}}.get_for( + X, + Y, + self.effective_n_threads, + self.chunks_n_threads, + dist_middle_terms_chunks_size, + n_features=X.shape[1], + chunk_size=self.chunk_size, + ) + + if metric_kwargs is not None and "Y_norm_squared" in metric_kwargs: + self.Y_norm_squared = check_array( + metric_kwargs.pop("Y_norm_squared"), + ensure_2d=False, + input_name="Y_norm_squared", + dtype=np.float64 + ) + else: + self.Y_norm_squared = _sqeuclidean_row_norms{{name_suffix}}( + Y, + self.effective_n_threads, + ) + + if metric_kwargs is not None and "X_norm_squared" in metric_kwargs: + self.X_norm_squared = check_array( + metric_kwargs.pop("X_norm_squared"), + ensure_2d=False, + input_name="X_norm_squared", + dtype=np.float64 + ) + else: + # Do not recompute norms if datasets are identical. + self.X_norm_squared = ( + self.Y_norm_squared if X is Y else + _sqeuclidean_row_norms{{name_suffix}}( + X, + self.effective_n_threads, + ) + ) + + self.use_squared_distances = use_squared_distances + + + @final + cdef void _parallel_on_X_parallel_init( + self, + ITYPE_t thread_num, + ) nogil: + PairwiseDistances{{name_suffix}}._parallel_on_X_parallel_init(self, thread_num) + self.middle_term_computer._parallel_on_X_parallel_init(thread_num) + + @final + cdef void _parallel_on_X_init_chunk( + self, + ITYPE_t thread_num, + ITYPE_t X_start, + ITYPE_t X_end, + ) nogil: + PairwiseDistances{{name_suffix}}._parallel_on_X_init_chunk(self, thread_num, X_start, X_end) + self.middle_term_computer._parallel_on_X_init_chunk(thread_num, X_start, X_end) + + @final + cdef void _parallel_on_X_pre_compute_and_reduce_distances_on_chunks( + self, + ITYPE_t X_start, + ITYPE_t X_end, + ITYPE_t Y_start, + ITYPE_t Y_end, + ITYPE_t thread_num, + ) nogil: + PairwiseDistances{{name_suffix}}._parallel_on_X_pre_compute_and_reduce_distances_on_chunks( + self, + X_start, X_end, + Y_start, Y_end, + thread_num, + ) + self.middle_term_computer._parallel_on_X_pre_compute_and_reduce_distances_on_chunks( + X_start, X_end, Y_start, Y_end, thread_num, + ) + + @final + cdef void _parallel_on_Y_init( + self, + ) nogil: + cdef ITYPE_t thread_num + PairwiseDistances{{name_suffix}}._parallel_on_Y_init(self) + self.middle_term_computer._parallel_on_Y_init() + + @final + cdef void _parallel_on_Y_parallel_init( + self, + ITYPE_t thread_num, + ITYPE_t X_start, + ITYPE_t X_end, + ) nogil: + PairwiseDistances{{name_suffix}}._parallel_on_Y_parallel_init(self, thread_num, X_start, X_end) + self.middle_term_computer._parallel_on_Y_parallel_init(thread_num, X_start, X_end) + + @final + cdef void _parallel_on_Y_pre_compute_and_reduce_distances_on_chunks( + self, + ITYPE_t X_start, + ITYPE_t X_end, + ITYPE_t Y_start, + ITYPE_t Y_end, + ITYPE_t thread_num, + ) nogil: + PairwiseDistances{{name_suffix}}._parallel_on_Y_pre_compute_and_reduce_distances_on_chunks( + self, + X_start, X_end, + Y_start, Y_end, + thread_num, + ) + self.middle_term_computer._parallel_on_Y_pre_compute_and_reduce_distances_on_chunks( + X_start, X_end, Y_start, Y_end, thread_num + ) + + + @final + cdef void _compute_and_reduce_distances_on_chunks( + self, + ITYPE_t X_start, + ITYPE_t X_end, + ITYPE_t Y_start, + ITYPE_t Y_end, + ITYPE_t thread_num, + ) nogil: + cdef: + ITYPE_t i, j + ITYPE_t pair_index = 0 + DTYPE_t sq_dist_i_j = 0. + + DTYPE_t *dist_middle_terms = self.middle_term_computer._compute_dist_middle_terms( + X_start, X_end, Y_start, Y_end, thread_num + ) + + for i in range(X_start, X_end): + for j in range(Y_start, Y_end): + # Using the squared euclidean distance as the rank-preserving distance: + # + # ||X_c_i||² - 2 X_c_i.Y_c_j^T + ||Y_c_j||² + # + sq_dist_i_j = ( + self.X_norm_squared[i] + + dist_middle_terms[pair_index] + + self.Y_norm_squared[j] + ) + # Guard against eventual -0. and NaN caused by catastrophic + # cancellation (e.g. if X is Y) + self.pairwise_distances_matrix[i, j] = ( + sq_dist_i_j if 0. < sq_dist_i_j == sq_dist_i_j else 0. + ) + pair_index += 1 + + def _finalize_results(self): + distance_matrix = PairwiseDistances{{name_suffix}}._finalize_results(self) + # Squared Euclidean distances have been used for efficiency. + # We remap them to Euclidean distances here before finalizing + # results. + if not self.use_squared_distances: + return np.sqrt(distance_matrix) + return PairwiseDistances{{name_suffix}}._finalize_results(self) + +{{endfor}} diff --git a/sklearn/metrics/_pairwise_fast.pyx b/sklearn/metrics/_pairwise_fast.pyx index b9006773c015d..b6941dba290db 100644 --- a/sklearn/metrics/_pairwise_fast.pyx +++ b/sklearn/metrics/_pairwise_fast.pyx @@ -6,10 +6,6 @@ cimport numpy as cnp from cython cimport floating -from cython.parallel cimport prange -from libc.math cimport fabs - -from ..utils._openmp_helpers import _openmp_effective_n_threads cnp.import_array() @@ -33,73 +29,3 @@ def _chi2_kernel_fast(floating[:, :] X, if nom != 0: res += denom * denom / nom result[i, j] = -res - - -def _sparse_manhattan(floating[::1] X_data, int[:] X_indices, int[:] X_indptr, - floating[::1] Y_data, int[:] Y_indices, int[:] Y_indptr, - double[:, ::1] D): - """Pairwise L1 distances for CSR matrices. - - Usage: - >>> D = np.zeros(X.shape[0], Y.shape[0]) - >>> _sparse_manhattan(X.data, X.indices, X.indptr, - ... Y.data, Y.indices, Y.indptr, - ... D) - """ - cdef cnp.npy_intp px, py, i, j, ix, iy - cdef double d = 0.0 - - cdef int m = D.shape[0] - cdef int n = D.shape[1] - - cdef int X_indptr_end = 0 - cdef int Y_indptr_end = 0 - - cdef int num_threads = _openmp_effective_n_threads() - - # We scan the matrices row by row. - # Given row px in X and row py in Y, we find the positions (i and j - # respectively), in .indices where the indices for the two rows start. - # If the indices (ix and iy) are the same, the corresponding data values - # are processed and the cursors i and j are advanced. - # If not, the lowest index is considered. Its associated data value is - # processed and its cursor is advanced. - # We proceed like this until one of the cursors hits the end for its row. - # Then we process all remaining data values in the other row. - - # Below the avoidance of inplace operators is intentional. - # When prange is used, the inplace operator has a special meaning, i.e. it - # signals a "reduction" - - for px in prange(m, nogil=True, num_threads=num_threads): - X_indptr_end = X_indptr[px + 1] - for py in range(n): - Y_indptr_end = Y_indptr[py + 1] - i = X_indptr[px] - j = Y_indptr[py] - d = 0.0 - while i < X_indptr_end and j < Y_indptr_end: - ix = X_indices[i] - iy = Y_indices[j] - - if ix == iy: - d = d + fabs(X_data[i] - Y_data[j]) - i = i + 1 - j = j + 1 - elif ix < iy: - d = d + fabs(X_data[i]) - i = i + 1 - else: - d = d + fabs(Y_data[j]) - j = j + 1 - - if i == X_indptr_end: - while j < Y_indptr_end: - d = d + fabs(Y_data[j]) - j = j + 1 - else: - while i < X_indptr_end: - d = d + fabs(X_data[i]) - i = i + 1 - - D[px, py] = d diff --git a/sklearn/metrics/pairwise.py b/sklearn/metrics/pairwise.py index e6abd596b0000..f28a83aa74ad0 100644 --- a/sklearn/metrics/pairwise.py +++ b/sklearn/metrics/pairwise.py @@ -28,10 +28,13 @@ from ..preprocessing import normalize from ..utils._mask import _get_mask from ..utils.fixes import delayed -from ..utils.fixes import sp_version, parse_version -from ._pairwise_distances_reduction import ArgKmin -from ._pairwise_fast import _chi2_kernel_fast, _sparse_manhattan +from ._pairwise_distances_reduction import ( + ArgKmin, + PairwiseDistances, + _precompute_metric_params, +) +from ._pairwise_fast import _chi2_kernel_fast from ..exceptions import DataConversionWarning @@ -337,6 +340,21 @@ def _euclidean_distances(X, Y, X_norm_squared=None, Y_norm_squared=None, squared float32, norms needs to be recomputed on upcast chunks. TODO: use a float64 accumulator in row_norms to avoid the latter. """ + metric = "sqeuclidean" if squared else "euclidean" + if PairwiseDistances.is_usable_for(X, Y, metric): + metric_kwargs = {} + if X_norm_squared is not None: + metric_kwargs["X_norm_squared"] = np.ravel(X_norm_squared) + + if Y_norm_squared is not None: + metric_kwargs["Y_norm_squared"] = np.ravel(Y_norm_squared) + + return PairwiseDistances.compute(X, Y, metric, metric_kwargs=metric_kwargs) + + # XXX: the following code is still used for list-of-lists of numbers which + # aren't converted to numpy arrays in validation steps done in `check_array`. + # TODO: convert list-of-lists to numpy arrays in `check_array`. + # See: https://github.com/scikit-learn/scikit-learn/issues/24745 if X_norm_squared is not None: if X_norm_squared.dtype == np.float32: XX = None @@ -868,6 +886,15 @@ def haversine_distances(X, Y=None): array([[ 0. , 11099.54035582], [11099.54035582, 0. ]]) """ + + if PairwiseDistances.is_usable_for(X, Y, metric="haversine"): + return PairwiseDistances.compute(X, Y, metric="haversine") + + # XXX: the following code is still used for list-of-lists of numbers which + # aren't converted to numpy arrays in validation steps done in `check_array`. + # TODO: convert list-of-lists to numpy arrays in `check_array`. + # See: https://github.com/scikit-learn/scikit-learn/issues/24745 + from ..metrics import DistanceMetric return DistanceMetric.get_metric("haversine").pairwise(X, Y) @@ -941,20 +968,30 @@ def manhattan_distances(X, Y=None, *, sum_over_features="deprecated"): X, Y = check_pairwise_arrays(X, Y) - if issparse(X) or issparse(Y): - if not sum_over_features: - raise TypeError( - "sum_over_features=%r not supported for sparse matrices" - % sum_over_features - ) - + if issparse(X): X = csr_matrix(X, copy=False) + # This also sorts indices in-place. + X.sum_duplicates() + + if issparse(Y): Y = csr_matrix(Y, copy=False) - X.sum_duplicates() # this also sorts indices in-place + # This also sorts indices in-place. Y.sum_duplicates() - D = np.zeros((X.shape[0], Y.shape[0])) - _sparse_manhattan(X.data, X.indices, X.indptr, Y.data, Y.indices, Y.indptr, D) - return D + + if sum_over_features and PairwiseDistances.is_usable_for(X, Y, metric="manhattan"): + return PairwiseDistances.compute(X, Y, metric="manhattan") + + # XXX: the following code is still used for list-of-lists of numbers which + # aren't converted to numpy arrays in validation steps done in `check_array` + # and for supporting `sum_over_features` which we should probably remove. + # TODO: convert list-of-lists to numpy arrays in `check_array`. + # See: https://github.com/scikit-learn/scikit-learn/issues/24745 + # TODO: remove `sum_over_features`, see: + # https://github.com/scikit-learn/scikit-learn/issues/24597 + + if issparse(X) or issparse(Y): + if not sum_over_features: + raise TypeError("sum_over_features=False not supported for sparse matrices") if sum_over_features: return distance.cdist(X, Y, "cityblock") @@ -1673,32 +1710,6 @@ def _check_chunk_size(reduced, chunk_size): ) -def _precompute_metric_params(X, Y, metric=None, **kwds): - """Precompute data-derived metric parameters if not provided.""" - if metric == "seuclidean" and "V" not in kwds: - # There is a bug in scipy < 1.5 that will cause a crash if - # X.dtype != np.double (float64). See PR #15730 - dtype = np.float64 if sp_version < parse_version("1.5") else None - if X is Y: - V = np.var(X, axis=0, ddof=1, dtype=dtype) - else: - raise ValueError( - "The 'V' parameter is required for the seuclidean metric " - "when Y is passed." - ) - return {"V": V} - if metric == "mahalanobis" and "VI" not in kwds: - if X is Y: - VI = np.linalg.inv(np.cov(X.T)).T - else: - raise ValueError( - "The 'VI' parameter is required for the mahalanobis metric " - "when Y is passed." - ) - return {"VI": VI} - return {} - - def pairwise_distances_chunked( X, Y=None, @@ -1991,6 +2002,17 @@ def pairwise_distances( % (metric, _VALID_METRICS) ) + if PairwiseDistances.is_usable_for(X, X if Y is None else Y, metric=metric): + # This is an adaptor for one "sqeuclidean" specification. + # For this backend, we can directly use "sqeuclidean". + if kwds.get("squared", False) and metric == "euclidean": + metric = "sqeuclidean" + kwds = {} + + return PairwiseDistances.compute( + X, X if Y is None else Y, metric=metric, metric_kwargs=kwds + ) + if metric == "precomputed": X, _ = check_pairwise_arrays( X, Y, precomputed=True, force_all_finite=force_all_finite @@ -2009,6 +2031,11 @@ def pairwise_distances( _pairwise_callable, metric=metric, force_all_finite=force_all_finite, **kwds ) else: + # XXX: the following code is still used for list-of-lists of numbers which + # aren't converted to numpy arrays in validation steps done in `check_array`. + # TODO: convert list-of-lists to numpy arrays in `check_array`. + # See: https://github.com/scikit-learn/scikit-learn/issues/24745 + if issparse(X) or issparse(Y): raise TypeError("scipy distance metrics do not support sparse matrices.") diff --git a/sklearn/metrics/tests/test_pairwise.py b/sklearn/metrics/tests/test_pairwise.py index 3624983c4c481..971c11af1dc23 100644 --- a/sklearn/metrics/tests/test_pairwise.py +++ b/sklearn/metrics/tests/test_pairwise.py @@ -153,25 +153,14 @@ def test_pairwise_distances(global_dtype): S = pairwise_distances(X_sparse, Y_sparse.tocsc(), metric="manhattan") S2 = manhattan_distances(X_sparse.tobsr(), Y_sparse.tocoo()) assert_allclose(S, S2) - if global_dtype == np.float64: - assert S.dtype == S2.dtype == global_dtype - else: - # TODO Fix manhattan_distances to preserve dtype. - # currently pairwise_distances uses manhattan_distances but converts the result - # back to the input dtype - with pytest.raises(AssertionError): - assert S.dtype == S2.dtype == global_dtype + + # pairwise_distances must preserves dtypes for the manhattan distance metric + assert S.dtype == S2.dtype == global_dtype S2 = manhattan_distances(X, Y) assert_allclose(S, S2) - if global_dtype == np.float64: - assert S.dtype == S2.dtype == global_dtype - else: - # TODO Fix manhattan_distances to preserve dtype. - # currently pairwise_distances uses manhattan_distances but converts the result - # back to the input dtype - with pytest.raises(AssertionError): - assert S.dtype == S2.dtype == global_dtype + # manhattan_distances must preserves dtypes + assert S.dtype == S2.dtype == global_dtype # Test with scipy.spatial.distance metric, with a kwd kwds = {"p": 2.0} @@ -185,12 +174,6 @@ def test_pairwise_distances(global_dtype): S2 = pairwise_distances(X, metric=minkowski, **kwds) assert_allclose(S, S2) - # Test that scipy distance metrics throw an error if sparse matrix given - with pytest.raises(TypeError): - pairwise_distances(X_sparse, metric="minkowski") - with pytest.raises(TypeError): - pairwise_distances(X, Y_sparse, metric="minkowski") - # Test that a value error is raised if the metric is unknown with pytest.raises(ValueError): pairwise_distances(X, Y, metric="blah") @@ -942,15 +925,10 @@ def test_euclidean_distances_upcast_sym(batch_size, x_array_constr): "dtype, eps, rtol", [ (np.float32, 1e-4, 1e-5), - pytest.param( - np.float64, - 1e-8, - 0.99, - marks=pytest.mark.xfail(reason="failing due to lack of precision"), - ), + (np.float64, 1e-8, 0.99), ], ) -@pytest.mark.parametrize("dim", [1, 1000000]) +@pytest.mark.parametrize("dim", [1, 100000]) def test_euclidean_distances_extreme_values(dtype, eps, rtol, dim): # check that euclidean distances is correct with float32 input thanks to # upcasting. On float64 there are still precision issues. @@ -960,7 +938,7 @@ def test_euclidean_distances_extreme_values(dtype, eps, rtol, dim): distances = euclidean_distances(X, Y) expected = cdist(X, Y) - assert_allclose(distances, expected, rtol=1e-5) + assert_allclose(distances, expected, rtol=1e-5, atol=4e-4) @pytest.mark.parametrize("squared", [True, False]) diff --git a/sklearn/metrics/tests/test_pairwise_distances_reduction.py b/sklearn/metrics/tests/test_pairwise_distances_reduction.py index f929a55105509..a89544bdb0a95 100644 --- a/sklearn/metrics/tests/test_pairwise_distances_reduction.py +++ b/sklearn/metrics/tests/test_pairwise_distances_reduction.py @@ -13,6 +13,7 @@ BaseDistancesReductionDispatcher, ArgKmin, RadiusNeighbors, + PairwiseDistances, sqeuclidean_row_norms, ) @@ -691,6 +692,44 @@ def test_radius_neighbors_factory_method_wrong_usages(): ) +def test_pairwise_distances_factory_method_wrong_usages(): + rng = np.random.RandomState(1) + X = rng.rand(100, 10) + Y = rng.rand(100, 10) + metric = "euclidean" + + msg = ( + "Only float64 or float32 datasets pairs are supported, but " + "got: X.dtype=float32 and Y.dtype=float64" + ) + with pytest.raises( + ValueError, + match=msg, + ): + PairwiseDistances.compute(X=X.astype(np.float32), Y=Y, metric=metric) + + msg = ( + "Only float64 or float32 datasets pairs are supported, but " + "got: X.dtype=float64 and Y.dtype=int32" + ) + with pytest.raises( + ValueError, + match=msg, + ): + PairwiseDistances.compute(X=X, Y=Y.astype(np.int32), metric=metric) + + with pytest.raises(ValueError, match="Unrecognized metric"): + PairwiseDistances.compute(X=X, Y=Y, metric="wrong metric") + + with pytest.raises( + ValueError, match=r"Buffer has wrong number of dimensions \(expected 2, got 1\)" + ): + PairwiseDistances.compute(X=np.array([1.0, 2.0]), Y=Y, metric=metric) + + with pytest.raises(ValueError, match="ndarray is not C-contiguous"): + PairwiseDistances.compute(X=np.asfortranarray(X), Y=Y, metric=metric) + + @pytest.mark.parametrize( "n_samples_X, n_samples_Y", [(100, 100), (500, 100), (100, 500)] ) diff --git a/sklearn/neighbors/_nca.py b/sklearn/neighbors/_nca.py index 4a83fcc7bc080..66e103810c6fa 100644 --- a/sklearn/neighbors/_nca.py +++ b/sklearn/neighbors/_nca.py @@ -486,7 +486,7 @@ def _loss_grad_lbfgs(self, transformation, X, same_class_mask, sign=1.0): X_embedded = np.dot(X, transformation.T) # (n_samples, n_components) # Compute softmax distances - p_ij = pairwise_distances(X_embedded, squared=True) + p_ij = pairwise_distances(X_embedded, metric="sqeuclidean") np.fill_diagonal(p_ij, np.inf) p_ij = softmax(-p_ij) # (n_samples, n_samples)