56
56
57
57
58
58
# Utility Functions
59
+ def _return_float_dtype (X , Y ):
60
+ """
61
+ 1. If dtype of X and Y is float32, then dtype float32 is returned.
62
+ 2. Else dtype float is returned.
63
+ """
64
+ if not issparse (X ) and not isinstance (X , np .ndarray ):
65
+ X = np .asarray (X )
66
+
67
+ if Y is None :
68
+ Y_dtype = X .dtype
69
+ elif not issparse (Y ) and not isinstance (Y , np .ndarray ):
70
+ Y = np .asarray (Y )
71
+ Y_dtype = Y .dtype
72
+ else :
73
+ Y_dtype = Y .dtype
74
+
75
+ if X .dtype == Y_dtype == np .float32 :
76
+ dtype = np .float32
77
+ else :
78
+ dtype = np .float
79
+
80
+ return X , Y , dtype
81
+
82
+
59
83
def check_pairwise_arrays (X , Y ):
60
84
""" Set X and Y appropriately and checks inputs
61
85
@@ -85,22 +109,18 @@ def check_pairwise_arrays(X, Y):
85
109
If Y was None, safe_Y will be a pointer to X.
86
110
87
111
"""
112
+ X , Y , dtype = _return_float_dtype (X , Y )
113
+
88
114
if Y is X or Y is None :
89
- X = Y = check_array (X , accept_sparse = 'csr' )
115
+ X = Y = check_array (X , accept_sparse = 'csr' , dtype = dtype )
90
116
else :
91
- X = check_array (X , accept_sparse = 'csr' )
92
- Y = check_array (Y , accept_sparse = 'csr' )
117
+ X = check_array (X , accept_sparse = 'csr' , dtype = dtype )
118
+ Y = check_array (Y , accept_sparse = 'csr' , dtype = dtype )
93
119
if X .shape [1 ] != Y .shape [1 ]:
94
120
raise ValueError ("Incompatible dimension for X and Y matrices: "
95
121
"X.shape[1] == %d while Y.shape[1] == %d" % (
96
122
X .shape [1 ], Y .shape [1 ]))
97
123
98
- if not (X .dtype == Y .dtype == np .float32 ):
99
- if Y is X :
100
- X = Y = check_array (X , ['csr' , 'csc' , 'coo' ], dtype = np .float )
101
- else :
102
- X = check_array (X , ['csr' , 'csc' , 'coo' ], dtype = np .float )
103
- Y = check_array (Y , ['csr' , 'csc' , 'coo' ], dtype = np .float )
104
124
return X , Y
105
125
106
126
@@ -225,7 +245,8 @@ def euclidean_distances(X, Y=None, Y_norm_squared=None, squared=False):
225
245
226
246
227
247
def pairwise_distances_argmin_min (X , Y , axis = 1 , metric = "euclidean" ,
228
- batch_size = 500 , metric_kwargs = None ):
248
+ batch_size = 500 , metric_kwargs = None ,
249
+ check_X_y = True ):
229
250
"""Compute minimum distances between one point and a set of points.
230
251
231
252
This function computes for each row in X, the index of the row of Y which
@@ -280,6 +301,10 @@ def pairwise_distances_argmin_min(X, Y, axis=1, metric="euclidean",
280
301
metric_kwargs : dict, optional
281
302
Keyword arguments to pass to specified metric function.
282
303
304
+ check_X_y : bool, default True
305
+ Whether or not to check X and y for shape, validity and dtype. Speed
306
+ improvements possible if set to False when called repeatedly.
307
+
283
308
Returns
284
309
-------
285
310
argmin : numpy.ndarray
@@ -300,7 +325,8 @@ def pairwise_distances_argmin_min(X, Y, axis=1, metric="euclidean",
300
325
elif not callable (metric ) and not isinstance (metric , str ):
301
326
raise ValueError ("'metric' must be a string or a callable" )
302
327
303
- X , Y = check_pairwise_arrays (X , Y )
328
+ if check_X_y :
329
+ X , Y = check_pairwise_arrays (X , Y )
304
330
305
331
if metric_kwargs is None :
306
332
metric_kwargs = {}
0 commit comments