10000 Merge pull request #4085 from jnothman/pairwise_parallel · scikit-learn/scikit-learn@da90a0f · GitHub
[go: up one dir, main page]

Skip to content

Commit da90a0f

Browse files
committed
Merge pull request #4085 from jnothman/pairwise_parallel
[MRG] ENH use parallelism for all metrics in pairwise_{kernels,distances}
2 parents bf203de + 0c9377e commit da90a0f

File tree

3 files changed

+121
-64
lines changed

3 files changed

+121
-64
lines changed

sklearn/metrics/pairwise.py

Lines changed: 53 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,11 @@
66
# Andreas Mueller <amueller@ais.uni-bonn.de>
77
# Philippe Gervais <philippe.gervais@inria.fr>
88
# Lars Buitinck <larsmans@gmail.com>
9+
# Joel Nothman <joel.nothman@gmail.com>
910
# License: BSD 3 clause
1011

12+
import itertools
13+
1114
import numpy as np
1215
from scipy.spatial import distance
1316
from scipy.sparse import csr_matrix
@@ -16,6 +19,7 @@
1619
from ..utils import check_array
1720
from ..utils import gen_even_slices
1821
from ..utils import gen_batches
22+
from ..utils.fixes import partial
1923
from ..utils.extmath import row_norms, safe_sparse_dot
2024
from ..preprocessing import normalize
2125
from ..externals.joblib import Parallel
@@ -951,13 +955,51 @@ def _parallel_pairwise(X, Y, func, n_jobs, **kwds):
951955
if Y is None:
952956
Y = X
953957

958+
if n_jobs == 1:
959+
# Special case to avoid picklability checks in delayed
960+
return func(X, Y, **kwds)
961+
962+
# TODO: in some cases, backend='threading' may be appropriate
963+
fd = delayed(func)
954964
ret = Parallel(n_jobs=n_jobs, verbose=0)(
955-
delayed(func)(X, Y[s], **kwds)
965+
fd(X, Y[s], **kwds)
956966
for s in gen_even_slices(Y.shape[0], n_jobs))
957967

958968
return np.hstack(ret)
959969

960970

971+
def _pairwise_callable(X, Y, metric, **kwds):
972+
"""Handle the callable case for pairwise_{distances,kernels}
973+
"""
974+
X, Y = check_pairwise_arrays(X, Y)
975+
976+
if X is Y:
977+
# Only calculate metric for upper triangle
978+
out = np.zeros((X.shape[0], Y.shape[0]), dtype='float')
979+
iterator = itertools.combinations(range(X.shape[0]), 2)
980+
for i, j in iterator:
981+
out[i, j] = metric(X[i], Y[j], **kwds)
982+
983+
# Make symmetric
984+
# NB: out += out.T will produce incorrect results
985+
out = out + out.T
986+
987+
# Calculate diagonal
988+
# NB: nonzero diagonals are allowed for both metrics and kernels
989+
for i in range(X.shape[0]):
990+
x = X[i]
991+
out[i, i] = metric(x, x, **kwds)
992+
993+
else:
994+
# Calculate all cells
995+
out = np.empty((X.shape[0], Y.shape[0]), dtype='float')
996+
iterator = itertools.product(range(X.shape[0]), range(Y.shape[0]))
997+
for i, j in iterator:
998+
out[i, j] = metric(X[i], Y[j], **kwds)
999+
1000+
return out
1001+
1002+
9611003
_VALID_METRICS = ['euclidean', 'l2', 'l1', 'manhattan', 'cityblock',
9621004
'braycurtis', 'canberra', 'chebyshev', 'correlation',
9631005
'cosine', 'dice', 'hamming', 'jaccard', 'kulsinski',
@@ -1053,41 +1095,19 @@ def pairwise_distances(X, Y=None, metric="euclidean", n_jobs=1, **kwds):
10531095
return X
10541096
elif metric in PAIRWISE_DISTANCE_FUNCTIONS:
10551097
func = PAIRWISE_DISTANCE_FUNCTIONS[metric]
1056-
if n_jobs == 1:
1057-
return func(X, Y, **kwds)
1058-
else:
1059-
return _parallel_pairwise(X, Y, func, n_jobs, **kwds)
10601098
elif callable(metric):
1061-
# Check matrices first (this is usually done by the metric).
1062-
X, Y = check_pairwise_arrays(X, Y)
1063-
n_x, n_y = X.shape[0], Y.shape[0]
1064-
# Calculate distance for each element in X and Y.
1065-
# FIXME: can use n_jobs here too
1066-
# FIXME: np.zeros can be replaced by np.empty
1067-
D = np.zeros((n_x, n_y), dtype='float')
1068-
for i in range(n_x):
1069-
start = 0
1070-
if X is Y:
1071-
start = i
1072-
for j in range(start, n_y):
1073-
# distance assumed to be symmetric.
1074-
D[i][j] = metric(X[i], Y[j], **kwds)
1075-
if X is Y:
1076-
D[j][i] = D[i][j]
1077-
return D
1099+
func = partial(_pairwise_callable, metric=metric, **kwds)
10781100
else:
1079-
# Note: the distance module doesn't support sparse matrices!
1080-
if type(X) is csr_matrix:
1101+
if issparse(X) or issparse(Y):
10811102
raise TypeError("scipy distance metrics do not"
10821103
" support sparse matrices.")
1083-
if Y is None:
1104+
X, Y = check_pairwise_arrays(X, Y)
1105+
if n_jobs == 1 and X is Y:
10841106
return distance.squareform(distance.pdist(X, metric=metric,
10851107
**kwds))
1086-
else:
1087-
if type(Y) is csr_matrix:
1088-
raise TypeError("scipy distance metrics do not"
1089-
" support sparse matrices.")
1090-
return distance.cdist(X, Y, metric=metric, **kwds)
1108+
func = partial(distance.cdist, metric=metric, **kwds)
1109+
1110+
return _parallel_pairwise(X, Y, func, n_jobs, **kwds)
10911111

10921112

10931113
# Helper functions - distance
@@ -1214,25 +1234,9 @@ def pairwise_kernels(X, Y=None, metric="linear", filter_params=False,
12141234
kwds = dict((k, kwds[k]) for k in kwds
12151235
if k in KERNEL_PARAMS[metric])
12161236
func = PAIRWISE_KERNEL_FUNCTIONS[metric]
1217-
if n_jobs == 1:
1218-
return func(X, Y, **kwds)
1219-
else:
1220-
return _parallel_pairwise(X, Y, func, n_jobs, **kwds)
12211237
elif callable(metric):
1222-
# Check matrices first (this is usually done by the metric).
1223-
X, Y = check_pairwise_arrays(X, Y)
1224-
n_x, n_y = X.shape[0], Y.shape[0]
1225-
# Calculate kernel for each element in X and Y.
1226-
K = np.zeros((n_x, n_y), dtype='float')
1227-
for i in range(n_x):
1228-
start = 0
1229-
if X is Y:
1230-
start = i
1231-
for j in range(start, n_y):
1232-
# Kernel assumed to be symmetric.
1233-
K[i][j] = metric(X[i], Y[j], **kwds)
1234-
if X is Y:
1235-
K[j][i] = K[i][j]
1236-
return K
1238+
func = partial(_pairwise_callable, metric=metric, **kwds)
12371239
else:
12381240
raise ValueError("Unknown kernel %r" % metric)
1241+
1242+
return _parallel_pairwise(X, Y, func, n_jobs, **kwds)

sklearn/metrics/tests/test_pairwise.py

Lines changed: 47 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from numpy import linalg
33

44
from scipy.sparse import dok_matrix, csr_matrix, issparse
5-
from scipy.spatial.distance import cosine, cityblock, minkowski
5+
from scipy.spatial.distance import cosine, cityblock, minkowski, wminkowski
66

77
from sklearn.utils.testing import assert_greater
88
from sklearn.utils.testing import assert_array_almost_equal
@@ -118,29 +118,61 @@ def test_pairwise_distances():
118118
assert_raises(ValueError, pairwise_distances, X, Y, metric="blah")
119119

120120

121-
def test_pairwise_parallel():
121+
def check_pairwise_parallel(func, metric, kwds):
122122
rng = np.random.RandomState(0)
123-
for func in (np.array, csr_matrix):
124-
X = func(rng.random_sample((5, 4)))
125-
Y = func(rng.random_sample((3, 4)))
126-
127-
S = euclidean_distances(X)
128-
S2 = _parallel_pairwise(X, None, euclidean_distances, n_jobs=3)
123+
for make_data in (np.array, csr_matrix):
124+
X = make_data(rng.random_sample((5, 4)))
125+
Y = make_data(rng.random_sample((3, 4)))
126+
127+
try:
128+
S = func(X, metric=metric, n_jobs=1, **kwds)
129+
except (TypeError, ValueError) as exc:
130+
# Not all metrics support sparse input
131+
# ValueError may be triggered by bad callable
132+
if make_data is csr_matrix:
133+
assert_raises(type(exc), func, X, metric=metric,
134+
n_jobs=2, **kwds)
135+
continue
136+
else:
137+
raise
138+
S2 = func(X, metric=metric, n_jobs=2, **kwds)
129139
assert_array_almost_equal(S, S2)
130140

131-
S = euclidean_distances(X, Y)
132-
S2 = _parallel_pairwise(X, Y, euclidean_distances, n_jobs=3)
141+
S = func(X, Y, metric=metric, n_jobs=1, **kwds)
142+
S2 = func(X, Y, metric=metric, n_jobs=2, **kwds)
133143
assert_array_almost_equal(S, S2)
134144

135145

146+
def test_pairwise_parallel():
147+
wminkowski_kwds = {'w': np.arange(1, 5).astype('double'), 'p': 1}
148+
metrics = [(pairwise_distances, 'euclidean', {}),
149+
(pairwise_distances, wminkowski, wminkowski_kwds),
150+
(pairwise_distances, 'wminkowski', wminkowski_kwds),
151+
(pairwise_kernels, 'polynomial', {'degree': 1}),
152+
(pairwise_kernels, callable_rbf_kernel, {'gamma': .1}),
153+
]
154+
for func, metric, kwds in metrics:
155+
yield check_pairwise_parallel, func, metric, kwds
156+
157+
158+
def test_pairwise_callable_nonstrict_metric():
159+
"""paired_distances should allow callable metric where metric(x, x) != 0
160+
161+
Knowing that the callable is a strict metric would allow the diagonal to
162+
be left uncalculated and set to 0.
163+
"""
164+
assert_equal(pairwise_distances([[1]], metric=lambda x, y: 5)[0, 0], 5)
165+
166+
167+
def callable_rbf_kernel(x, y, **kwds):
168+
""" Callable version of pairwise.rbf_kernel. """
169+
K = rbf_kernel(np.atleast_2d(x), np.atleast_2d(y), **kwds)
170+
return K
171+
172+
136173
def test_pairwise_kernels():
137174
""" Test the pairwise_kernels helper function. """
138175

139-
def callable_rbf_kernel(x, y, **kwds):
140-
""" Callable version of pairwise.rbf_kernel. """
141-
K = rbf_kernel(np.atleast_2d(x), np.atleast_2d(y), **kwds)
142-
return K
143-
144176
rng = np.random.RandomState(0)
145177
X = rng.random_sample((5, 4))
146178
Y = rng.random_sample((2, 4))

sklearn/utils/fixes.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212

1313
import inspect
1414
import warnings
15+
import sys
16+
import functools
1517

1618
import numpy as np
1719
import scipy.sparse as sp
@@ -314,3 +316,22 @@ def in1d(ar1, ar2, assume_unique=False, invert=False):
314316
from ._scipy_sparse_lsqr_backport import lsqr as sparse_lsqr
315317
else:
316318
from scipy.sparse.linalg import lsqr as sparse_lsqr
319+
320+
321+
if sys.version_info < (2, 7, 0):
322+
# partial cannot be pickled in Python 2.6
323+
# http://bugs.python.org/issue1398
324+
class partial(object):
325+
def __init__(self, func, *args, **keywords):
326+
functools.update_wrapper(self, func)
327+
self.func = func
328+
self.args = args
329+
self.keywords = keywords
330+
331+
def __call__(self, *args, **keywords):
332+
args = self.args + args
333+
kwargs = self.keywords.copy()
334+
kwargs.update(keywords)
335+
return self.func(*args, **kwargs)
336+
else:
337+
from functools import partial

0 commit comments

Comments
 (0)
0