8000 support X_norm_squared in euclidean_distances · scikit-learn/scikit-learn@4a1b0cc · GitHub
[go: up one dir, main page]

Skip to content

Commit 4a1b0cc

Browse files
committed
support X_norm_squared in euclidean_distances
1 parent 7224c31 commit 4a1b0cc

File tree

2 files changed

+36
-12
lines changed

2 files changed

+36
-12
lines changed

sklearn/metrics/pairwise.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,8 @@ def check_paired_arrays(X, Y):
133133

134134

135135
# Pairwise distances
136-
def euclidean_distances(X, Y=None, Y_norm_squared=None, squared=False):
136+
def euclidean_distances(X, Y=None, Y_norm_squared=None, squared=False,
137+
X_norm_squared=None):
137138
"""
138139
Considering the rows of X (and Y=X) as vectors, compute the
139140
distance matrix between each pair of vectors.
@@ -145,8 +146,8 @@ def euclidean_distances(X, Y=None, Y_norm_squared=None, squared=False):
145146
146147
This formulation has two advantages over other ways of computing distances.
147148
First, it is computationally efficient when dealing with sparse data.
148-
Second, if x varies but y remains unchanged, then the right-most dot
149-
product `dot(y, y)` can be pre-computed.
149+
Second, if one argument varies but the other remains unchanged, then
150+
`dot(x, x)` and/or `dot(y, y)` can be pre-computed.
150151
151152
However, this is not the most precise way of doing this computation, and
152153
the distance matrix returned by this function may not be exactly
@@ -167,6 +168,10 @@ def euclidean_distances(X, Y=None, Y_norm_squared=None, squared=False):
167168
squared : boolean, optional
168169
Return squared Euclidean distances.
169170
171+
X_norm_squared : array-like, shape = [n_samples_1], optional
172+
Pre-computed dot-products of vectors in X (e.g.,
173+
``(X**2).sum(axis=1)``)
174+
170175
Returns
171176
-------
172177
distances : {array, sparse matrix}, shape (n_samples_1, n_samples_2)
@@ -188,24 +193,28 @@ def euclidean_distances(X, Y=None, Y_norm_squared=None, squared=False):
188193
--------
189194
paired_distances : distances betweens pairs of elements of X and Y.
190195
"""
191-
# should not need X_norm_squared because if you could precompute that as
192-
# well as Y, then you should just pre-compute the output and not even
193-
# call this function.
194196
X, Y = check_pairwise_arrays(X, Y)
195197

196-
if Y_norm_squared is not None:
198+
if X_norm_squared is not None:
199+
XX = check_array(X_norm_squared)
200+
if XX.shape == (1, X.shape[0]):
201+
XX = XX.T
202+
elif XX.shape != (X.shape[0], 1):
203+
raise ValueError(
204+
"Incompatible dimensions for X and X_norm_squared")
205+
else:
206+
XX = row_norms(X, squared=True)[:, np.newaxis]
207+
208+
if X is Y: # shortcut in the common case euclidean_distances(X, X)
209+
YY = XX.T
210+
elif Y_norm_squared is not None:
197211
YY = check_array(Y_norm_squared)
198212
if YY.shape != (1, Y.shape[0]):
199213
raise ValueError(
200214
"Incompatible dimensions for Y and Y_norm_squared")
201215
else:
202216
YY = row_norms(Y, squared=True)[np.newaxis, :]
203217

204-
if X is Y: # shortcut in the common case euclidean_distances(X, X)
205-
XX = YY.T
206-
else:
207-
XX = row_norms(X, squared=True)[:, np.newaxis]
208-
209218
distances = safe_sparse_dot(X, Y.T, dense_output=True)
210219
distances *= -2
211220
distances += XX

sklearn/metrics/tests/test_pairwise.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,21 @@ def test_euclidean_distances():
334334
D = euclidean_distances(X, Y)
335335
assert_array_almost_equal(D, [[1., 2.]])
336336

337+
rng = np.random.RandomState(0)
338+
X = rng.random_sample((10, 4))
339+
Y = rng.random_sample((20, 4))
340+
X_norm_sq = (X ** 2).sum(axis=1)
341+
Y_norm_sq = (Y ** 2).sum(axis=1)
342+
343+
D1 = euclidean_distances(X, Y)
344+
D2 = euclidean_distances(X, Y, X_norm_squared=X_norm_sq)
345+
D3 = euclidean_distances(X, Y, Y_norm_squared=Y_norm_sq)
346+
D4 = euclidean_distances(X, Y, X_norm_squared=X_norm_sq,
347+
Y_norm_squared=Y_norm_sq)
348+
assert_array_almost_equal(D1, D2)
349+
assert_array_almost_equal(D1, D3)
350+
assert_array_almost_equal(D1, D4)
351+
337352

338353
# Paired distances
339354

0 commit comments

Comments
 (0)
0