8000 MAINT Plug `PairwiseDistancesArgKmin` as a back-end (#22288) · scikit-learn/scikit-learn@5a4d710 · GitHub
[go: up one dir, main page]

Skip to content

Commit 5a4d710

Browse files
authored
MAINT Plug PairwiseDistancesArgKmin as a back-end (#22288)
* Forward pairwise_dist_chunk_size in the configuration * Flip finalized results for PairwiseDistancesArgKmin The previous would have made the code more complex by introducing some boilerplate for the interface plugs. Having it this way actually simplifies the code. This also removes the haversine branch for test_pairwise_distances_argkmin Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org> * Plug PairwiseDistancesArgKmin as a back-end * Adapt test accordingly * Add whats_new entry * Change input validation order for kneighbors * Remove duplicated test_neighbors_distance_metric_deprecation Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com> * Adapt the documentation Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com> * Add mahalanobis case to test fixtures * Correct whats_new entry * CLN Remove unneeded private metric attribute This was needed when 'fast_sqeuclidean' and 'fast_euclidean' were present to choose the best implementation based on the user specification. Those metric have been removed since then, making this attribute useless. Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com> * TST Assert FutureWarning instead of DeprecationWarning in test_neighbors_metrics * MAINT Add use_pairwise_dist_activate to scikit-learn config * TST Add a test for the 'brute' backends' results' consistency Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com> * fixup! MAINT Add use_pairwise_dist_activate to scikit-learn config * fixup! fixup! MAINT Add use_pairwise_dist_activate to scikit-learn config * TST Filter FutureWarning for WMinkowskiDistance * MAINT pin numpydoc in arm for now (#22292) * fixup! TST Filter FutureWarning for WMinkowskiDistance * Revert keywords arguments removal for the GEMM trick for 'euclidean' * MAINT pin max numpydoc for now (#22286) * Add 'haversine' to CDIST_PAIRWISE_DISTANCES_REDUCTION_COMMON_METRICS Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org> * fixup! Add 'haversine' to CDIST_PAIRWISE_DISTANCES_REDUCTION_COMMON_METRICS * Apply suggestions from code review Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org> * MAINT Document some config parameters for maintenance Also rename one of them. Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com> Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org> * FIX Support and test one of 'sqeuclidean' specification Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org> * FIX Various typos fix and correct haversine 'haversine' is not supported by cdist. * Directly use get_config * CLN Apply comments from review Co-authored-by: Christian Lorentzen <lorentzen.ch@gmail.com> Co-authored-by: Jérémie du Boisberranger <jeremiedbb@users.noreply.github.com> * Motivate swapped returned values Co-authored-by: Jérémie du Boisberranger <jeremiedbb@users.noreply.github.com> Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com> * TST Remove mahalanobis from test fixtures Co-authored-by: Jérémie du Boisberranger <jeremiedbb@users.noreply.github.com> * MNT Add comment regaduction functions' signatures Co-authored-by: Christian Lorentzen <lorentzen.ch@gmail.com> Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org> * TST Complete test for `pairwise_distance_{argmin,argmin_min}` (#22371) * DOC Add sub-pull requests to the whats_new entry
1 parent 998c712 commit 5a4d710

File tree

11 files changed

+305
-58
lines changed

11 files changed

+305
-58
lines changed

doc/whats_new/v1.1.rst

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -617,6 +617,38 @@ Changelog
617617
left corner of the HTML representation to show how the elements are
618618
clickable. :pr:`21298` by `Thomas Fan`_.
619619

620+
Miscellaneous
621+
.............
622+
623+
- |Efficiency| Low-level routines for reductions on pairwise distances
624+
for dense float64 datasets have been refactored. The following functions
625+
and estimators now benefit from improved performances, in particular on
626+
multi-cores machines:
627+
- :func:`sklearn.metrics.pairwise_distances_argmin`
628+
- :func:`sklearn.metrics.pairwise_distances_argmin_min`
629+
- :class:`sklearn.cluster.AffinityPropagation`
630+
- :class:`sklearn.cluster.Birch`
631+
- :class:`sklearn.cluster.MeanShift`
632+
- :class:`sklearn.cluster.OPTICS`
633+
- :class:`sklearn.cluster.SpectralClustering`
634+
- :func:`sklearn.feature_selection.mutual_info_regression`
635+
- :class:`sklearn.neighbors.KNeighborsClassifier`
636+
- :class:`sklearn.neighbors.KNeighborsRegressor`
637+
- :class:`sklearn.neighbors.LocalOutlierFactor`
638+
- :class:`sklearn.neighbors.NearestNeighbors`
639+
- :class:`sklearn.manifold.Isomap`
640+
- :class:`sklearn.manifold.LocallyLinearEmbedding`
641+
- :class:`sklearn.manifold.TSNE`
642+
- :func:`sklearn.manifold.trustworthiness`
643+
- :class:`sklearn.semi_supervised.LabelPropagation`
644+
- :class:`sklearn.semi_supervised.LabelSpreading`
645+
646+
For instance :class:`sklearn.neighbors.NearestNeighbors.kneighbors`
647+
can be up to ×20 faster than in the previous versions'.
648+
649+
:pr:`21987`, :pr:`22064`, :pr:`22065` and :pr:`22288`
650+
by :user:`Julien Jerphanion <jjerphan>`
651+
620652
- |Fix| :func:`check_scalar` raises an error when `include_boundaries={"left", "right"}`
621653
and the boundaries are not set.
622654
:pr:`22027` by `Marie Lanternier <mlant>`.

sklearn/_config.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
"pairwise_dist_chunk_size": int(
1313
os.environ.get("SKLEARN_PAIRWISE_DIST_CHUNK_SIZE", 256)
1414
),
15+
"enable_cython_pairwise_dist": True,
1516
}
1617
_threadlocal = threading.local()
1718

@@ -48,6 +49,7 @@ def set_config(
4849
print_changed_only=None,
4950
display=None,
5051
pairwise_dist_chunk_size=None,
52+
enable_cython_pairwise_dist=None,
5153
):
5254
"&quo 10000 t;"Set global scikit-learn configuration
5355
@@ -88,9 +90,23 @@ def set_config(
8890
.. versionadded:: 0.23
8991
9092
pairwise_dist_chunk_size : int, default=None
91-
The number of vectors per chunk for PairwiseDistancesReduction.
93+
The number of row vectors per chunk for PairwiseDistancesReduction.
9294
Default is 256 (suitable for most of modern laptops' caches and architectures).
9395
96+
Intended for easier benchmarking and testing of scikit-learn internals.
97+
End users are not expected to benefit from customizing this configuration
98+
setting.
99+
100+
.. versionadded:: 1.1
101+
102+
enable_cython_pairwise_dist : bool, default=None
103+
Use PairwiseDistancesReduction when possible.
104+
Default is True.
105+
106+
Intended for easier benchmarking and testing of scikit-learn internals.
107+
End users are not expected to benefit from customizing this configuration
108+
setting.
109+
94110
.. versionadded:: 1.1
95111
96112
See Also
@@ -110,6 +126,8 @@ def set_config(
110126
local_config["display"] = display
111127
if pairwise_dist_chunk_size is not None:
112128
local_config["pairwise_dist_chunk_size"] = pairwise_dist_chunk_size
129+
if enable_cython_pairwise_dist is not None:
130+
local_config["enable_cython_pairwise_dist"] = enable_cython_pairwise_dist
113131

114132

115133
@contextmanager
@@ -120,6 +138,7 @@ def config_context(
120138
print_changed_only=None,
121139
display=None,
122140
pairwise_dist_chunk_size=None,
141+
enable_cython_pairwise_dist=None,
123142
):
124143
"""Context manager for global scikit-learn configuration.
125144
@@ -162,6 +181,20 @@ def config_context(
162181
The number of vectors per chunk for PairwiseDistancesReduction.
163182
Default is 256 (suitable for most of modern laptops' caches and architectures).
164183
184+
Intended for easier benchmarking and testing of scikit-learn internals.
185+
End users are not expected to benefit from customizing this configuration
186+
setting.
187+
188+
.. versionadded:: 1.1
189+
190+
enable_cython_pairwise_dist : bool, default=None
191+
Use PairwiseDistancesReduction when possible.
192+
Default is True.
193+
194+
Intended for easier benchmarking and testing of scikit-learn internals.
195+
End users are not expected to benefit from customizing this configuration
196+
setting.
197+
165198
.. versionadded:: 1.1
166199
167200
Yields
@@ -197,6 +230,8 @@ def config_context(
197230
working_memory=working_memory,
198231
print_changed_only=print_changed_only,
199232
display=display,
233+
pairwise_dist_chunk_size=pairwise_dist_chunk_size,
234+
enable_cython_pairwise_dist=enable_cython_pairwise_dist,
200235
)
201236

202237
try:

sklearn/metrics/_dist_metrics.pyx

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ METRIC_MAPPING = {'euclidean': EuclideanDistance,
7171
'pyfunc': PyFuncDistance}
7272

7373
BOOL_METRICS = [
74+
"hamming",
7475
"matching",
7576
"jaccard",
7677
"dice",

sklearn/metrics/_pairwise_distances_reduction.pyx

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,8 @@ cdef class PairwiseDistancesReduction:
211211
True if the PairwiseDistancesReduction can be used, else False.
212212
"""
213213
# TODO: support sparse arrays and 32 bits
214-
return (not issparse(X) and X.dtype == np.float64 and
214+
return (get_config().get("enable_cython_pairwise_dist", True) and
215+
not issparse(X) and X.dtype == np.float64 and
215216
not issparse(Y) and Y.dtype == np.float64 and
216217
metric in cls.valid_metrics())
217218

@@ -621,10 +622,10 @@ cdef class PairwiseDistancesArgKmin(PairwiseDistancesReduction):
621622
Indices of the argkmin for each vector in X.
622623
623624
If return_distance=True:
624-
- argkmin_indices : ndarray of shape (n_samples_X, k)
625-
Indices of the argkmin for each vector in X.
626625
- argkmin_distances : ndarray of shape (n_samples_X, k)
627626
Distances to the argkmin for each vector in X.
627+
- argkmin_indices : ndarray of shape (n_samples_X, k)
628+
Indices of the argkmin for each vector in X.
628629
629630
Notes
630631
-----
@@ -642,7 +643,7 @@ cdef class PairwiseDistancesArgKmin(PairwiseDistancesReduction):
642643
# Note (jjerphan): Some design thoughts for future extensions.
643644
# This factory comes to handle specialisations for the given arguments.
644645
# For future work, this might can be an entrypoint to specialise operations
645-
# for various back-end and/or hardware and/or datatypes, and/or fused
646+
# for various backend and/or hardware and/or datatypes, and/or fused
646647
# {sparse, dense}-datasetspair etc.
647648
if (
648649
metric in ("euclidean", "sqeuclidean")
@@ -883,7 +884,11 @@ cdef class PairwiseDistancesArgKmin(PairwiseDistancesReduction):
883884
# We need to recompute distances because we relied on
884885
# surrogate distances for the reduction.
885886
self.compute_exact_distances()
886-
return np.asarray(self.argkmin_indices), np.asarray(self.argkmin_distances)
887+
888+
# Values are returned identically to the way `KNeighborsMixin.kneighbors`
889+
# returns values. This is counter-intuitive but this allows not using
890+
# complex adaptations where `PairwiseDistancesArgKmin.compute` is called.
891+
return np.asarray(self.argkmin_distances), np.asarray(self.argkmin_indices)
887892

888893
return np.asarray(self.argkmin_indices)
889894

sklearn/metrics/pairwise.py

Lines changed: 90 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from scipy.sparse import issparse
2020
from joblib import Parallel, effective_n_jobs
2121

22+
from .. import config_context
2223
from ..utils.validation import _num_samples
2324
from ..utils.validation import check_non_negative
2425
from ..utils import check_array
@@ -31,6 +32,7 @@
3132
from ..utils.fixes import delayed
3233
from ..utils.fixes import sp_version, parse_version
3334

35+
from ._pairwise_distances_reduction import PairwiseDistancesArgKmin
3436
from ._pairwise_fast import _chi2_kernel_fast, _sparse_manhattan
3537
from ..exceptions import DataConversionWarning
3638

@@ -576,12 +578,23 @@ def _euclidean_distances_upcast(X, XX=None, Y=None, 10000 YY=None, batch_size=None):
576578
return distances
577579

578580

581+
# start is specified in the signature of `_argmin_min_reduce`
582+
# and of `_argmin_reduce` but is not used.
583+
# This is because the higher order `pairwise_distances_chunked`
584+
# function needs reduction functions that are passed as argument
585+
# to have a two arguments signature.
586+
587+
579588
def _argmin_min_reduce(dist, start):
580589
indices = dist.argmin(axis=1)
581590
values = dist[np.arange(dist.shape[0]), indices]
582591
return indices, values
583592

584593

594+
def _argmin_reduce(dist, start):
595+
return dist.argmin(axis=1)
596+
597+
585598
def pairwise_distances_argmin_min(
586599
X, Y, *, axis=1, metric="euclidean", metric_kwargs=None
587600
):
@@ -654,19 +667,44 @@ def pairwise_distances_argmin_min(
654667
"""
655668
X, Y = check_pairwise_arrays(X, Y)
656669

657-
if metric_kwargs is None:
658-
metric_kwargs = {}
659-
660670
if axis == 0:
661671
X, Y = Y, X
662672

663-
indices, values = zip(
664-
*pairwise_distances_chunked(
665-
X, Y, reduce_func=_argmin_min_reduce, metric=metric, **metric_kwargs
673+
if metric_kwargs is None:
674+
metric_kwargs = {}
675+
676+
if PairwiseDistancesArgKmin.is_usable_for(X, Y, metric):
677+
# This is an adaptor for one "sqeuclidean" specification.
678+
# For this backend, we can directly use "sqeuclidean".
679+
if metric_kwargs.get("squared", False) and metric == "euclidean":
680+
metric = "sqeuclidean"
681+
metric_kwargs = {}
682+
683+
values, indices = PairwiseDistancesArgKmin.compute(
684+
X=X,
685+
Y=Y,
686+
k=1,
687+
metric=metric,
688+
metric_kwargs=metric_kwargs,
689+
strategy="auto",
690+
return_distance=True,
666691
)
667-
)
668-
indices = np.concatenate(indices)
669-
values = np.concatenate(values)
692+
values = values.flatten()
693+
indices = indices.flatten()
694+
else:
695+
# TODO: once PairwiseDistancesArgKmin supports sparse input matrices and 32 bit,
696+
# we won't need to fallback to pairwise_distances_chunked anymore.
697+
698+
# Turn off check for finiteness because this is costly and because arrays
699+
# have already been validated.
700+
with config_context(assume_finite=True):
701+
indices, values = zip(
702+
*pairwise_distances_chunked(
703+
X, Y, reduce_func=_argmin_min_reduce, metric=metric, **metric_kwargs
704+
)
705+
)
706+
indices = np.concatenate(indices)
707+
values = np.concatenate(values)
670708

671709
return indices, values
672710

@@ -738,9 +776,49 @@ def pairwise_distances_argmin(X, Y, *, axis=1, metric="euclidean", metric_kwargs
738776
if metric_kwargs is None:
739777
metric_kwargs = {}
740778

741-
return pairwise_distances_argmin_min(
742-
X, Y, axis=axis, metric=metric, metric_kwargs=metric_kwargs
743-
)[0]
779+
X, Y = check_pairwise_arrays(X, Y)
780+
781+
if axis == 0:
782+
X, Y = Y, X
783+
784+
if metric_kwargs is None:
785+
metric_kwargs = {}
786+
787+
if PairwiseDistancesArgKmin.is_usable_for(X, Y, metric):
788+
# This is an adaptor for one "sqeuclidean" specification.
789+
# For this backend, we can directly use "sqeuclidean".
790+
if metric_kwargs.get("squared", False) and metric == "euclidean":
791+
metric = "sqeuclidean"
792+
metric_kwargs = {}
793+
794+
indices = PairwiseDistancesArgKmin.compute(
795+
X=X,
796+
Y=Y,
797+
k=1,
798+
metric=metric,
799+
metric_kwargs=metric_kwargs,
800+
strategy="auto",
801+
return_distance=False,
802+
)
803+
indices = indices.flatten()
804+
else:
805+
# TODO: once PairwiseDistancesArgKmin supports sparse input matrices and 32 bit,
806+
# we won't need to fallback to pairwise_distances_chunked anymore.
807+
808+
# Turn off check for finiteness because this is costly and because arrays
809+
# have already been validated.
810+
with config_context(assume_finite=True):
811+
indices = np.concatenate(
812+
list(
813+
# This returns a np.ndarray generator whose arrays we need
814+
# to flatten into one.
815+
pairwise_distances_chunked(
816+
X, Y, reduce_func=_argmin_reduce, metric=metric, **metric_kwargs
817+
)
818+
)
819+
)
820+
821+
return indices
744822

745823

746824
def haversine_distances(X, Y=None):

sklearn/metrics/tests/test_dist_metrics.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import scipy.sparse as sp
1111
from scipy.spatial.distance import cdist
1212
from sklearn.metrics import DistanceMetric
13+
from sklearn.metrics._dist_metrics import BOOL_METRICS
1314
from sklearn.utils import check_random_state
1415
from sklearn.utils._testing import create_memmap_backed_data
1516
from sklearn.utils.fixes import sp_version, parse_version
@@ -38,17 +39,6 @@ def dist_func(x1, x2, p):
3839
V = rng.random_sample((d, d))
3940
VI = np.dot(V, V.T)
4041

41-
BOOL_METRICS = [
42-
"hamming",
43-
"matching",
44-
"jaccard",
45-
"dice",
46-
"kulsinski",
47-
"rogerstanimoto",
48-
"russellrao",
49-
"sokalmichener",
50-
"sokalsneath",
51-
]
5242

5343
METRICS_DEFAULT_PARAMS = [
5444
("euclidean", {}),

sklearn/metrics/tests/test_pairwise.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -432,10 +432,11 @@ def test_paired_distances_callable():
432432
paired_distances(X, Y)
433433

434434

435-
def test_pairwise_distances_argmin_min():
435+
@pytest.mark.parametrize("dtype", (np.float32, np.float64))
436+
def test_pairwise_distances_argmin_min(dtype):
436437
# Check pairwise minimum distances computation for any metric
437-
X = [[0], [1]]
438-
Y = [[-2], [3]]
438+
X = np.asarray([[0], [1]], dtype=dtype)
439+
Y = np.asarray([[-2], [3]], dtype=dtype)
439440

440441
Xsp = dok_matrix(X)
441442
Ysp = csr_matrix(Y, dtype=np.float32)
@@ -458,12 +459,23 @@ def test_pairwise_distances_argmin_min():
458459
assert type(idxsp) == np.ndarray
459460
assert type(valssp) == np.ndarray
460461

461-
# euclidean metric squared
462-
idx, vals = pairwise_distances_argmin_min(
462+
# Squared Euclidean metric
463+
idx, vals = pairwise_distances_argmin_min(X, Y, metric="sqeuclidean")
464+
idx2, vals2 = pairwise_distances_argmin_min(
463465
X, Y, metric="euclidean", metric_kwargs={"squared": True}
464466
)
465-
assert_array_almost_equal(idx, expected_idx)
467+
idx3 = pairwise_distances_argmin(X, Y, metric="sqeuclidean")
468+
idx4 = pairwise_distances_argmin(
469+
X, Y, metric="euclidean", metric_kwargs={"squared": True}
470+
)
471+
466472
assert_array_almost_equal(vals, expected_vals_sq)
473+
assert_array_almost_equal(vals2, expected_vals_sq)
474+
475+
assert_array_almost_equal(idx, expected_idx)
476+
assert_array_almost_equal(idx2, expected_idx)
477+
assert_array_almost_equal(idx3, expected_idx)
478+
assert_array_almost_equal(idx4, expected_idx)
467479

468480
# Non-euclidean scikit-learn metric
469481
idx, vals = pairwise_distances_argmin_min(X, Y, metric="manhattan")

0 commit comments

Comments
 (0)
0