|
6 | 6 | # Andreas Mueller <amueller@ais.uni-bonn.de>
|
7 | 7 | # Philippe Gervais <philippe.gervais@inria.fr>
|
8 | 8 | # Lars Buitinck <larsmans@gmail.com>
|
| 9 | +# Joel Nothman <joel.nothman@gmail.com> |
9 | 10 | # License: BSD 3 clause
|
10 | 11 |
|
| 12 | +import itertools |
| 13 | + |
11 | 14 | import numpy as np
|
12 | 15 | from scipy.spatial import distance
|
13 | 16 | from scipy.sparse import csr_matrix
|
|
16 | 19 | from ..utils import check_array
|
17 | 20 | from ..utils import gen_even_slices
|
18 | 21 | from ..utils import gen_batches
|
| 22 | +from ..utils.fixes import partial |
19 | 23 | from ..utils.extmath import row_norms, safe_sparse_dot
|
20 | 24 | from ..preprocessing import normalize
|
21 | 25 | from ..externals.joblib import Parallel
|
@@ -951,13 +955,51 @@ def _parallel_pairwise(X, Y, func, n_jobs, **kwds):
|
951 | 955 | if Y is None:
|
952 | 956 | Y = X
|
953 | 957 |
|
| 958 | + if n_jobs == 1: |
| 959 | + # Special case to avoid picklability checks in delayed |
| 960 | + return func(X, Y, **kwds) |
| 961 | + |
| 962 | + # TODO: in some cases, backend='threading' may be appropriate |
| 963 | + fd = delayed(func) |
954 | 964 | ret = Parallel(n_jobs=n_jobs, verbose=0)(
|
955 |
| - delayed(func)(X, Y[s], **kwds) |
| 965 | + fd(X, Y[s], **kwds) |
956 | 966 | for s in gen_even_slices(Y.shape[0], n_jobs))
|
957 | 967 |
|
958 | 968 | return np.hstack(ret)
|
959 | 969 |
|
960 | 970 |
|
| 971 | +def _pairwise_callable(X, Y, metric, **kwds): |
| 972 | + """Handle the callable case for pairwise_{distances,kernels} |
| 973 | + """ |
| 974 | + X, Y = check_pairwise_arrays(X, Y) |
| 975 | + |
| 976 | + if X is Y: |
| 977 | + # Only calculate metric for upper triangle |
| 978 | + out = np.zeros((X.shape[0], Y.shape[0]), dtype='float') |
| 979 | + iterator = itertools.combinations(range(X.shape[0]), 2) |
| 980 | + for i, j in iterator: |
| 981 | + out[i, j] = metric(X[i], Y[j], **kwds) |
| 982 | + |
| 983 | + # Make symmetric |
| 984 | + # NB: out += out.T will produce incorrect results |
| 985 | + out = out + out.T |
| 986 | + |
| 987 | + # Calculate diagonal |
| 988 | + # NB: nonzero diagonals are allowed for both metrics and kernels |
| 989 | + for i in range(X.shape[0]): |
| 990 | + x = X[i] |
| 991 | + out[i, i] = metric(x, x, **kwds) |
| 992 | + |
| 993 | + else: |
| 994 | + # Calculate all cells |
| 995 | + out = np.empty((X.shape[0], Y.shape[0]), dtype='float') |
| 996 | + iterator = itertools.product(range(X.shape[0]), range(Y.shape[0])) |
| 997 | + for i, j in iterator: |
| 998 | + out[i, j] = metric(X[i], Y[j], **kwds) |
| 999 | + |
| 1000 | + return out |
| 1001 | + |
| 1002 | + |
961 | 1003 | _VALID_METRICS = ['euclidean', 'l2', 'l1', 'manhattan', 'cityblock',
|
962 | 1004 | 'braycurtis', 'canberra', 'chebyshev', 'correlation',
|
963 | 1005 | 'cosine', 'dice', 'hamming', 'jaccard', 'kulsinski',
|
@@ -1053,41 +1095,19 @@ def pairwise_distances(X, Y=None, metric="euclidean", n_jobs=1, **kwds):
|
1053 | 1095 | return X
|
1054 | 1096 | elif metric in PAIRWISE_DISTANCE_FUNCTIONS:
|
1055 | 1097 | func = PAIRWISE_DISTANCE_FUNCTIONS[metric]
|
1056 |
| - if n_jobs == 1: |
1057 |
| - return func(X, Y, **kwds) |
1058 |
| - else: |
1059 |
| - return _parallel_pairwise(X, Y, func, n_jobs, **kwds) |
1060 | 1098 | elif callable(metric):
|
1061 |
| - # Check matrices first (this is usually done by the metric). |
1062 |
| - X, Y = check_pairwise_arrays(X, Y) |
1063 |
| - n_x, n_y = X.shape[0], Y.shape[0] |
1064 |
| - # Calculate distance for each element in X and Y. |
1065 |
| - # FIXME: can use n_jobs here too |
1066 |
| - # FIXME: np.zeros can be replaced by np.empty |
1067 |
| - D = np.zeros((n_x, n_y), dtype='float') |
1068 |
| - for i in range(n_x): |
1069 |
| - start = 0 |
1070 |
| - if X is Y: |
1071 |
| - start = i |
1072 |
| - for j in range(start, n_y): |
1073 |
| - # distance assumed to be symmetric. |
1074 |
| - D[i][j] = metric(X[i], Y[j], **kwds) |
1075 |
| - if X is Y: |
1076 |
| - D[j][i] = D[i][j] |
1077 |
| - return D |
| 1099 | + func = partial(_pairwise_callable, metric=metric, **kwds) |
1078 | 1100 | else:
|
1079 |
| - # Note: the distance module doesn't support sparse matrices! |
1080 |
| - if type(X) is csr_matrix: |
| 1101 | + if issparse(X) or issparse(Y): |
1081 | 1102 | raise TypeError("scipy distance metrics do not"
|
1082 | 1103 | " support sparse matrices.")
|
1083 |
| - if Y is None: |
| 1104 | + X, Y = check_pairwise_arrays(X, Y) |
| 1105 | + if n_jobs == 1 and X is Y: |
1084 | 1106 | return distance.squareform(distance.pdist(X, metric=metric,
|
1085 | 1107 | **kwds))
|
1086 |
| - else: |
1087 |
| - if type(Y) is csr_matrix: |
1088 |
| - raise TypeError("scipy distance metrics do not" |
1089 |
| - " support sparse matrices.") |
1090 |
| - return distance.cdist(X, Y, metric=metric, **kwds) |
| 1108 | + func = partial(distance.cdist, metric=metric, **kwds) |
| 1109 | + |
| 1110 | + return _parallel_pairwise(X, Y, func, n_jobs, **kwds) |
1091 | 1111 |
|
1092 | 1112 |
|
1093 | 1113 | # Helper functions - distance
|
@@ -1214,25 +1234,9 @@ def pairwise_kernels(X, Y=None, metric="linear", filter_params=False,
|
1214 | 1234 | kwds = dict((k, kwds[k]) for k in kwds
|
1215 | 1235 | if k in KERNEL_PARAMS[metric])
|
1216 | 1236 | func = PAIRWISE_KERNEL_FUNCTIONS[metric]
|
1217 |
| - if n_jobs == 1: |
1218 |
| - return func(X, Y, **kwds) |
1219 |
| - else: |
1220 |
| - return _parallel_pairwise(X, Y, func, n_jobs, **kwds) |
1221 | 1237 | elif callable(metric):
|
1222 |
| - # Check matrices first (this is usually done by the metric). |
1223 |
| - X, Y = check_pairwise_arrays(X, Y) |
1224 |
| - n_x, n_y = X.shape[0], Y.shape[0] |
1225 |
| - # Calculate kernel for each element in X and Y. |
1226 |
| - K = np.zeros((n_x, n_y), dtype='float') |
1227 |
| - for i in range(n_x): |
1228 |
| - start = 0 |
1229 |
| - if X is Y: |
1230 |
| - start = i |
1231 |
| - for j in range(start, n_y): |
1232 |
| - # Kernel assumed to be symmetric. |
1233 |
| - K[i][j] = metric(X[i], Y[j], **kwds) |
1234 |
| - if X is Y: |
1235 |
| - K[j][i] = K[i][j] |
1236 |
| - return K |
| 1238 | + func = partial(_pairwise_callable, metric=metric, **kwds) |
1237 | 1239 | else:
|
1238 | 1240 | raise ValueError("Unknown kernel %r" % metric)
|
| 1241 | + |
| 1242 | + return _parallel_pairwise(X, Y, func, n_jobs, **kwds) |
0 commit comments