8000 [MRG+1] feature: add beta-threshold early stopping for decision tree … · scikit-learn/scikit-learn@376aa50 · GitHub
[go: up one dir, main page]

Skip to content

Commit 376aa50

Browse files
nelson-liuglouppe
authored andcommitted
[MRG+1] feature: add beta-threshold early stopping for decision tree growth (#6954)
* feature: add beta-threshold early stopping for decision tree growth * check if value of beta is greater than or equal to 0 * test if default value of beta is 0 and edit input validation error message * feature: separately validate beta for reg. and clf., and add tests for it * feature: add beta to forest-based ensemble methods * feature: add separate condition to determine that beta is float * feature: add beta to gradient boosting estimators * rename parameter to min_impurity_split, edit input validation and associated tests * chore: fix spacing in forest and force recompilation of grad boosting extension * remove trivial comment in grad boost and add whats new * edit wording in test comment / rebuild * rename constant with the same name as our parameter * edit line length for what's new * remove constant and set min_impurity_split to 1e-7 by default * fix docstrings for new default * fix defaults in gradientboosting and forest classes
1 parent d829091 commit 376aa50

File tree

7 files changed

+146
-18
lines changed

7 files changed

+146
-18
lines changed

doc/whats_new.rst

+5
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,11 @@ New features
125125
<https://github.com/scikit-learn/scikit-learn/pull/6667>`_) by `Nelson
126126
Liu`_.
127127

128+
- Added weighted impurity-based early stopping criterion for decision tree
129+
growth. (`#6954
130+
<https://github.com/scikit-learn/scikit-learn/pull/6954>`_) by `Nelson
131+
Liu`_
132+
128133
Enhancements
129134
............
130135

sklearn/ensemble/forest.py

+35-5
Original file line numberDiff line numberDiff line change
@@ -805,6 +805,10 @@ class RandomForestClassifier(ForestClassifier):
805805
If None then unlimited number of leaf nodes.
806806
If not None then ``max_depth`` will be ignored.
807807
808+
min_impurity_split : float, optional (default=1e-7)
809+
Threshold for early stopping in tree growth. A node will split
810+
if its impurity is above the threshold, otherwise it is a leaf.
811+
808812
bootstrap : boolean, optional (default=True)
809813
Whether bootstrap samples are used when building trees.
810814
@@ -899,6 +903,7 @@ def __init__(self,
899903
min_weight_fraction_leaf=0.,
900904
max_features="auto",
901905
max_leaf_nodes=None,
906+
min_impurity_split=1e-7,
902907
bootstrap=True,
903908
oob_score=False,
904909
n_jobs=1,
@@ -911,7 +916,7 @@ def __init__(self,
911916
n_estimators=n_estimators,
912917
estimator_params=("criterion", "max_depth", "min_samples_split",
913918
"min_samples_leaf", "min_weight_fraction_leaf",
914-
"max_features", "max_leaf_nodes",
919+
"max_features", "max_leaf_nodes", "min_impurity_split",
915920
"random_state"),
916921
bootstrap=bootstrap,
917922
oob_score=oob_score,
@@ -928,6 +933,7 @@ def __init__(self,
928933
self.min_weight_fraction_leaf = min_weight_fraction_leaf
929934
self.max_features = max_features
930935
self.max_leaf_nodes = max_leaf_nodes
936+
self.min_impurity_split = min_impurity_split
931937

932938

933939
class RandomForestRegressor(ForestRegressor):
@@ -1001,6 +1007,10 @@ class RandomForestRegressor(ForestRegressor):
10011007
If None then unlimited number of leaf nodes.
10021008
If not None then ``max_depth`` will be ignored.
10031009
1010+
min_impurity_split : float, optional (default=1e-7)
1011+
Threshold for early stopping in tree growth. A node will split
1012+
if its impurity is above the threshold, otherwise it is a leaf.
1013+
10041014
bootstrap : boolean, optional (default=True)
10051015
Whether bootstrap samples are used when building trees.
10061016
@@ -1064,6 +1074,7 @@ def __init__(self,
10641074
min_weight_fraction_leaf=0.,
10651075
max_features="auto",
10661076
max_leaf_nodes=None,
1077+
min_impurity_split=1e-7,
10671078
bootstrap=True,
10681079
oob_score=False,
10691080
n_jobs=1,
@@ -1075,7 +1086,7 @@ def __init__(self,
10751086
n_estimators=n_estimators,
10761087
estimator_params=("criterion", "max_depth", "min_samples_split",
10771088
"min_samples_leaf", "min_weight_fraction_leaf",
1078-
"max_features", "max_leaf_nodes",
1089+
"max_features", "max_leaf_nodes", "min_impurity_split",
10791090
"random_state"),
10801091
bootstrap=bootstrap,
10811092
oob_score=oob_score,
@@ -1091,6 +1102,7 @@ def __init__(self,
10911102
self.min_weight_fraction_leaf = min_weight_fraction_leaf
10921103
self.max_features = max_features
10931104
self.max_leaf_nodes = max_leaf_nodes
1105+
self.min_impurity_split = min_impurity_split
10941106

10951107

10961108
class ExtraTreesClassifier(ForestClassifier):
@@ -1160,6 +1172,10 @@ class ExtraTreesClassifier(ForestClassifier):
11601172
If None then unlimited number of leaf nodes.
11611173
If not None then ``max_depth`` will be ignored.
11621174
1175+
min_impurity_split : float, optional (default=1e-7)
1176+
Threshold for early stopping in tree growth. A node will split
1177+
if its impurity is above the threshold, otherwise it is a leaf.
1178+
11631179
bootstrap : boolean, optional (default=False)
11641180
Whether bootstrap samples are used when building trees.
11651181
@@ -1255,6 +1271,7 @@ def __init__(self,
12551271
min_weight_fraction_leaf=0.,
12561272
max_features="auto",
12571273
max_leaf_nodes=None,
1274+
min_impurity_split=1e-7,
12581275
bootstrap=False,
12591276
oob_score=False,
12601277
n_jobs=1,
@@ -1267,7 +1284,7 @@ def __init__(self,
12671284
n_estimators=n_estimators,
12681285
estimator_params=("criterion", "max_depth", "min_samples_split",
12691286
"min_samples_leaf", "min_weight_fraction_leaf",
1270-
"max_features", "max_leaf_nodes",
1287+
"max_features", "max_leaf_nodes", "min_impurity_split",
12711288
"random_state"),
12721289
bootstrap=bootstrap,
12731290
oob_score=oob_score,
@@ -1284,6 +1301,7 @@ def __init__(self,
12841301
self.min_weight_fraction_leaf = min_weight_fraction_leaf
12851302
self.max_features = max_features
12861303
self.max_leaf_nodes = max_leaf_nodes
1304+
self.min_impurity_split = min_impurity_split
12871305

12881306

12891307
class ExtraTreesRegressor(ForestRegressor):
@@ -1355,6 +1373,10 @@ class ExtraTreesRegressor(ForestRegressor):
13551373
If None then unlimited number of leaf nodes.
13561374
If not None then ``max_depth`` will be ignored.
13571375
1376+
min_impurity_split : float, optional (default=1e-7)
1377+
Threshold for early stopping in tree growth. A node will split
1378+
if its impurity is above the threshold, otherwise it is a leaf.
1379+
13581380
bootstrap : boolean, optional (default=False)
13591381
Whether bootstrap samples are used when building trees.
13601382
@@ -1419,6 +1441,7 @@ def __init__(self,
14191441
min_weight_fraction_leaf=0.,
14201442
max_features="auto",
14211443
max_leaf_nodes=None,
1444+
min_impurity_split=1e-7,
14221445
bootstrap=False,
14231446
oob_score=False,
14241447
n_jobs=1,
@@ -1430,7 +1453,7 @@ def __init__(self,
14301453
n_estimators=n_estimators,
14311454
estimator_params=("criterion", "max_depth", "min_samples_split",
14321455
"min_samples_leaf", "min_weight_fraction_leaf",
1433-
"max_features", "max_leaf_nodes",
1456+
"max_features", "max_leaf_nodes", "min_impurity_split",
14341457
"random_state"),
14351458
bootstrap=bootstrap,
14361459
oob_score=oob_score,
@@ -1446,6 +1469,7 @@ def __init__(self,
14461469
self.min_weight_fraction_leaf = min_weight_fraction_leaf
14471470
self.max_features = max_features
14481471
self.max_leaf_nodes = max_leaf_nodes
1472+
self.min_impurity_split = min_impurity_split
14491473

14501474

14511475
class RandomTreesEmbedding(BaseForest):
@@ -1500,6 +1524,10 @@ class RandomTreesEmbedding(BaseForest):
15001524
If None then unlimited number of leaf nodes.
15011525
If not None then ``max_depth`` will be ignored.
15021526
1527+
min_impurity_split : float, optional (default=1e-7)
1528+
Threshold for early stopping in tree growth. A node will split
1529+
if its impurity is above the threshold, otherwise it is a leaf.
1530+
15031531
sparse_output : bool, optional (default=True)
15041532
Whether or not to return a sparse CSR matrix, as default behavior,
15051533
or to return a dense array compatible with dense pipeline operators.
@@ -1544,6 +1572,7 @@ def __init__(self,
15441572
min_samples_leaf=1,
15451573
min_weight_fraction_leaf=0.,
15461574
max_leaf_nodes=None,
1575+
min_impurity_split=1e-7,
15471576
sparse_output=True,
15481577
n_jobs=1,
15491578
random_state=None,
@@ -1554,7 +1583,7 @@ def __init__(self,
15541583
n_estimators=n_estimators,
15551584
estimator_params=("criterion", "max_depth", "min_samples_split",
15561585
"min_samples_leaf", "min_weight_fraction_leaf",
1557-
"max_features", "max_leaf_nodes",
1586+
"max_features", "max_leaf_nodes", "min_impurity_split",
15581587
"random_state"),
15591588
bootstrap=False,
15601589
oob_score=False,
@@ -1570,6 +1599,7 @@ def __init__(self,
15701599
self.min_weight_fraction_leaf = min_weight_fraction_leaf
15711600
self.max_features = 1
15721601
self.max_leaf_nodes = max_leaf_nodes
1602+
self.min_impurity_split = min_impurity_split
15731603
self.sparse_output = sparse_output
15741604

15751605
def _set_oob_score(self, X, y):

sklearn/ensemble/gradient_boosting.py

+17-6
Original file line numberDiff line numberDiff line change
@@ -722,7 +722,7 @@ class BaseGradientBoosting(six.with_metaclass(ABCMeta, BaseEnsemble,
722722
@abstractmethod
723723
def __init__(self, loss, learning_rate, n_estimators, criterion,
724724
min_samples_split, min_samples_leaf, min_weight_fraction_leaf,
725-
max_depth, init, subsample, max_features,
725+
max_depth, min_impurity_split, init, subsample, max_features,
726726
random_state, alpha=0.9, verbose=0, max_leaf_nodes=None,
727727
warm_start=False, presort='auto'):
728728

@@ -736,6 +736,7 @@ def __init__(self, loss, learning_rate, n_estimators, criterion,
736736
self.subsample = subsample
737737
self.max_features = max_features
738738
self.max_depth = max_depth
739+
self.min_impurity_split = min_impurity_split
739740
self.init = init
740741
self.random_state = random_state
741742
self.alpha = alpha
@@ -1358,6 +1359,10 @@ class GradientBoostingClassifier(BaseGradientBoosting, ClassifierMixin):
13581359
If None then unlimited number of leaf nodes.
13591360
If not None then ``max_depth`` will be ignored.
13601361
1362+
min_impurity_split : float, optional (default=1e-7)
1363+
Threshold for early stopping in tree growth. A node will split
1364+
if its impurity is above the threshold, otherwise it is a leaf.
1365+
13611366
init : BaseEstimator, None, optional (default=None)
13621367
An estimator object that is used to compute the initial
13631368
predictions. ``init`` has to provide ``fit`` and ``predict``.
@@ -1437,8 +1442,8 @@ class GradientBoostingClassifier(BaseGradientBoosting, ClassifierMixin):
14371442
def __init__(self, loss='deviance', learning_rate=0.1, n_estimators=100,
14381443
subsample=1.0, criterion='friedman_mse', min_samples_split=2,
14391444
min_samples_leaf=1, min_weight_fraction_leaf=0.,
1440-
max_depth=3, init=None, random_state=None,
1441-
max_features=None, verbose=0,
1445+
max_depth=3, min_impurity_split=1e-7, init=None,
1446+
random_state=None, max_features=None, verbose=0,
14421447
max_leaf_nodes=None, warm_start=False,
14431448
presort='auto'):
14441449

@@ -1450,7 +1455,9 @@ def __init__(self, loss='deviance', learning_rate=0.1, n_estimators=100,
14501455
max_depth=max_depth, init=init, subsample=subsample,
14511456
max_features=max_features,
14521457
random_state=random_state, verbose=verbose,
1453-
max_leaf_nodes=max_leaf_nodes, warm_start=warm_start,
1458+
max_leaf_nodes=max_leaf_nodes,
1459+
min_impurity_split=min_impurity_split,
1460+
warm_start=warm_start,
14541461
presort=presort)
14551462

14561463
def _validate_y(self, y):
@@ -1711,6 +1718,10 @@ class GradientBoostingRegressor(BaseGradientBoosting, RegressorMixin):
17111718
Best nodes are defined as relative reduction in impurity.
17121719
If None then unlimited number of leaf nodes.
17131720
1721+
min_impurity_split : float, optional (default=1e-7)
1722+
Threshold for early stopping in tree growth. A node will split
1723+
if its impurity is above the threshold, otherwise it is a leaf.
1724+
17141725
alpha : float (default=0.9)
17151726
The alpha-quantile of the huber loss function and the quantile
17161727
loss function. Only if ``loss='huber'`` or ``loss='quantile'``.
@@ -1791,7 +1802,7 @@ class GradientBoostingRegressor(BaseGradientBoosting, RegressorMixin):
17911802
def __init__(self, loss='ls', learning_rate=0.1, n_estimators=100,
17921803
subsample=1.0, criterion='friedman_mse', min_samples_split=2,
17931804
min_samples_leaf=1, min_weight_fraction_leaf=0.,
1794-
max_depth=3, init=None, random_state=None,
1805+
max_depth=3, min_impurity_split=1e-7, init=None, random_state=None,
17951806
max_features=None, alpha=0.9, verbose=0, max_leaf_nodes=None,
17961807
warm_start=False, presort='auto'):
17971808

@@ -1801,7 +1812,7 @@ def __init__(self, loss='ls', learning_rate=0.1, n_estimators=100,
18011812
min_samples_leaf=min_samples_leaf,
18021813
min_weight_fraction_leaf=min_weight_fraction_leaf,
18031814
max_depth=max_depth, init=init, subsample=subsample,
1804-
max_features=max_features,
1815+
max_features=max_features, min_impurity_split=min_impurity_split,
18051816
random_state=random_state, alpha=alpha, verbose=verbose,
18061817
max_leaf_nodes=max_leaf_nodes, warm_start=warm_start,
18071818
presort=presort)

sklearn/tree/_tree.pxd

+2
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# Joel Nothman <joel.nothman@gmail.com>
55
# Arnaud Joly <arnaud.v.joly@gmail.com>
66
# Jacob Schreiber <jmschreiber91@gmail.com>
7+
# Nelson Liu <nelson@nelsonliu.me>
78
#
89
# License: BSD 3 clause
910

@@ -95,6 +96,7 @@ cdef class TreeBuilder:
9596
cdef SIZE_t min_samples_leaf # Minimum number of samples in a leaf
9697
cdef double min_weight_leaf # Minimum weight in a leaf
9798
cdef SIZE_t max_depth # Maximal tree depth
99+
cdef double min_impurity_split # Impurity threshold for early stopping
98100

99101
cpdef build(self, Tree tree, object X, np.ndarray y,
100102
np.ndarray sample_weight=*,

sklearn/tree/_tree.pyx

+11-5
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# Joel Nothman <joel.nothman@gmail.com>
1313
# Fares Hedayati <fares.hedayati@gmail.com>
1414
# Jacob Schreiber <jmschreiber91@gmail.com>
15+
# Nelson Liu <nelson@nelsonliu.me>
1516
#
1617
# License: BSD 3 clause
1718

@@ -63,7 +64,6 @@ TREE_UNDEFINED = -2
6364
cdef SIZE_t _TREE_LEAF = TREE_LEAF
6465
cdef SIZE_t _TREE_UNDEFINED = TREE_UNDEFINED
6566
cdef SIZE_t INITIAL_STACK_SIZE = 10
66-
cdef DTYPE_t MIN_IMPURITY_SPLIT = 1e-7
6767

6868
# Repeat struct definition for numpy
6969
NODE_DTYPE = np.dtype({
@@ -131,12 +131,13 @@ cdef class DepthFirstTreeBuilder(TreeBuilder):
131131

132132
def __cinit__(self, Splitter splitter, SIZE_t min_samples_split,
133133
SIZE_t min_samples_leaf, double min_weight_leaf,
134-
SIZE_t max_depth):
134+
SIZE_t max_depth, double min_impurity_split):
135135
self.splitter = splitter
136136
self.min_samples_split = min_samples_split
137137
self.min_samples_leaf = min_samples_leaf
138138
self.min_weight_leaf = min_weight_leaf
139139
self.max_depth = max_depth
140+
self.min_impurity_split = min_impurity_split
140141

141142
cpdef build(self, Tree tree, object X, np.ndarray y,
142143
np.ndarray sample_weight=None,
@@ -166,6 +167,7 @@ cdef class DepthFirstTreeBuilder(TreeBuilder):
166167
cdef SIZE_t min_samples_leaf = self.min_samples_leaf
167168
cdef double min_weight_leaf = self.min_weight_leaf
168169
cdef SIZE_t min_samples_split = self.min_samples_split
170+
cdef double min_impurity_split = self.min_impurity_split
169171

170172
# Recursive partition (without actual recursion)
171173
splitter.init(X, y, sample_weight_ptr, X_idx_sorted)
@@ -223,7 +225,8 @@ cdef class DepthFirstTreeBuilder(TreeBuilder):
223225
impurity = splitter.node_impurity()
224226
first = 0
225227

226-
is_leaf = is_leaf or (impurity <= MIN_IMPURITY_SPLIT)
228+
is_leaf = (is_leaf or
229+
(impurity <= min_impurity_split))
227230

228231
if not is_leaf:
229232
splitter.node_split(impurity, &split, &n_constant_features)
@@ -289,13 +292,15 @@ cdef class BestFirstTreeBuilder(TreeBuilder):
289292

290293
def __cinit__(self, Splitter splitter, SIZE_t min_samples_split,
291294
SIZE_t min_samples_leaf, min_weight_leaf,
292-
SIZE_t max_depth, SIZE_t max_leaf_nodes):
295+
SIZE_t max_depth, SIZE_t max_leaf_nodes,
296+
double min_impurity_split):
293297
self.splitter = splitter
294298
self.min_samples_split = min_samples_split
295299
self.min_samples_leaf = min_samples_leaf
296300
self.min_weight_leaf = min_weight_leaf
297301
self.max_depth = max_depth
298302
self.max_leaf_nodes = max_leaf_nodes
303+
self.min_impurity_split = min_impurity_split
299304

300305
cpdef build(self, Tree tree, object X, np.ndarray y,
301306
np.ndarray sample_weight=None,
@@ -421,6 +426,7 @@ cdef class BestFirstTreeBuilder(TreeBuilder):
421426
cdef SIZE_t n_node_samples
422427
cdef SIZE_t n_constant_features = 0
423428
cdef double weighted_n_samples = splitter.weighted_n_samples
429+
cdef double min_impurity_split = self.min_impurity_split
424430
cdef double weighted_n_node_samples
425431
cdef bint is_leaf
426432
cdef SIZE_t n_left, n_right
@@ -436,7 +442,7 @@ cdef class BestFirstTreeBuilder(TreeBuilder):
436442
(n_node_samples < self.min_samples_split) or
437443
(n_node_samples < 2 * self.min_samples_leaf) or
438444
(weighted_n_node_samples < self.min_weight_leaf) or
439-
(impurity <= MIN_IMPURITY_SPLIT))
445+
(impurity <= min_impurity_split))
440446

441447
if not is_leaf:
442448
splitter.node_split(impurity, &split, &n_constant_features)

0 commit comments

Comments
 (0)
0