8000 Fix #6420 Cloning decision tree estimators breaks criterion objects · scikit-learn/scikit-learn@da5a85b · GitHub
[go: up one dir, main page]

Skip to content

Commit da5a85b

Browse files
committed
Fix #6420 Cloning decision tree estimators breaks criterion objects
1 parent 5e4f524 commit da5a85b

File tree

6 files changed

+46
-8
lines changed

6 files changed

+46
-8
lines changed

doc/whats_new.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,10 @@ Bug fixes
5656
<https://github.com/scikit-learn/scikit-learn/pull/6178>`_) by `Bertrand
5757
Thirion`_
5858

59+
- Tree splitting criterion classes cloning/pickling are now memory safe
60+
(`#7680 <https://github.com/scikit-learn/scikit-learn/pull/7680>`_).
61+
By `Ibraim Ganiev`_.
62+
5963
.. _changes_0_18_1:
6064

6165
Version 0.18.1

sklearn/tree/_criterion.pxd

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ cdef class Criterion:
3434
cdef SIZE_t end
3535

3636
cdef SIZE_t n_outputs # Number of outputs
37+
cdef SIZE_t n_samples # Number of samples
3738
cdef SIZE_t n_node_samples # Number of samples in the node (end-start)
3839
cdef double weighted_n_samples # Weighted number of samples (in total)
3940
cdef double weighted_n_node_samples # Weighted number of samples in the node

sklearn/tree/_criterion.pyx

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,7 @@ cdef class ClassificationCriterion(Criterion):
235235
self.end = 0
236236

237237
self.n_outputs = n_outputs
238+
self.n_samples = 0
238239
self.n_node_samples = 0
239240
self.weighted_n_node_samples = 0.0
240241
self.weighted_n_left = 0.0
@@ -273,11 +274,10 @@ cdef class ClassificationCriterion(Criterion):
273274

274275
def __dealloc__(self):
275276
"""Destructor."""
276-
277277
free(self.n_classes)
278278

279279
def __reduce__(self):
280-
return (ClassificationCriterion,
280+
return (type(self),
281281
(self.n_outputs,
282282
sizet_ptr_to_ndarray(self.n_classes, self.n_outputs)),
283283
self.__getstate__())
@@ -710,6 +710,7 @@ cdef class RegressionCriterion(Criterion):
710710
self.end = 0
711711

712712
self.n_outputs = n_outputs
713+
self.n_samples = n_samples
713714
self.n_node_samples = 0
714715
self.weighted_n_node_samples = 0.0
715716
self.weighted_n_left = 0.0
@@ -734,7 +735,7 @@ cdef class RegressionCriterion(Criterion):
734735
raise MemoryError()
735736

736737
def __reduce__(self):
737-
return (RegressionCriterion, (self.n_outputs,), self.__getstate__())
738+
return (type(self), (self.n_outputs, self.n_samples), self.__getstate__())
738739

739740
cdef void init(self, DOUBLE_t* y, SIZE_t y_stride, DOUBLE_t* sample_weight,
740741
double weighted_n_samples, SIZE_t* samples, SIZE_t start,
@@ -881,6 +882,7 @@ cdef class MSE(RegressionCriterion):
881882
882883
MSE = var_left + var_right
883884
"""
885+
884886
cdef double node_impurity(self) nogil:
885887
"""Evaluate the impurity of the current node, i.e. the impurity of
886888
samples[start:end]."""
@@ -1004,6 +1006,7 @@ cdef class MAE(RegressionCriterion):
10041006
self.end = 0
10051007

10061008
self.n_outputs = n_outputs
1009+
self.n_samples = n_samples
10071010
self.n_node_samples = 0
10081011
self.weighted_n_node_samples = 0.0
10091012
self.weighted_n_left = 0.0

sklearn/tree/_tree.pyx

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -547,8 +547,7 @@ cdef class Tree:
547547
# (i.e. through `_resize` or `__setstate__`)
548548
property n_classes:
549549
def __get__(self):
550-
# it's small; copy for memory safety
551-
return sizet_ptr_to_ndarray(self.n_classes, self.n_outputs).copy()
550+
return sizet_ptr_to_ndarray(self.n_classes, self.n_outputs)
552551

553552
property children_left:
554553
def __get__(self):

sklearn/tree/_utils.pyx

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,10 +62,10 @@ cdef inline UINT32_t our_rand_r(UINT32_t* seed) nogil:
6262

6363

6464
cdef inline np.ndarray sizet_ptr_to_ndarray(SIZE_t* data, SIZE_t size):
65-
"""Encapsulate data into a 1D numpy array of intp's."""
65+
"""Return copied data as 1D numpy array of intp's."""
6666
cdef np.npy_intp shape[1]
6767
shape[0] = <np.npy_intp> size
68-
return np.PyArray_SimpleNewFromData(1, shape, np.NPY_INTP, data)
68+
return np.PyArray_SimpleNewFromData(1, shape, np.NPY_INTP, data).copy()
6969

7070

7171
cdef inline SIZE_t rand_int(SIZE_t low, SIZE_t high,

sklearn/tree/tests/test_tree.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""
22
Testing for the tree module (sklearn.tree).
33
"""
4+
import copy
45
import pickle
56
from functools import partial
67
from itertools import product
@@ -42,12 +43,14 @@
4243

4344
from sklearn import tree
4445
from sklearn.tree._tree import TREE_LEAF
46+
from sklearn.tree.tree import CRITERIA_CLF
47+
from sklearn.tree.tree import CRITERIA_REG
4548
from sklearn import datasets
4649

4750
from sklearn.utils import compute_sample_weight
4851

4952
CLF_CRITERIONS = ("gini", "entropy")
50-
REG_CRITERIONS = ("mse", "mae")
53+
REG_CRITERIONS = ("mse", "mae", "friedman_mse")
5154

5255
CLF_TREES = {
5356
"DecisionTreeClassifier": DecisionTreeClassifier,
@@ -1597,6 +1600,7 @@ def test_no_sparse_y_support():
15971600
for name in ALL_TREES:
15981601
yield (check_no_sparse_y_support, name)
15991602

1603+
16001604
def test_mae():
16011605
# check MAE criterion produces correct results
16021606
# on small toy dataset
@@ -1609,3 +1613,30 @@ def test_mae():
16091613
dt_mae.fit([[3],[5],[3],[8],[5]],[6,7,3,4,3], [0.6,0.3,0.1,1.0,0.3])
16101614
assert_array_equal(dt_mae.tree_.impurity, [7.0/2.3, 3.0/0.7, 4.0/1.6])
16111615
assert_array_equal(dt_mae.tree_.value.flat, [4.0, 6.0, 4.0])
1616+
1617+
1618+
def test_criterion_copy():
1619+
# Let's check whether copy of our criterion has the same type
1620+
# and properties as original
1621+
n_outputs = 3
1622+
n_classes = np.arange(3, dtype=np.intp)
1623+
n_samples = 100
1624+
1625+
def _pickle_copy(obj):
1626+
return pickle.loads(pickle.dumps(obj))
1627+
for copy_func in [copy.copy, copy.deepcopy, _pickle_copy]:
1628+
for _, typename in CRITERIA_CLF.items():
1629+
criteria = typename(n_outputs, n_classes)
1630+
result = copy_func(criteria).__reduce__()
1631+
typename_, (n_outputs_, n_classes_), _ = result
1632+
assert_equal(typename, typename_)
1633+
assert_equal(n_outputs, n_outputs_)
1634+
assert_array_equal(n_classes, n_classes_)
1635+
1636+
for _, typename in CRITERIA_REG.items():
1637+
criteria = typename(n_outputs, n_samples)
1638+
result = copy_func(criteria).__reduce__()
1639+
typename_, (n_outputs_, n_samples_), _ = result
1640+
assert_equal(typename, typename_)
1641+
assert_equal(n_outputs, n_outputs_)
1642+
assert_equal(n_samples, n_samples_)

0 commit comments

Comments
 (0)
0