|
| 1 | +import pickle |
1 | 2 | import numpy as np
|
2 | 3 | from numpy.testing import assert_array_almost_equal
|
3 | 4 | from sklearn.neighbors.ball_tree import (BallTree, NeighborsHeap,
|
|
29 | 30 | 'sokalsneath']
|
30 | 31 |
|
31 | 32 |
|
| 33 | +def dist_func(x1, x2, p): |
| 34 | + return np.sum((x1 - x2) ** p) ** (1. / p) |
| 35 | + |
| 36 | + |
32 | 37 | def brute_force_neighbors(X, Y, k, metric, **kwargs):
|
33 | 38 | D = DistanceMetric.get_metric(metric, **kwargs).pairwise(Y, X)
|
34 | 39 | ind = np.argsort(D, axis=1)[:, :k]
|
@@ -216,19 +221,32 @@ def check_two_point(r, dualtree):
|
216 | 221 |
|
217 | 222 |
|
218 | 223 | def test_ball_tree_pickle():
|
219 |
| - import pickle |
220 | 224 | np.random.seed(0)
|
221 | 225 | X = np.random.random((10, 3))
|
| 226 | + |
222 | 227 | bt1 = BallTree(X, leaf_size=1)
|
| 228 | + # Test if BallTree with callable metric is picklable |
| 229 | + bt1_pyfunc = BallTree(X, metric=dist_func, leaf_size=1, p=2) |
| 230 | + |
223 | 231 | ind1, dist1 = bt1.query(X)
|
| 232 | + ind1_pyfunc, dist1_pyfunc = bt1_pyfunc.query(X) |
224 | 233 |
|
225 | 234 | def check_pickle_protocol(protocol):
|
226 | 235 | s = pickle.dumps(bt1, protocol=protocol)
|
227 | 236 | bt2 = pickle.loads(s)
|
| 237 | + |
| 238 | + s_pyfunc = pickle.dumps(bt1_pyfunc, protocol=protocol) |
| 239 | + bt2_pyfunc = pickle.loads(s_pyfunc) |
| 240 | + |
228 | 241 | ind2, dist2 = bt2.query(X)
|
| 242 | + ind2_pyfunc, dist2_pyfunc = bt2_pyfunc.query(X) |
| 243 | + |
229 | 244 | assert_array_almost_equal(ind1, ind2)
|
230 | 245 | assert_array_almost_equal(dist1, dist2)
|
231 | 246 |
|
| 247 | + assert_array_almost_equal(ind1_pyfunc, ind2_pyfunc) |
| 248 | + assert_array_almost_equal(dist1_pyfunc, dist2_pyfunc) |
| 249 | + |
232 | 250 | for protocol in (0, 1, 2):
|
233 | 251 | yield check_pickle_protocol, protocol
|
234 | 252 |
|
|
0 commit comments