8000 FIX Pickled sample_weights in BinaryTree (#11774) · scikit-learn/scikit-learn@d990f72 · GitHub
[go: up one dir, main page]

Skip to content

Commit d990f72

Browse files
NicolasHugjnothman
authored andcommitted
FIX Pickled sample_weights in BinaryTree (#11774)
1 parent 790af8d commit d990f72

File tree

5 files changed

+32
-2
lines changed

5 files changed

+32
-2
lines changed

doc/whats_new/v0.20.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -761,6 +761,10 @@ Support for Python 3.3 has been officially dropped.
761761
faster construction and querying times.
762762
:issue:`11556` by :user:`Jake VanderPlas <jakevdp>`
763763

764+
- |Fix| Fixed a bug in `neighbors.KDTree` and `neighbors.BallTree` where
765+
pickled tree objects would change their type to the super class `BinaryTree`.
766+
:issue:`11774` by :user:`Nicolas Hug <NicolasHug>`.
767+
764768

765769
:mod:`sklearn.neural_network`
766770
.............................

sklearn/neighbors/binary_tree.pxi

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1119,7 +1119,7 @@ cdef class BinaryTree:
11191119
"""
11201120
reduce method used for pickling
11211121
"""
1122-
return (newObj, (BinaryTree,), self.__getstate__())
1122+
return (newObj, (type(self),), self.__getstate__())
11231123

11241124
def __getstate__(self):
11251125
"""
@@ -1136,7 +1136,8 @@ cdef class BinaryTree:
11361136
int(self.n_leaves),
11371137
int(self.n_splits),
11381138
int(self.n_calls),
1139-
self.dist_metric)
1139+
self.dist_metric,
1140+
self.sample_weight)
11401141

11411142
def __setstate__(self, state):
11421143
"""
@@ -1162,6 +1163,7 @@ cdef class BinaryTree:
11621163
self.dist_metric = state[11]
11631164
self.euclidean = (self.dist_metric.__class__.__name__
11641165
== 'EuclideanDistance')
1166+
self.sample_weight = state[12]
11651167

11661168
def get_tree_stats(self):
11671169
return (self.n_trims, self.n_leaves, self.n_splits)

sklearn/neighbors/tests/test_ball_tree.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,8 @@ def check_pickle_protocol(protocol):
228228
assert_array_almost_equal(ind1_pyfunc, ind2_pyfunc)
229229
assert_array_almost_equal(dist1_pyfunc, dist2_pyfunc)
230230

231+
assert isinstance(bt2, BallTree)
232+
231233
for protocol in (0, 1, 2):
232234
check_pickle_protocol(protocol)
233235

sklearn/neighbors/tests/test_kd_tree.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,7 @@ def check_pickle_protocol(protocol):
187187
ind2, dist2 = kdt2.query(X)
188188
assert_array_almost_equal(ind1, ind2)
189189
assert_array_almost_equal(dist1, dist2)
190+
assert isinstance(kdt2, KDTree)
190191

191192
check_pickle_protocol(protocol)
192193

sklearn/neighbors/tests/test_kde.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from sklearn.datasets import make_blobs
1111
from sklearn.model_selection import GridSearchCV
1212
from sklearn.preprocessing import StandardScaler
13+
from sklearn.externals import joblib
1314

1415

1516
def compute_kernel_slow(Y, X, kernel, h):
@@ -202,3 +203,23 @@ def test_kde_sample_weights():
202203
kde.fit(X, sample_weight=(scale_factor * weights))
203204
scores_scaled_weight = kde.score_samples(test_points)
204205
assert_allclose(scores_scaled_weight, scores_weight)
206+
207+
208+
def test_pickling(tmpdir):
209+
# Make sure that predictions are the same before and after pickling. Used
210+
# to be a bug because sample_weights wasn't pickled and the resulting tree
211+
# would miss some info.
212+
213+
kde = KernelDensity()
214+
data = np.reshape([1., 2., 3.], (-1, 1))
215+
kde.fit(data)
216+
217+
X = np.reshape([1.1, 2.1], (-1, 1))
218+
scores = kde.score_samples(X)
219+
220+
file_path = str(tmpdir.join('dump.pkl'))
221+
joblib.dump(kde, file_path)
222+
kde = joblib.load(file_path)
223+
scores_pickled = kde.score_samples(X)
224+
225+
assert_allclose(scores, scores_pickled)

0 commit comments

Comments
 (0)
0