8000 [MRG] set blockwise diagonals to zero for euclidean distance (#12612) · scikit-learn/scikit-learn@0b8650a · GitHub
[go: up one dir, main page]

Skip to content

Commit 0b8650a

Browse files
amuellerogrisel
authored andcommitted
[MRG] set blockwise diagonals to zero for euclidean distance (#12612)
1 parent ac327c5 commit 0b8650a

File tree

3 files changed

+23
-1
lines changed

3 files changed

+23
-1
lines changed

doc/whats_new/v0.20.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,10 @@ Changelog
120120
set to "euclidean". :issue:`12481` by
121121
:user:`Jérémie du Boisberranger <jeremiedbb>`.
122122

123+
- |Fix| Fixed a bug in :func:`metrics.pairwise.pairwise_distances_chunked`
124+
which didn't ensure the diagonal is zero for euclidean distances.
125+
:issue:`12612` by :user:`Andreas Müller <amueller>`.
126+
123127
- |API| The :func:`metrics.calinski_harabaz_score` has been renamed to
124128
:func:`metrics.calinski_harabasz_score` and will be removed in version 0.23.
125129
:issue:`12211` by :user:`Lisa Thomas <LisaThomas9>`,

sklearn/metrics/pairwise.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1271,6 +1271,12 @@ def pairwise_distances_chunked(X, Y=None, reduce_func=None,
12711271
X_chunk = X[sl]
12721272
D_chunk = pairwise_distances(X_chunk, Y, metric=metric,
12731273
n_jobs=n_jobs, **kwds)
1274+
if ((X is Y or Y is None)
1275+
and PAIRWISE_DISTANCE_FUNCTIONS.get(metric, None)
1276+
is euclidean_distances):
1277+
# zeroing diagonal, taking care of aliases of "euclidean",
1278+
# i.e. "l2"
1279+
D_chunk.flat[sl.start::_num_samples(X) + 1] = 0
12741280
if reduce_func is not None:
12751281
chunk_size = D_chunk.shape[0]
12761282
D_chunk = reduce_func(D_chunk, sl.start)

sklearn/metrics/tests/test_pairwise.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -473,7 +473,7 @@ def check_pairwise_distances_chunked(X, Y, working_memory, metric='euclidean'):
473473
metric=metric)
474474
assert isinstance(gen, GeneratorType)
475475
blockwise_distances = list(gen)
476-
Y = np.array(X if Y is None else Y)
476+
Y = X if Y is None else Y
477477
min_block_mib = len(Y) * 8 * 2 ** -20
478478

479479
for block in blockwise_distances:
@@ -485,6 +485,18 @@ def check_pairwise_distances_chunked(X, Y, working_memory, metric='euclidean'):
485485
assert_array_almost_equal(blockwise_distances, S)
486486

487487

488+
@pytest.mark.parametrize(
489+
'metric',
490+
('euclidean', 'l2', 'sqeuclidean'))
491+
def test_pairwise_distances_chunked_diagonal(metric):
492+
rng = np.random.RandomState(0)
493+
X = rng.normal(size=(1000, 10), scale=1e10)
494+
chunks = list(pairwise_distances_chunked(X, working_memory=1,
495+
metric=metric))
496+
assert len(chunks) > 1
497+
assert_array_almost_equal(np.diag(np.vstack(chunks)), 0, decimal=10)
498+
499+
488500
@ignore_warnings
489501
def test_pairwise_distances_chunked():
490502
# Test the pairwise_distance helper function.

0 commit comments

Comments
 (0)
0