8000 TST Add test to make sure BallTree remains picklable even with callab… · scikit-learn/scikit-learn@b42a148 · GitHub
[go: up one dir, main page]

Skip to content

Commit b42a148

Browse files
committed
TST Add test to make sure BallTree remains picklable even with callable metric
1 parent 593d489 commit b42a148

File tree

1 file changed

+19
-1
lines changed

1 file changed

+19
-1
lines changed

sklearn/neighbors/tests/test_ball_tree.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import pickle
12
import numpy as np
23
from numpy.testing import assert_array_almost_equal
34
from sklearn.neighbors.ball_tree import (BallTree, NeighborsHeap,
@@ -29,6 +30,10 @@
2930
'sokalsneath']
3031

3132

33+
def dist_func(x1, x2, p):
34+
return np.sum((x1 - x2) ** p) ** (1. / p)
35+
36+
3237
def brute_force_neighbors(X, Y, k, metric, **kwargs):
3338
D = DistanceMetric.get_metric(metric, **kwargs).pairwise(Y, X)
3439
ind = np.argsort(D, axis=1)[:, :k]
@@ -216,19 +221,32 @@ def check_two_point(r, dualtree):
216221

217222

218223
def test_ball_tree_pickle():
219-
import pickle
220224
np.random.seed(0)
221225
X = np.random.random((10, 3))
226+
222227
bt1 = BallTree(X, leaf_size=1)
228+
# Test if BallTree with callable metric is picklable
229+
bt1_pyfunc = BallTree(X, metric=dist_func, leaf_size=1, p=2)
230+
223231
ind1, dist1 = bt1.query(X)
232+
ind1_pyfunc, dist1_pyfunc = bt1_pyfunc.query(X)
224233

225234
def check_pickle_protocol(protocol):
226235
s = pickle.dumps(bt1, protocol=protocol)
227236
bt2 = pickle.loads(s)
237+
238+
s_pyfunc = pickle.dumps(bt1_pyfunc, protocol=protocol)
239+
bt2_pyfunc = pickle.loads(s_pyfunc)
240+
228241
ind2, dist2 = bt2.query(X)
242+
ind2_pyfunc, dist2_pyfunc = bt2_pyfunc.query(X)
243+
229244
assert_array_almost_equal(ind1, ind2)
230245
assert_array_almost_equal(dist1, dist2)
231246

247+
assert_array_almost_equal(ind1_pyfunc, ind2_pyfunc)
248+
assert_array_almost_equal(dist1_pyfunc, dist2_pyfunc)
249+
232250
for protocol in (0, 1, 2):
233251
yield check_pickle_protocol, protocol
234252

0 commit comments

Comments
 (0)
0