8000 support X_norm_squared in euclidean_distances · scikit-learn/scikit-learn@8d8b434 · GitHub
[go: up one dir, main page]

Skip to content

Commit 8d8b434

Browse files
committed
support X_norm_squared in euclidean_distances
1 parent 0bf7536 commit 8d8b434

File tree

2 files changed

+46
-12
lines changed

2 files changed

+46
-12
lines changed

sklearn/metrics/pairwise.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,8 @@ def check_paired_arrays(X, Y):
145145

146146

147147
# Pairwise distances
148-
def euclidean_distances(X, Y=None, Y_norm_squared=None, squared=False):
148+
def euclidean_distances(X, Y=None, Y_norm_squared=None, squared=False,
149+
X_norm_squared=None):
149150
"""
150151
Considering the rows of X (and Y=X) as vectors, compute the
151152
distance matrix between each pair of vectors.
@@ -157,8 +158,8 @@ def euclidean_distances(X, Y=None, Y_norm_squared=None, squared=False):
157158
158159
This formulation has two advantages over other ways of computing distances.
159160
First, it is computationally efficient when dealing with sparse data.
160-
Second, if x varies but y remains unchanged, then the right-most dot
161-
product `dot(y, y)` can be pre-computed.
161+
Second, if one argument varies but the other remains unchanged, then
162+
`dot(x, x)` and/or `dot(y, y)` can be pre-computed.
162163
163164
However, this is not the most precise way of doing this computation, and
164165
the distance matrix returned by this function may not be exactly
@@ -179,6 +180,10 @@ def euclidean_distances(X, Y=None, Y_norm_squared=None, squared=False):
179180
squared : boolean, optional
180181
Return squared Euclidean distances.
181182
183+
X_norm_squared : array-like, shape = [n_samples_1], optional
184+
Pre-computed dot-products of vectors in X (e.g.,
185+
``(X**2).sum(axis=1)``)
186+
182187
Returns
183188
-------
184189
distances : {array, sparse matrix}, shape (n_samples_1, n_samples_2)
@@ -200,24 +205,28 @@ def euclidean_distances(X, Y=None, Y_norm_squared=None, squared=False):
200205
--------
201206
paired_distances : distances betweens pairs of elements of X and Y.
202207
"""
203-
# should not need X_norm_squared because if you could precompute that as
204-
# well as Y, then you should just pre-compute the output and not even
205-
# call this function.
206208
X, Y = check_pairwise_arrays(X, Y)
207209

208-
if Y_norm_squared is not None:
210+
if X_norm_squared is not None:
211+
XX = check_array(X_norm_squared)
212+
if XX.shape == (1, X.shape[0]):
213+
XX = XX.T
214+
elif XX.shape != (X.shape[0], 1):
215+
raise ValueError(
216+
"Incompatible dimensions for X and X_norm_squared")
217+
else:
218+
XX = row_norms(X, squared=True)[:, np.newaxis]
219+
220+
if X is Y: # shortcut in the common case euclidean_distances(X, X)
221+
YY = XX.T
222+
elif Y_norm_squared is not None:
209223
YY = check_array(Y_norm_squared)
210224
if YY.shape != (1, Y.shape[0]):
211225
raise ValueError(
212226
"Incompatible dimensions for Y and Y_norm_squared")
213227
else:
214228
YY = row_norms(Y, squared=True)[np.newaxis, :]
215229

216-
if X is Y: # shortcut in the common case euclidean_distances(X, X)
217-
XX = YY.T
218-
else:
219-
XX = row_norms(X, squared=True)[:, np.newaxis]
220-
221230
distances = safe_sparse_dot(X, Y.T, dense_output=True)
222231
distances *= -2
223232
distances += XX

sklearn/metrics/tests/test_pairwise.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,31 @@ def test_euclidean_distances():
363363
D = euclidean_distances(X, Y)
364364
assert_array_almost_equal(D, [[1., 2.]])
365365

366+
rng = np.random.RandomState(0)
367+
X = rng.random_sample((10, 4))
368+
Y = rng.random_sample((20, 4))
369+
X_norm_sq = (X ** 2).sum(axis=1)
370+
Y_norm_sq = (Y ** 2).sum(axis=1)
371+
372+
# check that we still get the right answers with {X,Y}_norm_squared
373+
D1 = euclidean_distances(X, Y)
374+
D2 = euclidean_distances(X, Y, X_norm_squared=X_norm_sq)
375+
D3 = euclidean_distances(X, Y, Y_norm_squared=Y_norm_sq)
376+
D4 = euclidean_distances(X, Y, X_norm_squared=X_norm_sq,
377+
Y_norm_squared=Y_norm_sq)
378+
assert_array_almost_equal(D2, D1)
379+
assert_array_almost_equal(D3, D1)
380+
assert_array_almost_equal(D4, D1)
381+
382+
# check we get the wrong answer with wrong {X,Y}_norm_squared
383+
X_norm_sq *= 0.5
384+
Y_norm_sq *= 0.5
385+
wrong_D = euclidean_distances(X, Y,
386+
X_norm_squared=np.zeros_like(X_norm_sq),
387+
Y_norm_squared=np.zeros_like(Y_norm_sq))
388+
assert_greater(np.max(np.abs(wrong_D - D1)), .01)
389+
390+
366391

367392
# Paired distances
368393

0 commit comments

Comments
 (0)
0