8000 MAINT Plug `PairwiseDistancesArgKmin` as a back-end by jjerphan · Pull Request #22288 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

MAINT Plug PairwiseDistancesArgKmin as a back-end #22288

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
e93ea0f
Forward pairwise_dist_chunk_size in the configuration
jjerphan Jan 14, 2022
e39bf7d
Flip finalized results for PairwiseDistancesArgKmin
jjerphan Jan 14, 2022
c762c40
Plug PairwiseDistancesArgKmin as a back-end
jjerphan Jan 14, 2022
64cde3b
Adapt test accordingly
jjerphan Jan 14, 2022
d392122
Add whats_new entry
jjerphan Jan 14, 2022
5e1d071
Change input validation order for kneighbors
ogrisel Jan 18, 2022
3fa89b6
Merge pull request #8 from ogrisel/kneighbors-input-validation
jjerphan Jan 18, 2022
0e8ebb5
Remove duplicated test_neighbors_distance_metric_deprecation
jjerphan Jan 17, 2022
f15f271
Adapt the documentation
jjerphan Jan 17, 2022
0f0e440
Add mahalanobis case to test fixtures
jjerphan Jan 18, 2022
3448b01
Correct whats_new entry
jjerphan Jan 19, 2022
afdaaa1
CLN Remove unneeded private metric attribute
jjerphan Jan 20, 2022
ba6c463
Merge branch 'pairwise-distances-argkmin' into pairwise-distances-arg…
jjerphan Jan 24, 2022
34566a7
TST Assert FutureWarning instead of DeprecationWarning in
jjerphan Jan 24, 2022
3524735
MAINT Add use_pairwise_dist_activate to scikit-learn config
jjerphan Jan 25, 2022
6b396b0
TST Add a test for the 'brute' backends' results' consistency
jjerphan Jan 25, 2022
aa1f86f
fixup! MAINT Add use_pairwise_dist_activate to scikit-learn config
jjerphan Jan 25, 2022
305d217
fixup! fixup! MAINT Add use_pairwise_dist_activate to scikit-learn co…
jjerphan Jan 25, 2022
eced316
TST Filter FutureWarning for WMinkowskiDistance
jjerphan Jan 25, 2022
b84259d
MAINT pin numpydoc in arm for now (#22292)
thomasjpfan Jan 25, 2022
d63713e
fixup! TST Filter FutureWarning for WMinkowskiDistance
jjerphan Jan 25, 2022
4ad3509
Revert keywords arguments removal for the GEMM trick for 'euclidean'
jjerphan Jan 25, 2022
948b04c
MAINT pin max numpydoc for now (#22286)
ogrisel Jan 24, 2022
ca45a58
Add 'haversine' to CDIST_PAIRWISE_DISTANCES_REDUCTION_COMMON_METRICS
jjerphan Jan 26, 2022
ee3c43d
fixup! Add 'haversine' to CDIST_PAIRWISE_DISTANCES_REDUCTION_COMMON_M…
jjerphan Jan 26, 2022
c8de77c
Apply suggestions from code review
jjerphan Jan 27, 2022
2feec54
MAINT Document some config parameters for maintenance
jjerphan Jan 27, 2022
7a4c137
FIX Support and test one of 'sqeuclidean' specification
jjerphan Jan 27, 2022
ea762b7
FIX Various typos fix and correct haversine
jjerphan Jan 27, 2022
b9cb0f4
Directly use get_config
jjerphan Jan 27, 2022
e793927
Merge branch 'main' into pairwise-distances-argkmin-plug-contd
jjerphan Jan 28, 2022
3b5e738
Merge branch 'pairwise-distances-argkmin' into pairwise-distances-arg…
jjerphan Jan 28, 2022
2df70b1
CLN Apply comments from review
jjerphan Feb 1, 2022
b6e6f3d
Motivate swapped returned values
jjerphan Feb 1, 2022
16d777f
TST Remove mahalanobis from test fixtures
jjerphan Feb 2, 2022
bd02da0
MNT Add comment regaduction functions' signatures
jjerphan Feb 2, 2022
5ea0427
TST Complete test for `pairwise_distance_{argmin,argmin_min}` (#22371)
jjerphan Feb 8, 2022
5f5a83f
DOC Add sub-pull requests to the whats_new entry
jjerphan Feb 9, 2022
b8e9fb8
Merge branch 'pairwise-distances-argkmin' into pairwise-distances-arg…
jjerphan Feb 9, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions doc/whats_new/v1.1.rst
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,38 @@ Changelog
left corner of the HTML representation to show how the elements are
clickable. :pr:`21298` by `Thomas Fan`_.

Miscellaneous
.............

- |Efficiency| Low-level routines for reductions on pairwise distances
for dense float64 datasets have been refactored. The following functions
and estimators now benefit from improved performances, in particular on
multi-cores machines:
- :func:`sklearn.metrics.pairwise_distances_argmin`
- :func:`sklearn.metrics.pairwise_distances_argmin_min`
- :class:`sklearn.cluster.AffinityPropagation`
- :class:`sklearn.cluster.Birch`
- :class:`sklearn.cluster.MeanShift`
- :class:`sklearn.cluster.OPTICS`
- :class:`sklearn.cluster.SpectralClustering`
- :func:`sklearn.feature_selection.mutual_info_regression`
- :class:`sklearn.neighbors.KNeighborsClassifier`
- :class:`sklearn.neighbors.KNeighborsRegressor`
- :class:`sklearn.neighbors.LocalOutlierFactor`
- :class:`sklearn.neighbors.NearestNeighbors`
- :class:`sklearn.manifold.Isomap`
- :class:`sklearn.manifold.LocallyLinearEmbedding`
- :class:`sklearn.manifold.TSNE`
- :func:`sklearn.manifold.trustworthiness`
- :class:`sklearn.semi_supervised.LabelPropagation`
- :class:`sklearn.semi_supervised.LabelSpreading`

For instance :class:`sklearn.neighbors.NearestNeighbors.kneighbors`
can be up to ×20 faster than in the previous versions'.

:pr:`21987`, :pr:`22064`, :pr:`22065` and :pr:`22288`
by :user:`Julien Jerphanion <jjerphan>`

- |Fix| :func:`check_scalar` raises an error when `include_boundaries={"left", "right"}`
and the boundaries are not set.
:pr:`22027` by `Marie Lanternier <mlant>`.
Expand Down
37 changes: 36 additions & 1 deletion sklearn/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
"pairwise_dist_chunk_size": int(
os.environ.get("SKLEARN_PAIRWISE_DIST_CHUNK_SIZE", 256)
),
"enable_cython_pairwise_dist": True,
}
_threadlocal = threading.local()

Expand Down Expand Up @@ -48,6 +49,7 @@ def set_config(
print_changed_only=None,
display=None,
pairwise_dist_chunk_size=None,
enable_cython_pairwise_dist=None,
):
"""Set global scikit-learn configuration

Expand Down Expand Up @@ -88,9 +90,23 @@ def set_config(
.. versionadded:: 0.23

pairwise_dist_chunk_size : int, default=None
The number of vectors per chunk for PairwiseDistancesReduction.
The number of row vectors per chunk for PairwiseDistancesReduction.
Default is 256 (suitable for most of modern laptops' caches and architectures).

Intended for easier benchmarking and testing of scikit-learn internals.
End users are not expected to benefit from customizing this configuration
setting.

.. versionadded:: 1.1

enable_cython_pairwise_dist : bool, default=None
Use PairwiseDistancesReduction when possible.
Default is True.

Intended for easier benchmarking and testing of scikit-learn internals.
End users are not expected to benefit from customizing this configuration
setting.

.. versionadded:: 1.1

See Also
Expand All @@ -110,6 +126,8 @@ def set_config(
local_config["display"] = display
if pairwise_dist_chunk_size is not None:
local_config["pairwise_dist_chunk_size"] = pairwise_dist_chunk_size
if enable_cython_pairwise_dist is not None:
local_config["enable_cython_pairwise_dist"] = enable_cython_pairwise_dist


@contextmanager
Expand All @@ -120,6 +138,7 @@ def config_context(
print_changed_only=None,
display=None,
pairwise_dist_chunk_size=None,
enable_cython_pairwise_dist=None,
):
"""Context manager for global scikit-learn configuration.

Expand Down Expand Up @@ -162,6 +181,20 @@ def config_context(
The number of vectors per chunk for PairwiseDistancesReduction.
Default is 256 (suitable for most of modern laptops' caches and architectures).

Intended for easier benchmarking and testing of scikit-learn internals.
End users are not expected to benefit from customizing this configuration
setting.

.. versionadded:: 1.1

enable_cython_pairwise_dist : bool, default=None
Use PairwiseDistancesReduction when possible.
Default is True.

Intended for easier benchmarking and testing of scikit-learn internals.
End users are not expected to benefit from customizing this configuration
setting.

.. versionadded:: 1.1

Yields
Expand Down Expand Up @@ -197,6 +230,8 @@ def config_context(
working_memory=working_memory,
print_changed_only=print_changed_only,
display=display,
pairwise_dist_chunk_size=pairwise_dist_chunk_size,
enable_cython_pairwise_dist=enable_cython_pairwise_dist,
)

try:
Expand Down
1 change: 1 addition & 0 deletions sklearn/metrics/_dist_metrics.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ METRIC_MAPPING = {'euclidean': EuclideanDistance,
'pyfunc': PyFuncDistance}

BOOL_METRICS = [
"hamming",
"matching",
"jaccard",
"dice",
Expand Down
15 changes: 10 additions & 5 deletions sklearn/metrics/_pairwise_distances_reduction.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,8 @@ cdef class PairwiseDistancesReduction:
True if the PairwiseDistancesReduction can be used, else False.
"""
# TODO: support sparse arrays and 32 bits
return (not issparse(X) and X.dtype == np.float64 and
return (get_config().get("enable_cython_pairwise_dist", True) and
not issparse(X) and X.dtype == np.float64 and
not issparse(Y) and Y.dtype == np.float64 and
metric in cls.valid_metrics())

Expand Down Expand Up @@ -621,10 +622,10 @@ cdef class PairwiseDistancesArgKmin(PairwiseDistancesReduction):
Indices of the argkmin for each vector in X.

If return_distance=True:
- argkmin_indices : ndarray of shape (n_samples_X, k)
Indices of the argkmin for each vector in X.
- argkmin_distances : ndarray of shape (n_samples_X, k)
Distances to the argkmin for each vector in X.
- argkmin_indices : ndarray of shape (n_samples_X, k)
Indices of the argkmin for each vector in X.

Notes
-----
Expand All @@ -642,7 +643,7 @@ cdef class PairwiseDistancesArgKmin(PairwiseDistancesReduction):
# Note (jjerphan): Some design thoughts for future extensions.
# This factory comes to handle specialisations for the given arguments.
# For future work, this might can be an entrypoint to specialise operations
# for various back-end and/or hardware and/or datatypes, and/or fused
# for various backend and/or hardware and/or datatypes, and/or fused
# {sparse, dense}-datasetspair etc.
if (
metric in ("euclidean", "sqeuclidean")
Expand Down Expand Up @@ -883,7 +884,11 @@ cdef class PairwiseDistancesArgKmin(PairwiseDistancesReduction):
# We need to recompute distances because we relied on
# surrogate distances for the reduction.
self.compute_exact_distances()
return np.asarray(self.argkmin_indices), np.asarray(self.argkmin_distances)

# Values are returned identically to the way `KNeighborsMixin.kneighbors`
# returns values. This is counter-intuitive but this allows not using
# complex adaptations where `PairwiseDistancesArgKmin.compute` is called.
return np.asarray(self.argkmin_distances), np.asarray(self.argkmin_indices)

return np.asarray(self.argkmin_indices)

Expand Down
102 changes: 90 additions & 12 deletions sklearn/metrics/pairwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from scipy.sparse import issparse
from joblib import Parallel, effective_n_jobs

from .. import config_context
from ..utils.validation import _num_samples
from ..utils.validation import check_non_negative
from ..utils import check_array
Expand All @@ -31,6 +32,7 @@
from ..utils.fixes import delayed
from ..utils.fixes import sp_version, parse_version

from ._pairwise_distances_reduction import PairwiseDistancesArgKmin
from ._pairwise_fast import _chi2_kernel_fast, _sparse_manhattan
from ..exceptions import DataConversionWarning

Expand Down Expand Up @@ -576,12 +578,23 @@ def _euclidean_distances_upcast(X, XX=None, Y=None, YY=None, batch_size=None):
return distances


# start is specified in the signature of `_argmin_min_reduce`
# and of `_argmin_reduce` but is not used.
# This is because the higher order `pairwise_distances_chunked`
# function needs reduction functions that are passed as argument
# to have a two arguments signature.


def _argmin_min_reduce(dist, start):
indices = dist.argmin(axis=1)
values = dist[np.arange(dist.shape[0]), indices]
return indices, values


def _argmin_reduce(dist, start):
return dist.argmin(axis=1)


def pairwise_distances_argmin_min(
X, Y, *, axis=1, metric="euclidean", metric_kwargs=None
):
Expand Down Expand Up @@ -654,19 +667,44 @@ def pairwise_distances_argmin_min(
"""
X, Y = check_pairwise_arrays(X, Y)

if metric_kwargs is None:
metric_kwargs = {}

if axis == 0:
X, Y = Y, X

indices, values = zip(
*pairwise_distances_chunked(
X, Y, reduce_func=_argmin_min_reduce, metric=metric, **metric_kwargs
if metric_kwargs is None:
metric_kwargs = {}

if PairwiseDistancesArgKmin.is_usable_for(X, Y, metric):
# This is an adaptor for one "sqeuclidean" specification.
# For this backend, we can directly use "sqeuclidean".
if metric_kwargs.get("squared", False) and metric == "euclidean":
metric = "sqeuclidean"
metric_kwargs = {}

values, indices = PairwiseDistancesArgKmin.compute(
X=X,
Y=Y,
k=1,
metric=metric,
metric_kwargs=metric_kwargs,
strategy="auto",
return_distance=True,
)
)
indices = np.concatenate(indices)
values = np.concatenate(values)
values = values.flatten()
indices = indices.flatten()
else:
# TODO: once PairwiseDistancesArgKmin supports sparse input matrices and 32 bit,
# we won't need to fallback to pairwise_distances_chunked anymore.

# Turn off check for finiteness because this is costly and because arrays
# have already been validated.
with config_context(assume_finite=True):
indices, values = zip(
*pairwise_distances_chunked(
X, Y, reduce_func=_argmin_min_reduce, metric=metric, **metric_kwargs
)
)
indices = np.concatenate(indices)
values = np.concatenate(values)

return indices, values

Expand Down Expand Up @@ -738,9 +776,49 @@ def pairwise_distances_argmin(X, Y, *, axis=1, metric="euclidean", metric_kwargs
if metric_kwargs is None:
metric_kwargs = {}

return pairwise_distances_argmin_min(
X, Y, axis=axis, metric=metric, metric_kwargs=metric_kwargs
)[0]
X, Y = check_pairwise_arrays(X, Y)

if axis == 0:
X, Y = Y, X

if metric_kwargs is None:
metric_kwargs = {}

if PairwiseDistancesArgKmin.is_usable_for(X, Y, metric):
# This is an adaptor for one "sqeuclidean" specification.
# For this backend, we can directly use "sqeuclidean".
if metric_kwargs.get("squared", False) and metric == "euclidean":
metric = "sqeuclidean"
metric_kwargs = {}

indices = PairwiseDistancesArgKmin.compute(
X=X,
Y=Y,
k=1,
metric=metric,
metric_kwargs=metric_kwargs,
strategy="auto",
return_distance=False,
)
indices = indices.flatten()
else:
# TODO: once PairwiseDistancesArgKmin supports sparse input matrices and 32 bit,
# we won't need to fallback to pairwise_distances_chunked anymore.

# Turn off check for finiteness because this is costly and because arrays
# have already been validated.
with config_context(assume_finite=True):
indices = np.concatenate(
list(
# This returns a np.ndarray generator whose arrays we need
# to flatten into one.
pairwise_distances_chunked(
X, Y, reduce_func=_argmin_reduce, metric=metric, **metric_kwargs
)
)
)

return indices


def haversine_distances(X, Y=None):
Expand Down
12 changes: 1 addition & 11 deletions sklearn/metrics/tests/test_dist_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import scipy.sparse as sp
from scipy.spatial.distance import cdist
from sklearn.metrics import DistanceMetric
from sklearn.metrics._dist_metrics import BOOL_METRICS
from sklearn.utils import check_random_state
from sklearn.utils._testing import create_memmap_backed_data
from sklearn.utils.fixes import sp_version, parse_version
Expand Down Expand Up @@ -38,17 +39,6 @@ def dist_func(x1, x2, p):
V = rng.random_sample((d, d))
VI = np.dot(V, V.T)

BOOL_METRICS = [
"hamming",
"matching",
"jaccard",
"dice",
"kulsinski",
"rogerstanimoto",
"russellrao",
"sokalmichener",
"sokalsneath",
]

METRICS_DEFAULT_PARAMS = [
("euclidean", {}),
Expand Down
24 changes: 18 additions & 6 deletions sklearn/metrics/tests/test_pairwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,10 +432,11 @@ def test_paired_distances_callable():
paired_distances(X, Y)


def test_pairwise_distances_argmin_min():
@pytest.mark.parametrize("dtype", (np.float32, np.float64))
def test_pairwise_distances_argmin_min(dtype):
# Check pairwise minimum distances computation for any metric
X = [[0], [1]]
Y = [[-2], [3]]
X = np.asarray([[0], [1]], dtype=dtype)
Y = np.asarray([[-2], [3]], dtype=dtype)

Xsp = dok_matrix(X)
Ysp = csr_matrix(Y, dtype=np.float32)
Expand All @@ -458,12 +459,23 @@ def test_pairwise_distances_argmin_min():
assert type(idxsp) == np.ndarray
assert type(valssp) == np.ndarray

# euclidean metric squared
idx, vals = pairwise_distances_argmin_min(
# Squared Euclidean metric
idx, vals = pairwise_distances_argmin_min(X, Y, metric="sqeuclidean")
idx2, vals2 = pairwise_distances_argmin_min(
X, Y, metric="euclidean", metric_kwargs={"squared": True}
)
1241 assert_array_almost_equal(idx, expected_idx)
idx3 = pairwise_distances_argmin(X, Y, metric="sqeuclidean")
idx4 = pairwise_distances_argmin(
X, Y, metric="euclidean", metric_kwargs={"squared": True}
)

assert_array_almost_equal(vals, expected_vals_sq)
assert_array_almost_equal(vals2, expected_vals_sq)

assert_array_almost_equal(idx, expected_idx)
assert_array_almost_equal(idx2, expected_idx)
assert_array_almost_equal(idx3, expected_idx)
assert_array_almost_equal(idx4, expected_idx)

# Non-euclidean scikit-learn metric
idx, vals = pairwise_distances_argmin_min(X, Y, metric="manhattan")
Expand Down
Loading
0