8000 FIX pairwise_distances_argmin_min wrong with metric="euclidean" (#12481) · lithuak/scikit-learn@c2f17d0 · GitHub
[go: up one dir, main page]

Skip to content

Commit c2f17d0

Browse files
jeremiedbbqinhanmin2014
authored andcommitted
FIX pairwise_distances_argmin_min wrong with metric="euclidean" (scikit-learn#12481)
1 parent 6555631 commit c2f17d0

File tree

3 files changed

+45
-32
lines changed

3 files changed

+45
-32
lines changed

doc/whats_new/v0.20.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,13 @@ Changelog
103103
:class:`decomposition.IncrementalPCA` when using float32 datasets.
104104
:issue:`12338` by :user:`bauks <bauks>`.
105105

106+
:mod:`sklearn.metrics`
107+
......................
108+
109+
- |Fix| Fixed a bug in :func:`pairwise.pairwise_distances_argmin_min` which
110+
returned the square root of the distance when the metric parameter was set to
111+
"euclidean". :issue:`12481` by :user:`Jérémie du Boisberranger <jeremiedbb>`.
112+
106113
Miscellaneous
107114
.............
108115

sklearn/metrics/pairwise.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -358,8 +358,6 @@ def pairwise_distances_argmin_min(X, Y, axis=1, metric="euclidean",
358358
indices = np.concatenate(indices)
359359
values = np.concatenate(values)
360360

361-
if metric == "euclidean" and not metric_kwargs.get("squared", False):
362-
np.sqrt(values, values)
363361
return indices, values
364362

365363

sklearn/metrics/tests/test_pairwise.py

Lines changed: 38 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -343,49 +343,57 @@ def test_paired_distances_callable():
343343
def test_pairwise_distances_argmin_min():
344344
# Check pairwise minimum distances computation for any metric
345345
X = [[0], [1]]
346-
Y = [[-1], [2]]
346+
Y = [[-2], [3]]
347347

348348
Xsp = dok_matrix(X)
349349
Ysp = csr_matrix(Y, dtype=np.float32)
350350

351-
# euclidean metric
352-
D, E = pairwise_distances_argmin_min(X, Y, metric="euclidean")
353-
D2 = pairwise_distances_argmin(X, Y, metric="euclidean")
354-
assert_array_almost_equal(D, [0, 1])
355-
assert_array_almost_equal(D2, [0, 1])
356-
assert_array_almost_equal(D, [0, 1])
357-
assert_array_almost_equal(E, [1., 1.])
351+
expected_idx = [0, 1]
352+
expected_vals = [2, 2]
353+
expected_vals_sq = [4, 4]
358354

355+
# euclidean metric
356+
idx, vals = pairwise_distances_argmin_min(X, Y, metric="euclidean")
357+
idx2 = pairwise_distances_argmin(X, Y, metric="euclidean")
358+
assert_array_almost_equal(idx, expected_idx)
359+
assert_array_almost_equal(idx2, expected_idx)
360+
assert_array_almost_equal(vals, expected_vals)
359361
# sparse matrix case
360-
Dsp, Esp = pairwise_distances_argmin_min(Xsp, Ysp, metric="euclidean")
361-
assert_array_equal(Dsp, D)
362-
assert_array_equal(Esp, E)
362+
idxsp, valssp = pairwise_distances_argmin_min(Xsp, Ysp, metric="euclidean")
363+
assert_array_almost_equal(idxsp, expected_idx)
364+
assert_array_almost_equal(valssp, expected_vals)
363365
# We don't want np.matrix here
364-
assert_equal(type(Dsp), np.ndarray)
365-
assert_equal(type(Esp), np.ndarray)
366+
assert_equal(type(idxsp), np.ndarray)
367+
assert_equal(type(valssp), np.ndarray)
368+
369+
# euclidean metric squared
370+
idx, vals = pairwise_distances_argmin_min(X, Y, metric="euclidean",
371+
metric_kwargs={"squared": True})
372+
assert_array_almost_equal(idx, expected_idx)
373+
assert_array_almost_equal(vals, expected_vals_sq)
366374

367375
# Non-euclidean scikit-learn metric
368-
D, E = pairwise_distances_argmin_min(X, Y, metric="manhattan")
369-
D2 = pairwise_distances_argmin(X, Y, metric="manhattan")
370-
assert_array_almost_equal(D, [0, 1])
371-
assert_array_almost_equal(D2, [0, 1])
372-
assert_array_almost_equal(E, [1., 1.])
373-
D, E = pairwise_distances_argmin_min(Xsp, Ysp, metric="manhattan")
374-
D2 = pairwise_distances_argmin(Xsp, Ysp, metric="manhattan")
375-
assert_array_almost_equal(D, [0, 1])
376-
assert_array_almost_equal(E, [1., 1.])
376+
idx, vals = pairwise_distances_argmin_min(X, Y, metric="manhattan")
377+
idx2 = pairwise_distances_argmin(X, Y, metric="manhattan")
378+
assert_array_almost_equal(idx, expected_idx)
379+
assert_array_almost_equal(idx2, expected_idx)
380+
assert_array_almost_equal(vals, expected_vals)
381+
# sparse matrix case
382+
idxsp, valssp = pairwise_distances_argmin_min(Xsp, Ysp, metric="manhattan")
383+
assert_array_almost_equal(idxsp, expected_idx)
384+
assert_array_almost_equal(valssp, expected_vals)
377385

378386
# Non-euclidean Scipy distance (callable)
379-
D, E = pairwise_distances_argmin_min(X, Y, metric=minkowski,
380-
metric_kwargs={"p": 2})
381-
assert_array_almost_equal(D, [0, 1])
382-
assert_array_almost_equal(E, [1., 1.])
387+
idx, vals = pairwise_distances_argmin_min(X, Y, metric=minkowski,
388+
metric_kwargs={"p": 2})
389+
assert_array_almost_equal(idx, expected_idx)
390+
assert_array_almost_equal(vals, expected_vals)
383391

384392
# Non-euclidean Scipy distance (string)
385-
D, E = pairwise_distances_argmin_min(X, Y, metric="minkowski",
386-
metric_kwargs={"p": 2})
387-
assert_array_almost_equal(D, [0, 1])
388-
assert_array_almost_equal(E, [1., 1.])
393+
idx, vals = pairwise_distances_argmin_min(X, Y, metric="minkowski",
394+
metric_kwargs={"p": 2})
395+
assert_array_almost_equal(idx, expected_idx)
396+
assert_array_almost_equal(vals, expected_vals)
389397

390398
# Compare with naive implementation
391399
rng = np.random.RandomState(0)

0 commit comments

Comments
 (0)
0