8000 Merge pull request #3811 from MechCoder/fix_repeated_checking · scikit-learn/scikit-learn@3f49cee · GitHub
[go: up one dir, main page]

Skip to content

Commit 3f49cee

Browse files
committed
Merge pull request #3811 from MechCoder/fix_repeated_checking
[MRG] Fix repeated calls of check_pairwise and type casting in pairwise_distances_argmin_min
2 parents 626f672 + b500e3e commit 3f49cee

File tree

2 files changed

+41
-11
lines changed

2 files changed

+41
-11
lines changed

doc/whats_new.rst

+4
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,10 @@ Enhancements
9191
handle unknown categorical features more gracefully during transform.
9292
By `Manoj Kumar`_
9393

94+
- Added option ``check_X_y`` to :func:`metrics.pairwise_distances_argmin_min`
95+
that can give speed improvements by avoiding repeated checking when set to
96+
False. By `Manoj Kumar`_
97+
9498
Documentation improvements
9599
..........................
96100

sklearn/metrics/pairwise.py

+37-11
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,30 @@
5656

5757

5858
# Utility Functions
59+
def _return_float_dtype(X, Y):
60+
"""
61+
1. If dtype of X and Y is float32, then dtype float32 is returned.
62+
2. Else dtype float is returned.
63+
"""
64+
if not issparse(X) and not isinstance(X, np.ndarray):
65+
X = np.asarray(X)
66+
67+
if Y is None:
68+
Y_dtype = X.dtype
69+
elif not issparse(Y) and not isinstance(Y, np.ndarray):
70+
Y = np.asarray(Y)
71+
Y_dtype = Y.dtype
72+
else:
73+
Y_dtype = Y.dtype
74+
75+
if X.dtype == Y_dtype == np.float32:
76+
dtype = np.float32
77+
else:
78+
dtype = np.float
79+
80+
return X, Y, dtype
81+
82+
5983
def check_pairwise_arrays(X, Y):
6084
""" Set X and Y appropriately and checks inputs
6185
@@ -85,22 +109,18 @@ def check_pairwise_arrays(X, Y):
85109
If Y was None, safe_Y will be a pointer to X.
86110
87111
"""
112+
X, Y, dtype = _return_float_dtype(X, Y)
113+
88114
if Y is X or Y is None:
89-
X = Y = check_array(X, accept_sparse='csr')
115+
X = Y = check_array(X, accept_sparse='csr', dtype=dtype)
90116
else:
91-
X = check_array(X, accept_sparse='csr')
92-
Y = check_array(Y, accept_sparse='csr')
117+
X = check_array(X, accept_sparse='csr', dtype=dtype)
118+
Y = check_array(Y, accept_sparse='csr', dtype=dtype)
93119
if X.shape[1] != Y.shape[1]:
94120
raise ValueError("Incompatible dimension for X and Y matrices: "
95121
"X.shape[1] == %d while Y.shape[1] == %d" % (
96122
X.shape[1], Y.shape[1]))
97123

98-
if not (X.dtype == Y.dtype == np.float32):
99-
if Y is X:
100-
X = Y = check_array(X, ['csr', 'csc', 'coo'], dtype=np.float)
101-
else:
102-
X = check_array(X, ['csr', 'csc', 'coo'], dtype=np.float)
103-
Y = check_array(Y, ['csr', 'csc', 'coo'], dtype=np.float)
104124
return X, Y
105125

106126

@@ -225,7 +245,8 @@ def euclidean_distances(X, Y=None, Y_norm_squared=None, squared=False):
225245

226246

227247
def pairwise_distances_argmin_min(X, Y, axis=1, metric="euclidean",
228-
batch_size=500, metric_kwargs=None):
248+
batch_size=500, metric_kwargs=None,
249+
check_X_y=True):
229250
"""Compute minimum distances between one point and a set of points.
230251
231252
This function computes for each row in X, the index of the row of Y which
@@ -280,6 +301,10 @@ def pairwise_distances_argmin_min(X, Y, axis=1, metric="euclidean",
280301
metric_kwargs : dict, optional
281302
Keyword arguments to pass to specified metric function.
282303
304+
check_X_y : bool, default True
305+
Whether or not to check X and y for shape, validity and dtype. Speed
306+
improvements possible if set to False when called repeatedly.
307+
283308
Returns
284309
-------
285310
argmin : numpy.ndarray
@@ -300,7 +325,8 @@ def pairwise_distances_argmin_min(X, Y, axis=1, metric="euclidean",
300325
elif not callable(metric) and not isinstance(metric, str):
301326
raise ValueError("'metric' must be a string or a callable")
302327

303-
X, Y = check_pairwise_arrays(X, Y)
328+
if check_X_y:
329+
X, Y = check_pairwise_arrays(X, Y)
304330

305331
if metric_kwargs is None:
306332
metric_kwargs = {}

0 commit comments

Comments
 (0)
0