8000 [MRG] Save sample_weight_arr instead of sample_weight in KernelDensit… · scikit-learn/scikit-learn@36e4fda · GitHub
[go: up one dir, main page]

Skip to content

Commit 36e4fda

Browse files
aditya1702ogrisel
authored andcommitted
[MRG] Save sample_weight_arr instead of sample_weight in KernelDensity (#13772)
1 parent a300e7c commit 36e4fda

File tree

3 files changed

+55
-36
lines changed

3 files changed

+55
-36
lines changed

doc/whats_new/v0.21.rst

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
Version 0.21.1
88
==============
99

10-
**May 2019**
10+
**17 May 2019**
1111

1212

1313
This is a bug-fix release with some minor documentation improvements and
@@ -24,6 +24,13 @@ Changelog
2424
``Y == None``.
2525
:issue:`13864` by :user:`Paresh Mathur <rick2047>`.
2626

27+
:mod:`sklearn.neighbors`
28+
......................
29+
30+
- |Fix| Fixed a bug in :class:`neighbors.KernelDensity` which could not be
31+
restored from a pickle if ``sample_weight`` had been used.
32+
:issue:`13772` by :user:`Aditya Vyas <aditya1702>`.
33+
2734

2835
.. _changes_0_21:
2936

sklearn/neighbors/binary_tree.pxi

Lines changed: 44 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1064,10 +1064,17 @@ cdef class BinaryTree:
10641064

10651065
def __init__(self, data,
10661066
leaf_size=40, metric='minkowski', sample_weight=None, **kwargs):
1067-
self.data_arr = np.asarray(data, dtype=DTYPE, order='C')
1068-
self.data = get_memview_DTYPE_2D(self.data_arr)
1067+
# validate data
1068+
if data.size == 0:
1069+
raise ValueError("X is an empty array")
1070+
1071+
if leaf_size < 1:
1072+
raise ValueError("leaf_size must be greater than or equal to 1")
10691073

1074+
n_samples = data.shape[0]
1075+
n_features = data.shape[1]
10701076

1077+
self.data_arr = np.asarray(data, dtype=DTYPE, order='C')
10711078
self.leaf_size = leaf_size
10721079
self.dist_metric = DistanceMetric.get_metric(metric, **kwargs)
10731080
self.euclidean = (self.dist_metric.__class__.__name__
@@ -1079,26 +1086,6 @@ cdef class BinaryTree:
10791086
'{BinaryTree}'.format(metric=metric,
10801087
**DOC_DICT))
10811088

1082-
# validate data
1083-
if self.data.size == 0:
1084-
raise ValueError("X is an empty array")
1085-
1086-
if leaf_size < 1:
1087-
raise ValueError("leaf_size must be greater than or equal to 1")
1088-
1089-
n_samples = self.data.shape[0]
1090-
n_features = self.data.shape[1]
1091-
1092-
1093-
if sample_weight is not None:
1094-
self.sample_weight_arr = np.asarray(sample_weight, dtype=DTYPE, order='C')
1095-
self.sample_weight = get_memview_DTYPE_1D(self.sample_weight_arr)
1096-
self.sum_weight = np.sum(self.sample_weight)
1097-
else:
1098-
self.sample_weight = None
1099-
self.sum_weight = <DTYPE_t> n_samples
1100-
1101-
11021089
# determine number of levels in the tree, and from this
11031090
# the number of nodes in the tree. This results in leaf nodes
11041091
# with numbers of points between leaf_size and 2 * leaf_size
@@ -1107,15 +1094,34 @@ cdef class BinaryTree:
11071094

11081095
# allocate arrays for storage
11091096
self.idx_array_arr = np.arange(n_samples, dtype=ITYPE)
1110-
self.idx_array = get_memview_ITYPE_1D(self.idx_array_arr)
1111-
11121097
self.node_data_arr = np.zeros(self.n_nodes, dtype=NodeData)
1113-
self.node_data = get_memview_NodeData_1D(self.node_data_arr)
1098+
1099+
self._update_sample_weight(n_samples, sample_weight)
1100+
self._update_memviews()
11141101

11151102
# Allocate tree-specific data
11161103
allocate_data(self, self.n_nodes, n_features)
11171104
self._recursive_build(0, 0, n_samples)
11181105

1106+
def _update_sample_weight(self, n_samples, sample_weight):
1107+
if sample_weight is not None:
1108+
self.sample_weight_arr = np.asarray(
1109+
sample_weight, dtype=DTYPE, order='C')
1110+
self.sample_weight = get_memview_DTYPE_1D(
1111+
self.sample_weight_arr)
1112+
self.sum_weight = np.sum(self.sample_weight)
1113+
else:
1114+
self.sample_weight = None
1115+
self.sample_weight_arr = np.empty(1, dtype=DTYPE, order='C')
1116+
self.sum_weight = <DTYPE_t> n_samples
1117+
1118+
def _update_memviews(self):
1119+
self.data = get_memview_DTYPE_2D(self.data_arr)
1120+
self.idx_array = get_memview_ITYPE_1D(self.idx_array_arr)
1121+
self.node_data = get_memview_NodeData_1D(self.node_data_arr)
1122+
self.node_bounds = get_memview_DTYPE_3D(self.node_bounds_arr)
1123+
1124+
11191125
def __reduce__(self):
11201126
"""
11211127
reduce method used for pickling
@@ -1126,6 +1132,13 @@ cdef class BinaryTree:
11261132
"""
11271133
get state for pickling
11281134
"""
1135+
if self.sample_weight is not None:
1136+
# pass the numpy array
1137+
sample_weight_arr = self.sample_weight_arr
1138+
else:
1139+
# pass None to avoid confusion with the empty place holder
1140+
# of size 1 from __cinit__
1141+
sample_weight_arr = None
11291142
return (self.data_arr,
11301143
self.idx_array_arr,
11311144
self.node_data_arr,
@@ -1138,7 +1151,7 @@ cdef class BinaryTree:
11381151
int(self.n_splits),
11391152
int(self.n_calls),
11401153
self.dist_metric,
1141-
self.sample_weight)
1154+
sample_weight_arr)
11421155

11431156
def __setstate__(self, state):
11441157
"""
@@ -1148,12 +1161,6 @@ cdef class BinaryTree:
11481161
self.idx_array_arr = state[1]
11491162
self.node_data_arr = state[2]
11501163
self.node_bounds_arr = state[3]
1151-
1152-
self.data = get_memview_DTYPE_2D(self.data_arr)
1153-
self.idx_array = get_memview_ITYPE_1D(self.idx_array_arr)
1154-
self.node_data = get_memview_NodeData_1D(self.node_data_arr)
1155-
self.node_bounds = get_memview_DTYPE_3D(self.node_bounds_arr)
1156-
11571164
self.leaf_size = state[4]
11581165
self.n_levels = state[5]
11591166
self.n_nodes = state[6]
@@ -1162,9 +1169,13 @@ cdef class BinaryTree:
11621169
self.n_splits = state[9]
11631170
self.n_calls = state[10]
11641171
self.dist_metric = state[11]
1172+
sample_weight_arr = state[12]
1173+
11651174
self.euclidean = (self.dist_metric.__class__.__name__
11661175
== 'EuclideanDistance')
1167-
self.sample_weight = state[12]
1176+
n_samples = self.data_arr.shape[0]
1177+
self._update_sample_weight(n_samples, sample_weight_arr)
1178+
self._update_memviews()
11681179

11691180
def get_tree_stats(self):
11701181
return (self.n_trims, self.n_leaves, self.n_splits)

sklearn/neighbors/tests/test_kde.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -205,14 +205,15 @@ def test_kde_sample_weights():
205205
assert_allclose(scores_scaled_weight, scores_weight)
206206

207207

208-
def test_pickling(tmpdir):
208+
@pytest.mark.parametrize('sample_weight', [None, [0.1, 0.2, 0.3]])
209+
def test_pickling(tmpdir, sample_weight):
209210
# Make sure that predictions are the same before and after pickling. Used
210211
# to be a bug because sample_weights wasn't pickled and the resulting tree
211212
# would miss some info.
212213

213214
kde = KernelDensity()
214215
data = np.reshape([1., 2., 3.], (-1, 1))
215-
kde.fit(data)
216+
kde.fit(data, sample_weight=sample_weight)
216217

217218
X = np.reshape([1.1, 2.1], (-1, 1))
218219
scores = kde.score_samples(X)

0 commit comments

Comments
 (0)
0