56
56
57
57
58
58
# Utility Functions
59
+ def _return_float_dtype (X , Y ):
60
+ """
61
+ 1. If dtype of X and Y is float32, then dtype float32 is returned.
62
+ 2. Else dtype float is returned.
63
+ """
64
+ if not issparse (X ) and not isinstance (X , np .ndarray ):
65
+ X = np .asarray (X )
66
+
67
+ if Y is None :
68
+ Y_dtype = X .dtype
69
+ elif not issparse (Y ) and not isinstance (Y , np .ndarray ):
70
+ Y = np .asarray (Y )
71
+ Y_dtype = Y .dtype
72
+ else :
73
+ Y_dtype = Y .dtype
74
+
75
+ if X .dtype == Y_dtype == np .float32 :
76
+ dtype = np .float32
77
+ else :
78
+ dtype = np .float
79
+
80
+ return X , Y , dtype
81
+
59
82
def check_pairwise_arrays (X , Y ):
60
83
""" Set X and Y appropriately and checks inputs
61
84
@@ -85,22 +108,18 @@ def check_pairwise_arrays(X, Y):
85
108
If Y was None, safe_Y will be a pointer to X.
86
109
87
110
"""
111
+ X , Y , dtype = _return_float_dtype (X , Y )
112
+
88
113
if Y is X or Y is None :
89
- X = Y = check_array (X , accept_sparse = 'csr' )
114
+ X = Y = check_array (X , accept_sparse = [ 'csr' , 'csc' , 'coo' ], dtype = dtype )
90
115
else :
91
- X = check_array (X , accept_sparse = 'csr' )
92
- Y = check_array (Y , accept_sparse = 'csr' )
116
+ X = check_array (X , accept_sparse = [ 'csr' , 'csc' , 'coo' ], dtype = dtype )
117
+ Y = check_array (Y , accept_sparse = [ 'csr' , 'csc' , 'coo' ], dtype = dtype )
93
118
if X .shape [1 ] != Y .shape [1 ]:
94
119
raise ValueError ("Incompatible dimension for X and Y matrices: "
95
120
"X.shape[1] == %d while Y.shape[1] == %d" % (
96
121
X .shape [1 ], Y .shape [1 ]))
97
122
98
- if not (X .dtype == Y .dtype == np .float32 ):
99
- if Y is X :
100
- X = Y = check_array (X , ['csr' , 'csc' , 'coo' ], dtype = np .float )
101
- else :
102
- X = check_array (X , ['csr' , 'csc' , 'coo' ], dtype = np .float )
103
- Y = check_array (Y , ['csr' , 'csc' , 'coo' ], dtype = np .float )
104
123
return X , Y
105
124
106
125
@@ -195,6 +214,7 @@ def euclidean_distances(X, Y=None, Y_norm_squared=None, squared=False):
195
214
# should not need X_norm_squared because if you could precompute that as
196
215
# well as Y, then you should just pre-compute the output and not even
197
216
# call this function.
217
+
198
218
X , Y = check_pairwise_arrays (X , Y )
199
219
200
220
if Y_norm_squared is not None :
@@ -225,7 +245,8 @@ def euclidean_distances(X, Y=None, Y_norm_squared=None, squared=False):
225
245
226
246
227
247
def pairwise_distances_argmin_min (X , Y , axis = 1 , metric = "euclidean" ,
228
- batch_size = 500 , metric_kwargs = None ):
248
+ batch_size = 500 , metric_kwargs = None ,
249
+ check_X_y = True ):
229
250
"""Compute minimum distances between one point and a set of points.
230
251
231
252
This function computes for each row in X, the index of the row of Y which
@@ -280,6 +301,10 @@ def pairwise_distances_argmin_min(X, Y, axis=1, metric="euclidean",
280
301
metric_kwargs : dict, optional
281
302
Keyword arguments to pass to specified metric function.
282
303
304
+ check_X_y : bool, default True
305
+ Whether or not to check X and y for shape, validity and dtype. Speed
306
+ improvements possible if set to False when called repeatedly.
307
+
283
308
Returns
284
309
-------
285
310
argmin : numpy.ndarray
@@ -300,7 +325,8 @@ def pairwise_distances_argmin_min(X, Y, axis=1, metric="euclidean",
300
325
elif not callable (metric ) and not isinstance (metric , str ):
301
326
raise ValueError ("'metric' must be a string or a callable" )
302
327
303
- X , Y = check_pairwise_arrays (X , Y )
328
+ if check_X_y :
329
+ X , Y = check_pairwise_arrays (X , Y )
304
330
305
331
if metric_kwargs is None :
306
332
metric_kwargs = {}
@@ -470,7 +496,6 @@ def manhattan_distances(X, Y=None, sum_over_features=True,
470
496
[ 1., 1.]]...)
471
497
"""
472
498
X , Y = check_pairwise_arrays (X , Y )
473
-
474
499
if issparse (X ) or issparse (Y ):
475
500
if not sum_over_features :
476
501
raise TypeError ("sum_over_features=%r not supported"
0 commit comments