5656
5757
5858# Utility Functions
59+ def _return_float_dtype (X , Y ):
60+ """
61+ 1. If dtype of X and Y is float32, then dtype float32 is returned.
62+ 2. Else dtype float is returned.
63+ """
64+ if not issparse (X ) and not isinstance (X , np .ndarray ):
65+ X = np .asarray (X )
66+
67+ if Y is None :
68+ Y_dtype = X .dtype
69+ elif not issparse (Y ) and not isinstance (Y , np .ndarray ):
70+ Y = np .asarray (Y )
71+ Y_dtype = Y .dtype
72+ else :
73+ Y_dtype = Y .dtype
74+
75+ if X .dtype == Y_dtype == np .float32 :
76+ dtype = np .float32
77+ else :
78+ dtype = np .float
79+
80+ return X , Y , dtype
81+
82+
5983def check_pairwise_arrays (X , Y ):
6084 """ Set X and Y appropriately and checks inputs
6185
@@ -85,22 +109,18 @@ def check_pairwise_arrays(X, Y):
85109 If Y was None, safe_Y will be a pointer to X.
86110
87111 """
112+ X , Y , dtype = _return_float_dtype (X , Y )
113+
88114 if Y is X or Y is None :
89- X = Y = check_array (X , accept_sparse = 'csr' )
115+ X = Y = check_array (X , accept_sparse = 'csr' , dtype = dtype )
90116 else :
91- X = check_array (X , accept_sparse = 'csr' )
92- Y = check_array (Y , accept_sparse = 'csr' )
117+ X = check_array (X , accept_sparse = 'csr' , dtype = dtype )
118+ Y = check_array (Y , accept_sparse = 'csr' , dtype = dtype )
93119 if X .shape [1 ] != Y .shape [1 ]:
94120 raise ValueError ("Incompatible dimension for X and Y matrices: "
95121 "X.shape[1] == %d while Y.shape[1] == %d" % (
96122 X .shape [1 ], Y .shape [1 ]))
97123
98- if not (X .dtype == Y .dtype == np .float32 ):
99- if Y is X :
100- X = Y = check_array (X , ['csr' , 'csc' , 'coo' ], dtype = np .float )
101- else :
102- X = check_array (X , ['csr' , 'csc' , 'coo' ], dtype = np .float )
103- Y = check_array (Y , ['csr' , 'csc' , 'coo' ], dtype = np .float )
104124 return X , Y
105125
106126
@@ -225,7 +245,8 @@ def euclidean_distances(X, Y=None, Y_norm_squared=None, squared=False):
225245
226246
227247def pairwise_distances_argmin_min (X , Y , axis = 1 , metric = "euclidean" ,
228- batch_size = 500 , metric_kwargs = None ):
248+ batch_size = 500 , metric_kwargs = None ,
249+ check_X_y = True ):
229250 """Compute minimum distances between one point and a set of points.
230251
231252 This function computes for each row in X, the index of the row of Y which
@@ -280,6 +301,10 @@ def pairwise_distances_argmin_min(X, Y, axis=1, metric="euclidean",
280301 metric_kwargs : dict, optional
281302 Keyword arguments to pass to specified metric function.
282303
304+ check_X_y : bool, default True
305+ Whether or not to check X and y for shape, validity and dtype. Speed
306+ improvements possible if set to False when called repeatedly.
307+
283308 Returns
284309 -------
285310 argmin : numpy.ndarray
@@ -300,7 +325,8 @@ def pairwise_distances_argmin_min(X, Y, axis=1, metric="euclidean",
300325 elif not callable (metric ) and not isinstance (metric , str ):
301326 raise ValueError ("'metric' must be a string or a callable" )
302327
303- X , Y = check_pairwise_arrays (X , Y )
328+ if check_X_y :
329+ X , Y = check_pairwise_arrays (X , Y )
304330
305331 if metric_kwargs is None :
306332 metric_kwargs = {}
0 commit comments