8000 [MRG] Save sample_weight_arr instead of sample_weight in KernelDensit… · InterferencePattern/scikit-learn@936a9fa · GitHub
[go: up one dir, main page]

Skip to content
8000

Commit 936a9fa

Browse files
aditya1702jnothman
authored andcommitted
[MRG] Save sample_weight_arr instead of sample_weight in KernelDensity (scikit-learn#13772)
1 parent 319f27a commit 936a9fa

File tree

3 files changed

+54
-35
lines changed

3 files changed

+54
-35
lines changed

doc/whats_new/v0.20.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,13 @@ Changelog
2222
making ``shuffle=True`` ineffective.
2323
:issue:`13124` by :user:`Hanmin Qin <qinhanmin2014>`.
2424

25+
:mod:`sklearn.neighbors`
26+
......................
27+
28+
- |Fix| Fixed a bug in :class:`neighbors.KernelDensity` which could not be
29+
restored from a pickle if ``sample_weight`` had been used.
30+
:issue:`13772` by :user:`Aditya Vyas <aditya1702>`.
31+
2532
.. _changes_0_20_3:
2633

2734
Version 0.20.3

sklearn/neighbors/binary_tree.pxi

Lines changed: 44 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1063,10 +1063,17 @@ cdef class BinaryTree:
10631063

10641064
def __init__(self, data,
10651065
leaf_size=40, metric='minkowski', sample_weight=None, **kwargs):
1066-
self.data_arr = np.asarray(data, dtype=DTYPE, order='C')
1067-
self.data = get_memview_DTYPE_2D(self.data_arr)
1066+
# validate data
1067+
if data.size == 0:
1068+
raise ValueError("X is an empty array")
1069+
1070+
if leaf_size < 1:
1071+
raise ValueError("leaf_size must be greater than or equal to 1")
10681072

1073+
n_samples = data.shape[0]
1074+
n_features = data.shape[1]
10691075

1076+
self.data_arr = np.asarray(data, dtype=DTYPE, order='C')
10701077
self.leaf_size = leaf_size
10711078
self.dist_metric = DistanceMetric.get_metric(metric, **kwargs)
10721079
self.euclidean = (self.dist_metric.__class__.__name__
@@ -1078,26 +1085,6 @@ cdef class BinaryTree:
10781085
'{BinaryTree}'.format(metric=metric,
10791086
**DOC_DICT))
10801087

1081-
# validate data
1082-
if self.data.size == 0:
1083-
raise ValueError("X is an empty array")
1084-
1085-
if leaf_size < 1:
1086-
raise ValueError("leaf_size must be greater than or equal to 1")
1087-
1088-
n_samples = self.data.shape[0]
1089-
n_features = self.data.shape[1]
1090-
1091-
1092-
if sample_weight is not None:
1093-
self.sample_weight_arr = np.asarray(sample_weight, dtype=DTYPE, order='C')
1094-
self.sample_weight = get_memview_DTYPE_1D(self.sample_weight_arr)
1095-
self.sum_weight = np.sum(self.sample_weight)
1096-
else:
1097-
self.sample_weight = None
1098-
self.sum_weight = <DTYPE_t> n_samples
1099-
1100-
11011088
# determine number of levels in the tree, and from this
11021089
# the number of nodes in the tree. This results in leaf nodes
11031090
# with numbers of points between leaf_size and 2 * leaf_size
@@ -1106,15 +1093,34 @@ cdef class BinaryTree:
11061093

11071094
# allocate arrays for storage
11081095
self.idx_array_arr = np.arange(n_samples, dtype=ITYPE)
1109-
self.idx_array = get_memview_ITYPE_1D(self.idx_array_arr)
1110-
11111096
self.node_data_arr = np.zeros(self.n_nodes, dtype=NodeData)
1112-
self.node_data = get_memview_NodeData_1D(self.node_data_arr)
1097+
1098+
self._update_sample_weight(n_samples, sample_weight)
1099+
self._update_memviews()
11131100

11141101
# Allocate tree-specific data
11151102
allocate_data(self, self.n_nodes, n_features)
11161103
self._recursive_build(0, 0, n_samples)
11171104

1105+
def _update_sample_weight(self, n_samples, sample_weight):
1106+
if sample_weight is not None:
1107+
self.sample_weight_arr = np.asarray(
1108+
sample_weight, dtype=DTYPE, order='C')
1109+
self.sample_weight = get_memview_DTYPE_1D(
1110+
self.sample_weight_arr)
1111+
self.sum_weight = np.sum(self.sample_weight)
1112+
else:
1113+
self.sample_weight = None
1114+
self.sample_weight_arr = np.empty(1, dtype=DTYPE, order='C')
1115+
self.sum_weight = <DTYPE_t> n_samples
1116+
1117+
def _update_memviews(self):
1118+
self.data = get_memview_DTYPE_2D(self.data_arr)
1119+
self.idx_array = get_memview_ITYPE_1D(self.idx_array_arr)
1120+
self.node_data = get_memview_NodeData_1D(self.node_data_arr)
1121+
self.node_bounds = get_memview_DTYPE_3D(self.node_bounds_arr)
1122+
1123+
11181124
def __reduce__(self):
11191125
"""
11201126
reduce method used for pickling
@@ -1125,6 +1131,13 @@ cdef class BinaryTree:
11251131
"""
11261132
get state for pickling
11271133
"""
1134+
if self.sample_weight is not None:
1135+
# pass the numpy array
1136+
sample_weight_arr = self.sample_weight_arr
1137+
else:
1138+
# pass None to avoid confusion with the empty place holder
1139+
# of size 1 from __cinit__
1140+
sample_weight_arr = None
11281141
return (self.data_arr,
11291142
self.idx_array_arr,
11301143
self.node_data_arr,
@@ -1137,7 +1150,7 @@ cdef class BinaryTree:
11371150
int(self.n_splits),
11381151
int(self.n_calls),
11391152
self.dist_metric,
1140-
self.sample_weight)
1153+
sample_weight_arr)
11411154

11421155
def __setstate__(self, state):
11431156
"""
@@ -1147,12 +1160,6 @@ cdef class BinaryTree:
11471160
self.idx_array_arr = state[1]
11481161
self.node_data_arr = state[2]
11491162
self.node_bounds_arr = state[3]
1150-
1151-
self.data = get_memview_DTYPE_2D(self.data_arr)
1152-
self.idx_array = get_memview_ITYPE_1D(self.idx_array_arr)
1153-
self.node_data = get_memview_NodeData_1D(self.node_data_arr)
1154-
self.node_bounds = get_memview_DTYPE_3D(self.node_bounds_arr)
1155-
11561163
self.leaf_size = state[4]
11571164
self.n_levels = state[5]
11581165
self.n_nodes = state[6]
@@ -1161,9 +1168,13 @@ cdef class BinaryTree:
11611168
self.n_splits = state[9]
11621169
self.n_calls = state[10]
11631170
self.dist_metric = state[11]
1171+
sample_weight_arr = state[12]
1172+
11641173
self.euclidean = (self.dist_metric.__class__.__name__
11651174
== 'EuclideanDistance')
1166-
self.sample_weight = state[12]
1175+
n_samples = self.data_arr.shape[0]
1176+
self._update_sample_weight(n_samples, sample_weight_arr)
1177+
self._update_memviews()
11671178

11681179
def get_tree_stats(self):
11691180
return (self.n_trims, self.n_leaves, self.n_splits)

sklearn/neighbors/tests/test_kde.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -205,14 +205,15 @@ def test_kde_sample_weights():
205205
assert_allclose(scores_scaled_weight, scores_weight)
206206

207207

208-
def test_pickling(tmpdir):
208+
@pytest.mark.parametrize('sample_weight', [None, [0.1, 0.2, 0.3]])
209+
def test_pickling(tmpdir, sample_weight):
209210
# Make sure that predictions are the same before and after pickling. Used
210211
# to be a bug because sample_weights wasn't pickled and the resulting tree
211212
# would miss some info.
212213

213214
kde = KernelDensity()
214215
data = np.reshape([1., 2., 3.], (-1, 1))
215-
kde.fit(data)
216+
kde.fit(data, sample_weight=sample_weight)
216217

217218
X = np.reshape([1.1, 2.1], (-1, 1))
218219
scores = kde.score_samples(X)

0 commit comments

Comments
 (0)
0