8000 support X_norm_squared in euclidean_distances · scikit-learn/scikit-learn@8d2978c · GitHub
[go: up one dir, main page]

Skip to content

Commit 8d2978c

Browse files
committed
support X_norm_squared in euclidean_distances
1 parent e54c54a commit 8d2978c

File tree

2 files changed

+37
-13
lines changed

2 files changed

+37
-13
lines changed

sklearn/metrics/pairwise.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,8 @@ def check_pairwise_arrays(X, Y):
106106

107107

108108
# Distances
109-
def euclidean_distances(X, Y=None, Y_norm_squared=None, squared=False):
109+
def euclidean_distances(X, Y=None, Y_norm_squared=None, squared=False,
110+
X_norm_squared=None):
110111
"""
111112
Considering the rows of X (and Y=X) as vectors, compute the
112113
distance matrix between each pair of vectors.
@@ -117,9 +118,9 @@ def euclidean_distances(X, Y=None, Y_norm_squared=None, squared=False):
117118
dist(x, y) = sqrt(dot(x, x) - 2 * dot(x, y) + dot(y, y))
118119
119120
This formulation has two main advantages. First, it is computationally
120-
efficient when dealing with sparse data. Second, if x varies but y
121-
remains unchanged, then the right-most dot-product `dot(y, y)` can be
122-
pre-computed.
121+
efficient when dealing with sparse data. Second, the components `dot(x, x)`
122+
or `dot(y, y)` can be pre-computed when getting euclidean distances for
123+
multiple sets.
123124
124125
Parameters
125126
----------
@@ -134,6 +135,10 @@ def euclidean_distances(X, Y=None, Y_norm_squared=None, squared=False):
134135
squared : boolean, optional
135136
Return squared Euclidean distances.
136137
138+
X_norm_squared : array-like, shape = [n_samples_1], optional
139+
Pre-computed dot-products of vectors in X (e.g.,
140+
``(X**2).sum(axis=1)``)
141+
137142
Returns
138143
-------
139144
distances : {array, sparse matrix}, shape = [n_samples_1, n_samples_2]
@@ -151,24 +156,28 @@ def euclidean_distances(X, Y=None, Y_norm_squared=None, squared=False):
151156
array([[ 1. ],
152157
[ 1.41421356]])
153158
"""
154-
# should not need X_norm_squared because if you could precompute that as
155-
# well as Y, then you should just pre-compute the output and not even
156-
# call this function.
157159
X, Y = check_pairwise_arrays(X, Y)
158160

159-
if Y_norm_squared is not None:
161+
if X_norm_squared is not None:
162+
XX = array2d(X_norm_squared)
163+
if XX.shape == (1, X.shape[0]):
164+
XX = XX.T
165+
elif XX.shape != (X.shape[0], 1):
166+
raise ValueError(
167+
"Incompatible dimensions for X and X_norm_squared")
168+
else:
169+
XX = row_norms(X, squared=True)[:, np.newaxis]
170+
171+
if X is Y: # shortcut in the common case euclidean_distances(X, X)
172+
YY = XX.T
173+
elif Y_norm_squared is not None:
160174
YY = array2d(Y_norm_squared)
161175
if YY.shape != (1, Y.shape[0]):
162176
raise ValueError(
163177
"Incompatible dimensions for Y and Y_norm_squared")
164178
else:
165179
YY = row_norms(Y, squared=True)[np.newaxis, :]
166180

167-
if X is Y: # shortcut in the common case euclidean_distances(X, X)
168-
XX = YY.T
169-
else:
170-
XX = row_norms(X, squared=True)[:, np.newaxis]
171-
172181
distances = safe_sparse_dot(X, Y.T, dense_output=True)
173182
distances *= -2
174183
distances += XX

sklearn/metrics/tests/test_pairwise.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,21 @@ def test_euclidean_distances():
260260
D = euclidean_distances(X, Y)
261261
assert_array_almost_equal(D, [[1., 2.]])
262262

263+
rng = np.random.RandomState(0)
264+
X = rng.random_sample((10, 4))
265+
Y = rng.random_sample((20, 4))
266+
X_norm_sq = (X ** 2).sum(axis=1)
267+
Y_norm_sq = (Y ** 2).sum(axis=1)
268+
269+
D1 = euclidean_distances(X, Y)
270+
D2 = euclidean_distances(X, Y, X_norm_squared=X_norm_sq)
271+
D3 = euclidean_distances(X, Y, Y_norm_squared=Y_norm_sq)
272+
D4 = euclidean_distances(X, Y, X_norm_squared=X_norm_sq,
273+
Y_norm_squared=Y_norm_sq)
274+
assert_array_almost_equal(D1, D2)
275+
assert_array_almost_equal(D1, D3)
276+
assert_array_almost_equal(D1, D4)
277+
263278

264279
def test_chi_square_kernel():
265280
rng = np.random.RandomState(0)

0 commit comments

Comments
 (0)
0