diff --git a/.gitignore b/.gitignore index f4601a15655a5..f07a2654e0a8a 100644 --- a/.gitignore +++ b/.gitignore @@ -97,6 +97,8 @@ sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pxd sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pyx sklearn/metrics/_pairwise_distances_reduction/_middle_term_computer.pxd sklearn/metrics/_pairwise_distances_reduction/_middle_term_computer.pyx +sklearn/metrics/_pairwise_distances_reduction/_pairwise_distances.pxd +sklearn/metrics/_pairwise_distances_reduction/_pairwise_distances.pyx sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors.pxd sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors.pyx diff --git a/doc/whats_new/v1.3.rst b/doc/whats_new/v1.3.rst index 7bdbe7841f0d4..cac0c90501e2b 100644 --- a/doc/whats_new/v1.3.rst +++ b/doc/whats_new/v1.3.rst @@ -518,6 +518,16 @@ Changelog - |Fix| :func:`metrics.manhattan_distances` now supports readonly sparse datasets. :pr:`25432` by :user:`Julien Jerphanion `. +- |Efficiency| :func:`pairwise.pairwise_distances`' performance has been improved + when providing dense datasets. + :pr:`25561` by :user:`Vincent Maladiere ` and + :user:`Julien Jerphanion `. + +- |Feature| :func:`pairwise.pairwise_distances` now supports combination of + dense arrays and sparse CSR matrices datasets. + :pr:`25561` by :user:`Vincent Maladiere ` and + :user:`Julien Jerphanion `. + - |Fix| Fixed :func:`metrics.classification_report` so that empty input will return `np.nan`. Previously, "macro avg" and `weighted avg` would return e.g. `f1-score=np.nan` and `f1-score=0.0`, being inconsistent. Now, they diff --git a/setup.cfg b/setup.cfg index d91a27344c575..479a850d03bac 100644 --- a/setup.cfg +++ b/setup.cfg @@ -52,6 +52,8 @@ ignore = sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pyx sklearn/metrics/_pairwise_distances_reduction/_middle_term_computer.pxd sklearn/metrics/_pairwise_distances_reduction/_middle_term_computer.pyx + sklearn/metrics/_pairwise_distances_reduction/_pairwise_distances.pxd + sklearn/metrics/_pairwise_distances_reduction/_pairwise_distances.pyx sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors.pxd sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors.pyx diff --git a/setup.py b/setup.py index 5af738f5f841f..0ee1cd71cebd0 100755 --- a/setup.py +++ b/setup.py @@ -277,6 +277,12 @@ def check_package_status(package, min_version): "include_np": True, "extra_compile_args": ["-std=c++11"], }, + { + "sources": ["_pairwise_distances.pyx.tp", "_pairwise_distances.pxd.tp"], + "language": "c++", + "include_np": True, + "extra_compile_args": ["-std=c++11"], + }, { "sources": ["_argkmin.pyx.tp", "_argkmin.pxd.tp"], "language": "c++", diff --git a/sklearn/metrics/_pairwise_distances_reduction/__init__.py b/sklearn/metrics/_pairwise_distances_reduction/__init__.py index 68972de0a1a51..9d9075fcbebd5 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/__init__.py +++ b/sklearn/metrics/_pairwise_distances_reduction/__init__.py @@ -90,14 +90,18 @@ ArgKmin, ArgKminClassMode, BaseDistancesReductionDispatcher, + PairwiseDistances, RadiusNeighbors, sqeuclidean_row_norms, ) +from ._pairwise_distances import _precompute_metric_params __all__ = [ "BaseDistancesReductionDispatcher", "ArgKmin", + "PairwiseDistances", "RadiusNeighbors", "ArgKminClassMode", "sqeuclidean_row_norms", + "_precompute_metric_params", ] diff --git a/sklearn/metrics/_pairwise_distances_reduction/_base.pxd.tp b/sklearn/metrics/_pairwise_distances_reduction/_base.pxd.tp index 9578129993c37..4eb76b14f25e4 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_base.pxd.tp +++ b/sklearn/metrics/_pairwise_distances_reduction/_base.pxd.tp @@ -46,6 +46,7 @@ cdef class BaseDistancesReduction{{name_suffix}}: intp_t n_samples_X, X_n_samples_chunk, X_n_chunks, X_n_samples_last_chunk intp_t n_samples_Y, Y_n_samples_chunk, Y_n_chunks, Y_n_samples_last_chunk + bint X_is_Y bint execute_in_parallel_on_Y @final diff --git a/sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pxd.tp b/sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pxd.tp index fc56a59cab16f..04070f145cbdd 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pxd.tp +++ b/sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pxd.tp @@ -20,6 +20,8 @@ cdef class DatasetsPair{{name_suffix}}: {{DistanceMetric}} distance_metric intp_t n_features + readonly bint X_is_Y + cdef intp_t n_samples_X(self) noexcept nogil cdef intp_t n_samples_Y(self) noexcept nogil diff --git a/sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pyx.tp b/sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pyx.tp index 40a9a45e8b8e1..20889d1d934db 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pyx.tp +++ b/sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pyx.tp @@ -1,3 +1,5 @@ +import copy + {{py: implementation_specific_values = [ @@ -84,12 +86,17 @@ cdef class DatasetsPair{{name_suffix}}: datasets_pair: DatasetsPair{{name_suffix}} The suited DatasetsPair{{name_suffix}} implementation. """ - # Y_norm_squared might be propagated down to DatasetsPairs - # via metrics_kwargs when the Euclidean specialisations - # can't be used. To prevent Y_norm_squared to be passed + # X_norm_squared and Y_norm_squared might be propagated + # down to DatasetsPairs via metrics_kwargs when the Euclidean + # specialisations can't be used. + # To prevent X_norm_squared and Y_norm_squared to be passed # down to DistanceMetrics (whose constructors would raise - # a RuntimeError), we pop it here. + # a RuntimeError), we pop them here. if metric_kwargs is not None: + # Copying metric_kwargs not to pop "X_norm_squared" + # and "Y_norm_squared" where they are used + metric_kwargs = copy.copy(metric_kwargs) + metric_kwargs.pop("X_norm_squared", None) metric_kwargs.pop("Y_norm_squared", None) cdef: {{DistanceMetric}} distance_metric = DistanceMetric.get_metric( @@ -97,6 +104,7 @@ cdef class DatasetsPair{{name_suffix}}: {{INPUT_DTYPE}}, **(metric_kwargs or {}) ) + bint X_is_Y = X is Y # Metric-specific checks that do not replace nor duplicate `check_array`. distance_metric._validate_data(X) @@ -106,15 +114,15 @@ cdef class DatasetsPair{{name_suffix}}: Y_is_sparse = issparse(Y) if not X_is_sparse and not Y_is_sparse: - return DenseDenseDatasetsPair{{name_suffix}}(X, Y, distance_metric) + return DenseDenseDatasetsPair{{name_suffix}}(X, Y, distance_metric, X_is_Y) if X_is_sparse and Y_is_sparse: - return SparseSparseDatasetsPair{{name_suffix}}(X, Y, distance_metric) + return SparseSparseDatasetsPair{{name_suffix}}(X, Y, distance_metric, X_is_Y) if X_is_sparse and not Y_is_sparse: - return SparseDenseDatasetsPair{{name_suffix}}(X, Y, distance_metric) + return SparseDenseDatasetsPair{{name_suffix}}(X, Y, distance_metric, X_is_Y) - return DenseSparseDatasetsPair{{name_suffix}}(X, Y, distance_metric) + return DenseSparseDatasetsPair{{name_suffix}}(X, Y, distance_metric, X_is_Y) @classmethod def unpack_csr_matrix(cls, X: csr_matrix): @@ -124,8 +132,9 @@ cdef class DatasetsPair{{name_suffix}}: X_indptr = np.asarray(X.indptr, dtype=np.int32) return X_data, X_indices, X_indptr - def __init__(self, {{DistanceMetric}} distance_metric, intp_t n_features): + def __init__(self, {{DistanceMetric}} distance_metric, intp_t n_features, bint X_is_Y): self.distance_metric = distance_metric + self.X_is_Y = X_is_Y self.n_features = n_features cdef intp_t n_samples_X(self) noexcept nogil: @@ -173,8 +182,9 @@ cdef class DenseDenseDatasetsPair{{name_suffix}}(DatasetsPair{{name_suffix}}): const {{INPUT_DTYPE_t}}[:, ::1] X, const {{INPUT_DTYPE_t}}[:, ::1] Y, {{DistanceMetric}} distance_metric, + bint X_is_Y, ): - super().__init__(distance_metric, n_features=X.shape[1]) + super().__init__(distance_metric, n_features=X.shape[1], X_is_Y=X_is_Y) # Arrays have already been checked self.X = X self.Y = Y @@ -213,8 +223,8 @@ cdef class SparseSparseDatasetsPair{{name_suffix}}(DatasetsPair{{name_suffix}}): between two vectors of (X, Y). """ - def __init__(self, X, Y, {{DistanceMetric}} distance_metric): - super().__init__(distance_metric, n_features=X.shape[1]) + def __init__(self, X, Y, {{DistanceMetric}} distance_metric, bint X_is_Y): + super().__init__(distance_metric, n_features=X.shape[1], X_is_Y=X_is_Y) self.X_data, self.X_indices, self.X_indptr = self.unpack_csr_matrix(X) self.Y_data, self.Y_indices, self.Y_indptr = self.unpack_csr_matrix(Y) @@ -273,8 +283,8 @@ cdef class SparseDenseDatasetsPair{{name_suffix}}(DatasetsPair{{name_suffix}}): between two vectors of (X, Y). """ - def __init__(self, X, Y, {{DistanceMetric}} distance_metric): - super().__init__(distance_metric, n_features=X.shape[1]) + def __init__(self, X, Y, {{DistanceMetric}} distance_metric, bint X_is_Y): + super().__init__(distance_metric, n_features=X.shape[1], X_is_Y=X_is_Y) self.X_data, self.X_indices, self.X_indptr = self.unpack_csr_matrix(X) @@ -371,10 +381,10 @@ cdef class DenseSparseDatasetsPair{{name_suffix}}(DatasetsPair{{name_suffix}}): between two vectors of (X, Y). """ - def __init__(self, X, Y, {{DistanceMetric}} distance_metric): - super().__init__(distance_metric, n_features=X.shape[1]) + def __init__(self, X, Y, {{DistanceMetric}} distance_metric, bint X_is_Y): + super().__init__(distance_metric, n_features=X.shape[1], X_is_Y=X_is_Y) # Swapping arguments on the constructor - self.datasets_pair = SparseDenseDatasetsPair{{name_suffix}}(Y, X, distance_metric) + self.datasets_pair = SparseDenseDatasetsPair{{name_suffix}}(Y, X, distance_metric, X_is_Y) @final cdef intp_t n_samples_X(self) noexcept nogil: diff --git a/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py b/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py index 42f9e38aa2265..3622c748f57b2 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py +++ b/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py @@ -5,6 +5,7 @@ from scipy.sparse import issparse, isspmatrix_csr from ... import get_config +from ...utils._openmp_helpers import _openmp_effective_n_threads from .._dist_metrics import BOOL_METRICS, METRIC_MAPPING64 from ._argkmin import ( ArgKmin32, @@ -15,6 +16,10 @@ ArgKminClassMode64, ) from ._base import _sqeuclidean_row_norms32, _sqeuclidean_row_norms64 +from ._pairwise_distances import ( + PairwiseDistances32, + PairwiseDistances64, +) from ._radius_neighbors import ( RadiusNeighbors32, RadiusNeighbors64, @@ -148,6 +153,172 @@ def compute( """ +class PairwiseDistances(BaseDistancesReductionDispatcher): + """Compute the pairwise distances matrix for two sets of vectors. + + The distance function `dist` depends on the values of the `metric` + and `metric_kwargs` parameters. + + This class only computes the pairwise distances matrix without + applying any reduction on it. It shares most of the underlying + code infrastructure with reducing variants to leverage multi-thread + parallelism. However contrary to the reducing variants, no chunking + is applied to allow for contiguous write access to the final distance + array that is not expected to fit in the CPU cache in general. + + This class is not meant to be instantiated, one should only use + its :meth:`compute` classmethod which handles allocation and + deallocation consistently. + """ + + @classmethod + def is_usable_for(cls, X, Y, metric, metric_kwargs=None) -> bool: + """Return True if the dispatcher can be used for the + given parameters. + + Parameters + ---------- + X : {ndarray, sparse matrix} of shape (n_samples_X, n_features) + Input data. + + Y : {ndarray, sparse matrix} of shape (n_samples_Y, n_features) + Input data. + + metric : str, default='euclidean' + The distance metric to use. + For a list of available metrics, see the documentation of + :class:`~sklearn.metrics.DistanceMetric`. + + metric_kwargs : dict, default=None + Keyword arguments to pass to specified metric function. + + Returns + ------- + True if the dispatcher can be used, else False. + """ + effective_n_threads = _openmp_effective_n_threads() + + def is_euclidean(metric, metric_kwargs): + metric_kwargs = metric_kwargs or dict() + euclidean_metrics = [ + "euclidean", + "sqeuclidean", + "l2", + ] + # TODO: pass `p` as a standalone argument instead of a metric_kwargs. + return metric in euclidean_metrics or ( + metric == "minkowski" and metric_kwargs.get("p", 2) == 2 + ) + + Y = X if Y is None else Y + + # We need to rely on `PairwiseDistances` for manhattan anyway because + # the implementation of manhattan distances on sparse data has been removed. + manhattan_metrics = ["cityblock", "l1", "manhattan"] + + is_usable = super().is_usable_for(X, Y, metric) and ( + (not is_euclidean(metric, metric_kwargs) and effective_n_threads != 1) + or metric in manhattan_metrics + ) + + return is_usable + + @classmethod + def compute( + cls, + X, + Y, + metric="euclidean", + metric_kwargs=None, + strategy=None, + ): + """Return pairwise distances matrix for the given arguments. + + Parameters + ---------- + X : ndarray or CSR matrix of shape (n_samples_X, n_features) + Input data. + + Y : ndarray or CSR matrix of shape (n_samples_Y, n_features) + Input data. + + metric : str, default='euclidean' + The distance metric to use. + For a list of available metrics, see the documentation of + :class:`~sklearn.metrics.DistanceMetric`. + + metric_kwargs : dict, default=None + Keyword arguments to pass to specified metric function. + + strategy : str, {'auto', 'parallel_on_X', 'parallel_on_Y'}, default=None + The strategy defining which dataset parallelization are made on. + + For both strategies the computations happens with two nested loops, + respectively on rows of X and rows of Y. + Strategies differs on which loop (outer or inner) is made to run + in parallel with the Cython `prange` construct: + + - 'parallel_on_X' dispatches rows of X uniformly on threads. + Each thread then iterates on all the rows of Y. This strategy is + embarrassingly parallel and comes with no datastructures + synchronisation. + + - 'parallel_on_Y' dispatches rows of Y uniformly on threads. + Each thread processes all the rows of X in turn. This strategy is + a sequence of embarrassingly parallel subtasks (the inner loop on Y + chunks) with no intermediate datastructures synchronisation. + + - 'auto' relies on a simple heuristic to choose between + 'parallel_on_X' and 'parallel_on_Y': when `X.shape[0]` is large enough, + 'parallel_on_X' is usually the most efficient strategy. + When `X.shape[0]` is small but `Y.shape[0]` is large, 'parallel_on_Y' + brings more opportunity for parallelism and is therefore more efficient. + + - None (default) looks-up in scikit-learn configuration for + `pairwise_dist_parallel_strategy`, and use 'auto' if it is not set. + + Returns + ------- + pairwise_distances_matrix : ndarray of shape (n_samples_X, n_samples_Y) + The pairwise distances matrix. + + Notes + ----- + This public classmethod is responsible for introspecting the arguments + values to dispatch to the private dtype-specialized implementation of + :class:`PairwiseDistances`. + + All temporarily allocated datastructures necessary for the concrete + implementations are therefore freed when this classmethod returns. + + This allows entirely decoupling the API entirely from the + implementation details whilst maintaining RAII. + """ + Y = X if Y is None else Y + if X.dtype == Y.dtype == np.float64: + return PairwiseDistances64.compute( + X=X, + Y=Y, + metric=metric, + metric_kwargs=metric_kwargs, + strategy=strategy, + ) + + if X.dtype == Y.dtype == np.float32: + return PairwiseDistances32.compute( + X=X, + Y=Y, + metric=metric, + metric_kwargs=metric_kwargs, + strategy=strategy, + ) + + raise ValueError( + "Only float64 or float32 datasets pairs are supported, but " + f"got: X.dtype={X.dtype} and Y.dtype={Y.dtype}." + ) + + class ArgKmin(BaseDistancesReductionDispatcher): """Compute the argkmin of row vectors of X on the ones of Y. diff --git a/sklearn/metrics/_pairwise_distances_reduction/_pairwise_distances.pxd.tp b/sklearn/metrics/_pairwise_distances_reduction/_pairwise_distances.pxd.tp new file mode 100644 index 0000000000000..797cbe97874ac --- /dev/null +++ b/sklearn/metrics/_pairwise_distances_reduction/_pairwise_distances.pxd.tp @@ -0,0 +1,40 @@ +{{py: + +implementation_specific_values = [ + # Values are the following ones: + # + # name_suffix, INPUT_DTYPE_t + # + # + ('64', 'cnp.float64_t'), + ('32', 'cnp.float32_t') +] + +}} +cimport numpy as cnp + +from ...utils._typedefs cimport intp_t +{{for name_suffix, INPUT_DTYPE_t in implementation_specific_values}} + +from ._datasets_pair cimport DatasetsPair{{name_suffix}} + + +cdef class PairwiseDistances{{name_suffix}}: + """float{{name_suffix}} implementation of PairwiseDistances.""" + + cdef: + readonly DatasetsPair{{name_suffix}} datasets_pair + + intp_t n_samples_X, n_samples_Y + intp_t effective_n_threads + bint X_is_Y + bint execute_in_parallel_on_Y + + {{INPUT_DTYPE_t}}[:, ::1] pairwise_distances_matrix + + cdef void _parallel_on_X(self) nogil + + cdef void _parallel_on_Y(self) nogil + + +{{endfor}} diff --git a/sklearn/metrics/_pairwise_distances_reduction/_pairwise_distances.pyx.tp b/sklearn/metrics/_pairwise_distances_reduction/_pairwise_distances.pyx.tp new file mode 100644 index 0000000000000..bac3904b85e97 --- /dev/null +++ b/sklearn/metrics/_pairwise_distances_reduction/_pairwise_distances.pyx.tp @@ -0,0 +1,181 @@ +cimport numpy as cnp +from cython cimport final +from cython.parallel cimport prange +from ...utils._typedefs cimport intp_t + +import numpy as np + +from scipy.sparse import issparse +from sklearn import get_config +from ...utils import check_array, _in_unstable_openblas_configuration +from ...utils._openmp_helpers import _openmp_effective_n_threads +from ...utils.fixes import threadpool_limits, sp_version, parse_version + +cnp.import_array() + + +def _precompute_metric_params(X, Y, metric=None, **kwds): + """Precompute data-derived metric parameters if not provided.""" + if metric == "seuclidean" and "V" not in kwds: + # There is a bug in scipy < 1.5 that will cause a crash if + # X.dtype != np.double (float64). See PR #15730 + dtype = np.float64 if sp_version < parse_version("1.5") else None + if X is Y: + V = np.var(X, axis=0, ddof=1, dtype=dtype) + else: + raise ValueError( + "The 'V' parameter is required for the seuclidean metric " + "when Y is passed." + ) + return {"V": V} + if metric == "mahalanobis" and "VI" not in kwds: + if X is Y: + VI = np.linalg.inv(np.cov(X.T)).T + else: + raise ValueError( + "The 'VI' parameter is required for the mahalanobis metric " + "when Y is passed." + ) + return {"VI": VI} + return {} + +{{for name_suffix, INPUT_DTYPE in (('64', 'np.float64'),('32', 'np.float32'))}} + +from ._datasets_pair cimport DatasetsPair{{name_suffix}} + + +cdef class PairwiseDistances{{name_suffix}}: + """float{{name_suffix}} implementation of PairwiseDistances.""" + + @classmethod + def compute( + cls, + X, + Y, + str metric="euclidean", + dict metric_kwargs=None, + str strategy=None, + ): + """Compute the pairwise-distances matrix. + + This classmethod is responsible for introspecting the arguments + values to dispatch to the most appropriate implementation of + :class:`PairwiseDistances{{name_suffix}}`. + + This allows decoupling the API entirely from the implementation details + whilst maintaining RAII: all temporarily allocated datastructures necessary + for the concrete implementation are therefore freed when this classmethod + returns. + + No instance should directly be created outside of this class method. + """ + # Precompute data-derived distance metric parameters + metric_kwargs = {} if metric_kwargs is None else metric_kwargs + + params = _precompute_metric_params( + X, + Y, + metric=metric, + **metric_kwargs, + ) + metric_kwargs.update(**params) + + # Fall back on a generic implementation that handles most scipy + # metrics by computing the distances between 2 vectors at a time. + pdr = PairwiseDistances{{name_suffix}}( + datasets_pair=DatasetsPair{{name_suffix}}.get_for(X, Y, metric, metric_kwargs), + strategy=strategy, + ) + + # Limit the number of threads in second level of nested parallelism for BLAS + # to avoid threads over-subscription (in GEMM for instance). + with threadpool_limits(limits=1, user_api="blas"): + if pdr.execute_in_parallel_on_Y: + pdr._parallel_on_Y() + else: + pdr._parallel_on_X() + + return pdr._finalize_results() + + + def __init__( + self, + DatasetsPair{{name_suffix}} datasets_pair, + strategy=None, + sort_results=False, + ): + self.datasets_pair = datasets_pair + self.n_samples_X = datasets_pair.n_samples_X() + self.n_samples_Y = datasets_pair.n_samples_Y() + + self.effective_n_threads = _openmp_effective_n_threads() + + if strategy is None: + strategy = get_config().get("pairwise_dist_parallel_strategy", 'auto') + + if strategy not in ('parallel_on_X', 'parallel_on_Y', 'auto'): + raise RuntimeError(f"strategy must be 'parallel_on_X, 'parallel_on_Y', " + f"or 'auto', but currently strategy='{self.strategy}'.") + + if strategy == 'auto': + # TODO: inspect if the current heuristic is relevant + # for PairwiseDistances + # This is a simple heuristic whose constant for the + # comparison has been chosen based on experiments. + # parallel_on_X has less synchronization overhead than + # parallel_on_Y and should therefore be used whenever + # n_samples_X is large enough to not starve any of the + # available hardware threads. + if self.n_samples_Y < self.n_samples_X: + # No point to even consider parallelizing on Y in this case. This + # is in particular important to do this on machines with a large + # number of hardware threads. + strategy = 'parallel_on_X' + elif 4 * self.effective_n_threads < self.n_samples_X: + # If Y is larger than X, but X is still large enough to allow for + # parallelism, we might still want to favor parallelizing on X. + strategy = 'parallel_on_X' + else: + strategy = 'parallel_on_Y' + + self.execute_in_parallel_on_Y = strategy == "parallel_on_Y" + + # Distance matrix which will be complete and returned to the caller. + self.pairwise_distances_matrix = np.empty( + (self.n_samples_X, self.n_samples_Y), dtype={{INPUT_DTYPE}}, + ) + + cdef void _parallel_on_X(self) nogil: + + cdef: + intp_t n_X = self.n_samples_X + intp_t n_Y = self.n_samples_Y + intp_t i, j + + for i in prange(n_X, nogil=True, num_threads=self.effective_n_threads): + for j in range(n_Y): + self.pairwise_distances_matrix[i, j] = self.datasets_pair.dist(i, j) + + cdef void _parallel_on_Y(self) nogil: + + cdef: + intp_t n_X = self.n_samples_X + intp_t n_Y = self.n_samples_Y + intp_t i, j + + for i in range(n_X): + for j in prange(n_Y, nogil=True, num_threads=self.effective_n_threads): + self.pairwise_distances_matrix[i, j] = self.datasets_pair.dist(i, j) + + def _finalize_results(self): + # If X is Y, then catastrophic cancellation might have occurred for + # computations of terms on the diagonal which must equal zero. + # We enforce it by zeroing the diagonal. + distance_matrix = np.asarray(self.pairwise_distances_matrix) + if self.datasets_pair.X_is_Y: + np.fill_diagonal(distance_matrix, 0.) + + return distance_matrix + + +{{endfor}} diff --git a/sklearn/metrics/_pairwise_fast.pyx b/sklearn/metrics/_pairwise_fast.pyx index d5290d94679c9..e4074d2035557 100644 --- a/sklearn/metrics/_pairwise_fast.pyx +++ b/sklearn/metrics/_pairwise_fast.pyx @@ -6,10 +6,6 @@ cimport numpy as cnp from cython cimport floating -from cython.parallel cimport prange -from libc.math cimport fabs - -from ..utils._openmp_helpers import _openmp_effective_n_threads cnp.import_array() @@ -33,79 +29,3 @@ def _chi2_kernel_fast(floating[:, :] X, if nom != 0: res += denom * denom / nom result[i, j] = -res - - -def _sparse_manhattan( - const floating[::1] X_data, - const int[:] X_indices, - const int[:] X_indptr, - const floating[::1] Y_data, - const int[:] Y_indices, - const int[:] Y_indptr, - double[:, ::1] D, -): - """Pairwise L1 distances for CSR matrices. - - Usage: - >>> D = np.zeros(X.shape[0], Y.shape[0]) - >>> _sparse_manhattan(X.data, X.indices, X.indptr, - ... Y.data, Y.indices, Y.indptr, - ... D) - """ - cdef cnp.npy_intp px, py, i, j, ix, iy - cdef double d = 0.0 - - cdef int m = D.shape[0] - cdef int n = D.shape[1] - - cdef int X_indptr_end = 0 - cdef int Y_indptr_end = 0 - - cdef int num_threads = _openmp_effective_n_threads() - - # We scan the matrices row by row. - # Given row px in X and row py in Y, we find the positions (i and j - # respectively), in .indices where the indices for the two rows start. - # If the indices (ix and iy) are the same, the corresponding data values - # are processed and the cursors i and j are advanced. - # If not, the lowest index is considered. Its associated data value is - # processed and its cursor is advanced. - # We proceed like this until one of the cursors hits the end for its row. - # Then we process all remaining data values in the other row. - - # Below the avoidance of inplace operators is intentional. - # When prange is used, the inplace operator has a special meaning, i.e. it - # signals a "reduction" - - for px in prange(m, nogil=True, num_threads=num_threads): - X_indptr_end = X_indptr[px + 1] - for py in range(n): - Y_indptr_end = Y_indptr[py + 1] - i = X_indptr[px] - j = Y_indptr[py] - d = 0.0 - while i < X_indptr_end and j < Y_indptr_end: - ix = X_indices[i] - iy = Y_indices[j] - - if ix == iy: - d = d + fabs(X_data[i] - Y_data[j]) - i = i + 1 - j = j + 1 - elif ix < iy: - d = d + fabs(X_data[i]) - i = i + 1 - else: - d = d + fabs(Y_data[j]) - j = j + 1 - - if i == X_indptr_end: - while j < Y_indptr_end: - d = d + fabs(Y_data[j]) - j = j + 1 - else: - while i < X_indptr_end: - d = d + fabs(X_data[i]) - i = i + 1 - - D[px, py] = d diff --git a/sklearn/metrics/pairwise.py b/sklearn/metrics/pairwise.py index 3cdd5dd69edf0..60088ca1fa456 100644 --- a/sklearn/metrics/pairwise.py +++ b/sklearn/metrics/pairwise.py @@ -41,8 +41,12 @@ from ..utils.fixes import parse_version, sp_base_version from ..utils.parallel import Parallel, delayed from ..utils.validation import _num_samples, check_non_negative -from ._pairwise_distances_reduction import ArgKmin -from ._pairwise_fast import _chi2_kernel_fast, _sparse_manhattan +from ._pairwise_distances_reduction import ( + ArgKmin, + PairwiseDistances, + _precompute_metric_params, +) +from ._pairwise_fast import _chi2_kernel_fast # Utility Functions @@ -347,6 +351,21 @@ def _euclidean_distances(X, Y, X_norm_squared=None, Y_norm_squared=None, squared float32, norms needs to be recomputed on upcast chunks. TODO: use a float64 accumulator in row_norms to avoid the latter. """ + metric = "sqeuclidean" if squared else "euclidean" + if PairwiseDistances.is_usable_for(X, Y, metric): + metric_kwargs = {} + if X_norm_squared is not None: + metric_kwargs["X_norm_squared"] = np.ravel(X_norm_squared) + + if Y_norm_squared is not None: + metric_kwargs["Y_norm_squared"] = np.ravel(Y_norm_squared) + + return PairwiseDistances.compute(X, Y, metric, metric_kwargs=metric_kwargs) + + # XXX: the following code is still used for list-of-lists of numbers which + # aren't converted to numpy arrays in validation steps done in `check_array`. + # TODO: convert list-of-lists to numpy arrays in `check_array`. + # See: https://github.com/scikit-learn/scikit-learn/issues/24745 if X_norm_squared is not None: if X_norm_squared.dtype == np.float32: XX = None @@ -959,6 +978,15 @@ def haversine_distances(X, Y=None): array([[ 0. , 11099.54035582], [11099.54035582, 0. ]]) """ + + if PairwiseDistances.is_usable_for(X, Y, metric="haversine"): + return PairwiseDistances.compute(X, Y, metric="haversine") + + # XXX: the following code is still used for list-of-lists of numbers which + # aren't converted to numpy arrays in validation steps done in `check_array`. + # TODO: convert list-of-lists to numpy arrays in `check_array`. + # See: https://github.com/scikit-learn/scikit-learn/issues/24745 + from ..metrics import DistanceMetric return DistanceMetric.get_metric("haversine").pairwise(X, Y) @@ -1041,20 +1069,30 @@ def manhattan_distances(X, Y=None, *, sum_over_features="deprecated"): X, Y = check_pairwise_arrays(X, Y) - if issparse(X) or issparse(Y): - if not sum_over_features: - raise TypeError( - "sum_over_features=%r not supported for sparse matrices" - % sum_over_features - ) - + if issparse(X): X = csr_matrix(X, copy=False) + # This also sorts indices in-place. + X.sum_duplicates() + + if issparse(Y): Y = csr_matrix(Y, copy=False) - X.sum_duplicates() # this also sorts indices in-place + # This also sorts indices in-place. Y.sum_duplicates() - D = np.zeros((X.shape[0], Y.shape[0])) - _sparse_manhattan(X.data, X.indices, X.indptr, Y.data, Y.indices, Y.indptr, D) - return D + + if sum_over_features and PairwiseDistances.is_usable_for(X, Y, metric="manhattan"): + return PairwiseDistances.compute(X, Y, metric="manhattan") + + # XXX: the following code is still used for list-of-lists of numbers which + # aren't converted to numpy arrays in validation steps done in `check_array` + # and for supporting `sum_over_features` which we should probably remove. + # TODO: convert list-of-lists to numpy arrays in `check_array`. + # See: https://github.com/scikit-learn/scikit-learn/issues/24745 + # TODO: remove `sum_over_features`, see: + # https://github.com/scikit-learn/scikit-learn/issues/24597 + + if issparse(X) or issparse(Y): + if not sum_over_features: + raise TypeError("sum_over_features=False not supported for sparse matrices") if sum_over_features: return distance.cdist(X, Y, "cityblock") @@ -1820,29 +1858,6 @@ def _check_chunk_size(reduced, chunk_size): ) -def _precompute_metric_params(X, Y, metric=None, **kwds): - """Precompute data-derived metric parameters if not provided.""" - if metric == "seuclidean" and "V" not in kwds: - if X is Y: - V = np.var(X, axis=0, ddof=1) - else: - raise ValueError( - "The 'V' parameter is required for the seuclidean metric " - "when Y is passed." - ) - return {"V": V} - if metric == "mahalanobis" and "VI" not in kwds: - if X is Y: - VI = np.linalg.inv(np.cov(X.T)).T - else: - raise ValueError( - "The 'VI' parameter is required for the mahalanobis metric " - "when Y is passed." - ) - return {"VI": VI} - return {} - - def pairwise_distances_chunked( X, Y=None, @@ -2141,6 +2156,27 @@ def pairwise_distances( % (metric, _VALID_METRICS) ) + if PairwiseDistances.is_usable_for(X, Y, metric=metric, metric_kwargs=kwds): + # This is an adaptor for one "sqeuclidean" specification. + # For this backend, we can directly use "sqeuclidean". + if kwds.get("squared", False) and metric == "euclidean": + # TODO: use 'sqeuclidean' instead of 'euclidean' + # with EuclideanPairwiseDistances + metric = "sqeuclidean" + kwds = {} + + if issparse(X): + X = csr_matrix(X, copy=False) + # This also sorts indices in-place. + X.sum_duplicates() + + if issparse(Y): + Y = csr_matrix(Y, copy=False) + # This also sorts indices in-place. + Y.sum_duplicates() + + return PairwiseDistances.compute(X, Y, metric=metric, metric_kwargs=kwds) + if metric == "precomputed": X, _ = check_pairwise_arrays( X, Y, precomputed=True, force_all_finite=force_all_finite @@ -2159,6 +2195,11 @@ def pairwise_distances( _pairwise_callable, metric=metric, force_all_finite=force_all_finite, **kwds ) else: + # XXX: the following code is still used for list-of-lists of numbers which + # aren't converted to numpy arrays in validation steps done in `check_array`. + # TODO: convert list-of-lists to numpy arrays in `check_array`. + # See: https://github.com/scikit-learn/scikit-learn/issues/24745 + if issparse(X) or issparse(Y): raise TypeError("scipy distance metrics do not support sparse matrices.") diff --git a/sklearn/metrics/tests/test_pairwise.py b/sklearn/metrics/tests/test_pairwise.py index 1574d007bfdfb..2c9a08602309e 100644 --- a/sklearn/metrics/tests/test_pairwise.py +++ b/sklearn/metrics/tests/test_pairwise.py @@ -162,25 +162,12 @@ def test_pairwise_distances(global_dtype): S = pairwise_distances(X_sparse, Y_sparse.tocsc(), metric="manhattan") S2 = manhattan_distances(X_sparse.tobsr(), Y_sparse.tocoo()) assert_allclose(S, S2) - if global_dtype == np.float64: - assert S.dtype == S2.dtype == global_dtype - else: - # TODO Fix manhattan_distances to preserve dtype. - # currently pairwise_distances uses manhattan_distances but converts the result - # back to the input dtype - with pytest.raises(AssertionError): - assert S.dtype == S2.dtype == global_dtype + + assert S.dtype == S2.dtype == global_dtype S2 = manhattan_distances(X, Y) assert_allclose(S, S2) - if global_dtype == np.float64: - assert S.dtype == S2.dtype == global_dtype - else: - # TODO Fix manhattan_distances to preserve dtype. - # currently pairwise_distances uses manhattan_distances but converts the result - # back to the input dtype - with pytest.raises(AssertionError): - assert S.dtype == S2.dtype == global_dtype + assert S.dtype == S2.dtype == global_dtype # Test with scipy.spatial.distance metric, with a kwd kwds = {"p": 2.0} @@ -194,11 +181,11 @@ def test_pairwise_distances(global_dtype): S2 = pairwise_distances(X, metric=minkowski, **kwds) assert_allclose(S, S2) - # Test that scipy distance metrics throw an error if sparse matrix given - with pytest.raises(TypeError): - pairwise_distances(X_sparse, metric="minkowski") - with pytest.raises(TypeError): - pairwise_distances(X, Y_sparse, metric="minkowski") + # Test PairwiseDistance + kwds = {"p": 3.0} + S = pairwise_distances(X, metric="minkowski", **kwds) + S2 = pairwise_distances(X, metric=minkowski, **kwds) + assert_allclose(S, S2) # Test that a value error is raised if the metric is unknown with pytest.raises(ValueError): @@ -948,15 +935,10 @@ def test_euclidean_distances_upcast_sym(batch_size, x_array_constr): "dtype, eps, rtol", [ (np.float32, 1e-4, 1e-5), - pytest.param( - np.float64, - 1e-8, - 0.99, - marks=pytest.mark.xfail(reason="failing due to lack of precision"), - ), + (np.float64, 1e-8, 0.99), ], ) -@pytest.mark.parametrize("dim", [1, 1000000]) +@pytest.mark.parametrize("dim", [1, 100000]) def test_euclidean_distances_extreme_values(dtype, eps, rtol, dim): # check that euclidean distances is correct with float32 input thanks to # upcasting. On float64 there are still precision issues. @@ -966,7 +948,7 @@ def test_euclidean_distances_extreme_values(dtype, eps, rtol, dim): distances = euclidean_distances(X, Y) expected = cdist(X, Y) - assert_allclose(distances, expected, rtol=1e-5) + assert_allclose(distances, expected, rtol=1e-5, atol=6e-4) @pytest.mark.parametrize("squared", [True, False]) diff --git a/sklearn/metrics/tests/test_pairwise_distances_reduction.py b/sklearn/metrics/tests/test_pairwise_distances_reduction.py index 5fcf980fbe39b..0900f5d3122c0 100644 --- a/sklearn/metrics/tests/test_pairwise_distances_reduction.py +++ b/sklearn/metrics/tests/test_pairwise_distances_reduction.py @@ -2,6 +2,7 @@ import re import warnings from collections import defaultdict +from functools import partial from math import floor, log10 import numpy as np @@ -15,9 +16,11 @@ ArgKmin, ArgKminClassMode, BaseDistancesReductionDispatcher, + PairwiseDistances, RadiusNeighbors, sqeuclidean_row_norms, ) +from sklearn.utils._openmp_helpers import _openmp_effective_n_threads from sklearn.utils._testing import ( assert_allclose, assert_array_equal, @@ -307,6 +310,11 @@ def assert_radius_neighbors_results_quasi_equality( ): assert_radius_neighbors_results_quasi_equality, } +ASSERT_RESULT_PAIRWISE = { + np.float32: partial(assert_allclose, rtol=1e-4), + np.float64: assert_array_equal, +} + def test_assert_argkmin_results_quasi_equality(): rtol = 1e-7 @@ -851,6 +859,44 @@ def test_radius_neighbors_factory_method_wrong_usages(): ) +def test_pairwise_distances_factory_method_wrong_usages(): + rng = np.random.RandomState(1) + X = rng.rand(100, 10) + Y = rng.rand(100, 10) + metric = "euclidean" + + msg = ( + "Only float64 or float32 datasets pairs are supported, but " + "got: X.dtype=float32 and Y.dtype=float64" + ) + with pytest.raises( + ValueError, + match=msg, + ): + PairwiseDistances.compute(X=X.astype(np.float32), Y=Y, metric=metric) + + msg = ( + "Only float64 or float32 datasets pairs are supported, but " + "got: X.dtype=float64 and Y.dtype=int32" + ) + with pytest.raises( + ValueError, + match=msg, + ): + PairwiseDistances.compute(X=X, Y=Y.astype(np.int32), metric=metric) + + with pytest.raises(ValueError, match="Unrecognized metric"): + PairwiseDistances.compute(X=X, Y=Y, metric="wrong metric") + + with pytest.raises( + ValueError, match=r"Buffer has wrong number of dimensions \(expected 2, got 1\)" + ): + PairwiseDistances.compute(X=np.array([1.0, 2.0]), Y=Y, metric=metric) + + with pytest.raises(ValueError, match="ndarray is not C-contiguous"): + PairwiseDistances.compute(X=np.asfortranarray(X), Y=Y, metric=metric) + + @pytest.mark.parametrize( "n_samples_X, n_samples_Y", [(100, 100), (500, 100), (100, 500)] ) @@ -960,6 +1006,29 @@ def test_n_threads_agnosticism( ) +@pytest.mark.parametrize("dtype", [np.float64, np.float32]) +def test_n_threads_agnosticism_pairwise_distances( + global_random_seed, + dtype, + n_features=100, +): + """Check that results do not depend on the number of threads.""" + # TODO: Parametrize `n_samples_X` and `n_samples_Y` when the + # strategy heuristic has been inspected. + n_samples_X, n_samples_Y = 100, 100 + rng = np.random.RandomState(global_random_seed) + spread = 100 + X = rng.rand(n_samples_X, n_features).astype(dtype) * spread + Y = rng.rand(n_samples_Y, n_features).astype(dtype) * spread + + ref_dist = PairwiseDistances.compute(X, Y) + + with threadpoolctl.threadpool_limits(limits=1, user_api="openmp"): + dist = PairwiseDistances.compute(X, Y) + + ASSERT_RESULT_PAIRWISE[dtype](ref_dist, dist) + + @pytest.mark.parametrize( "Dispatcher, dtype", [ @@ -1025,6 +1094,31 @@ def test_format_agnosticism( ) +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +def test_format_agnosticism_pairwise_distances( + global_random_seed, + dtype, +): + """Check that results do not depend on the format (dense, sparse) of the input.""" + rng = np.random.RandomState(global_random_seed) + spread = 100 + n_samples, n_features = 100, 100 + + X = rng.rand(n_samples, n_features).astype(dtype) * spread + Y = rng.rand(n_samples, n_features).astype(dtype) * spread + + X_csr = csr_matrix(X) + Y_csr = csr_matrix(Y) + + dist_dense = PairwiseDistances.compute(X, Y) + + for _X, _Y in itertools.product((X, X_csr), (Y, Y_csr)): + if _X is X and _Y is Y: + continue + dist = PairwiseDistances.compute(_X, _Y) + ASSERT_RESULT_PAIRWISE[dtype](dist, dist_dense) + + @pytest.mark.parametrize( "n_samples_X, n_samples_Y", [(100, 100), (100, 500), (500, 100)] ) @@ -1102,6 +1196,58 @@ def test_strategies_consistency( ) +@pytest.mark.parametrize( + "metric", + ["euclidean", "minkowski", "manhattan", "infinity", "seuclidean", "haversine"], +) +@pytest.mark.parametrize("dtype", [np.float64, np.float32]) +def test_strategies_consistency_pairwise_distances( + global_random_seed, + metric, + dtype, + n_features=10, +): + """Check that the results do not depend on the strategy used.""" + # TODO: Parametrize `n_samples_X` and `n_samples_Y` when the + # strategy heuristic has been inspected. + n_samples_X, n_samples_Y = 100, 100 + rng = np.random.RandomState(global_random_seed) + spread = 100 + X = rng.rand(n_samples_X, n_features).astype(dtype) * spread + Y = rng.rand(n_samples_Y, n_features).astype(dtype) * spread + + # Haversine distance only accepts 2D data + if metric == "haversine": + X = np.ascontiguousarray(X[:, :2]) + Y = np.ascontiguousarray(Y[:, :2]) + + dist_par_X = PairwiseDistances.compute( + X, + Y, + metric=metric, + # Taking the first + metric_kwargs=_get_metric_params_list( + metric, n_features, seed=global_random_seed + )[0], + # To be sure to use parallelization + strategy="parallel_on_X", + ) + + dist_par_Y = PairwiseDistances.compute( + X, + Y, + metric=metric, + # Taking the first + metric_kwargs=_get_metric_params_list( + metric, n_features, seed=global_random_seed + )[0], + # To be sure to use parallelization + strategy="parallel_on_Y", + ) + + ASSERT_RESULT_PAIRWISE[dtype](dist_par_X, dist_par_Y) + + # "Concrete Dispatchers"-specific tests @@ -1294,6 +1440,37 @@ def test_memmap_backed_data( ) +@pytest.mark.parametrize("metric", ["manhattan", "euclidean"]) +@pytest.mark.parametrize("dtype", [np.float64, np.float32]) +def test_memmap_backed_data_pairwise_distances( + metric, + dtype, +): + """Check that the results do not depend on the datasets writability.""" + rng = np.random.RandomState(0) + spread = 100 + n_samples, n_features = 128, 10 + X = rng.rand(n_samples, n_features).astype(dtype) * spread + Y = rng.rand(n_samples, n_features).astype(dtype) * spread + + # Create read only datasets + X_mm, Y_mm = create_memmap_backed_data([X, Y]) + + ref_dist = PairwiseDistances.compute( + X, + Y, + metric=metric, + ) + + dist_mm = PairwiseDistances.compute( + X_mm, + Y_mm, + metric=metric, + ) + + ASSERT_RESULT_PAIRWISE[dtype](ref_dist, dist_mm) + + @pytest.mark.parametrize("n_samples", [100, 1000]) @pytest.mark.parametrize("n_features", [5, 10, 100]) @pytest.mark.parametrize("num_threads", [1, 2, 8]) @@ -1324,6 +1501,38 @@ def test_sqeuclidean_row_norms( sqeuclidean_row_norms(X, num_threads=num_threads) +@pytest.mark.parametrize("dtype", [np.float64, np.float32]) +def test_pairwise_distances_is_usable_for( + global_random_seed, + dtype, + monkeypatch, +): + rng = np.random.RandomState(global_random_seed) + n_samples, n_features = 100, 10 + X = rng.rand(n_samples, n_features).astype(dtype) + + # Equivalent specifications of the Euclidean metric. + # TODO: support Euclidean metric. + assert not PairwiseDistances.is_usable_for(X, X, metric="euclidean") + assert not PairwiseDistances.is_usable_for(X, X, metric="minkowski") + assert not PairwiseDistances.is_usable_for( + X, X, metric="minkowski", metric_kwargs={"p": 2} + ) + + # PairwiseDistances must not be used for sequential execution because + # They are not yet competitive with the previous joblib-based back-end. + # TODO: make PairwiseDistances competitive for sequential execution. + assert PairwiseDistances.is_usable_for( + X, X, metric="minkowski", metric_kwargs={"p": 3} + ) == (_openmp_effective_n_threads() != 1) + + with threadpoolctl.threadpool_limits(limits=1, user_api=None): + assert PairwiseDistances.is_usable_for(X, X, metric="manhattan") + + with threadpoolctl.threadpool_limits(limits=2, user_api=None): + assert PairwiseDistances.is_usable_for(X, X, metric="manhattan") + + def test_argkmin_classmode_strategy_consistent(): rng = np.random.RandomState(1) X = rng.rand(100, 10)