@@ -106,7 +106,8 @@ def check_pairwise_arrays(X, Y):
106
106
107
107
108
108
# Distances
109
- def euclidean_distances (X , Y = None , Y_norm_squared = None , squared = False ):
109
+ def euclidean_distances (X , Y = None , Y_norm_squared = None , squared = False ,
110
+ X_norm_squared = None ):
110
111
"""
111
112
Considering the rows of X (and Y=X) as vectors, compute the
112
113
distance matrix between each pair of vectors.
@@ -117,9 +118,9 @@ def euclidean_distances(X, Y=None, Y_norm_squared=None, squared=False):
117
118
dist(x, y) = sqrt(dot(x, x) - 2 * dot(x, y) + dot(y, y))
118
119
119
120
This formulation has two main advantages. First, it is computationally
120
- efficient when dealing with sparse data. Second, if x varies but y
121
- remains unchanged, then the right-most dot-product `dot(y, y)` can be
122
- pre-computed .
121
+ efficient when dealing with sparse data. Second, the components `dot(x, x)`
122
+ or `dot(y, y)` can be pre-computed when getting euclidean distances for
123
+ multiple sets .
123
124
124
125
Parameters
125
126
----------
@@ -134,6 +135,10 @@ def euclidean_distances(X, Y=None, Y_norm_squared=None, squared=False):
134
135
squared : boolean, optional
135
136
Return squared Euclidean distances.
136
137
138
+ X_norm_squared : array-like, shape = [n_samples_1], optional
139
+ Pre-computed dot-products of vectors in X (e.g.,
140
+ ``(X**2).sum(axis=1)``)
141
+
137
142
Returns
138
143
-------
139
144
distances : {array, sparse matrix}, shape = [n_samples_1, n_samples_2]
@@ -151,24 +156,28 @@ def euclidean_distances(X, Y=None, Y_norm_squared=None, squared=False):
151
156
array([[ 1. ],
152
157
[ 1.41421356]])
153
158
"""
154
- # should not need X_norm_squared because if you could precompute that as
155
- # well as Y, then you should just pre-compute the output and not even
156
- # call this function.
157
159
X , Y = check_pairwise_arrays (X , Y )
158
160
159
- if Y_norm_squared is not None :
161
+ if X_norm_squared is not None :
162
+ XX = array2d (X_norm_squared )
163
+ if XX .shape == (1 , X .shape [0 ]):
164
+ XX = XX .T
165
+ elif XX .shape != (X .shape [0 ], 1 ):
166
+ raise ValueError (
167
+ "Incompatible dimensions for X and X_norm_squared" )
168
+ else :
169
+ XX = row_norms (X , squared = True )[:, np .newaxis ]
170
+
171
+ if X is Y : # shortcut in the common case euclidean_distances(X, X)
172
+ YY = XX .T
173
+ elif Y_norm_squared is not None :
160
174
YY = array2d (Y_norm_squared )
161
175
if YY .shape != (1 , Y .shape [0 ]):
162
176
raise ValueError (
163
177
"Incompatible dimensions for Y and Y_norm_squared" )
164
178
else :
165
179
YY = row_norms (Y , squared = True )[np .newaxis , :]
166
180
167
- if X is Y : # shortcut in the common case euclidean_distances(X, X)
168
- XX = YY .T
169
- else :
170
- XX = row_norms (X , squared = True )[:, np .newaxis ]
171
-
172
181
distances = safe_sparse_dot (X , Y .T , dense_output = True )
173
182
distances *= - 2
174
183
distances += XX
0 commit comments