8000 [MRG+1] feature: add beta-threshold early stopping for decision tree growth by nelson-liu · Pull Request #6954 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

[MRG+1] feature: add beta-threshold early stopping for decision tree growth #6954

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Jul 27, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,11 @@ New features
<https://github.com/scikit-learn/scikit-learn/pull/6667>`_) by `Nelson
Liu`_.

- Added weighted impurity-based early stopping criterion for decision tree
growth. (`#6954
<https://github.com/scikit-learn/scikit-learn/pull/6954>`_) by `Nelson
Liu`_

Enhancements
............

Expand Down
40 changes: 35 additions & 5 deletions sklearn/ensemble/forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -805,6 +805,10 @@ class RandomForestClassifier(ForestClassifier):
If None then unlimited number of leaf nodes.
If not None then ``max_depth`` will be ignored.

min_impurity_split : float, optional (default=1e-7)
Threshold for early stopping in tree growth. A node will split
if its impurity is above the threshold, otherwise it is a leaf.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

versionadded is missing.


bootstrap : boolean, optional (default=True)
Whether bootstrap samples are used when building trees.

Expand Down Expand Up @@ -899,6 +903,7 @@ def __init__(self,
min_weight_fraction_leaf=0.,
max_features="auto",
max_leaf_nodes=None,
min_impurity_split=1e-7,
bootstrap=True,
oob_score=False,
n_jobs=1,
Expand All @@ -911,7 +916,7 @@ def __init__(self,
n_estimators=n_estimators,
estimator_params=("criterion", "max_depth", "min_samples_split",
"min_samples_leaf", "min_weight_fraction_leaf",
"max_features", "max_leaf_nodes",
"max_features", "max_leaf_nodes", "min_impurity_split",
"random_state"),
bootstrap=bootstrap,
oob_score=oob_score,
Expand All @@ -928,6 +933,7 @@ def __init__(self,
self.min_weight_fraction_leaf = min_weight_fraction_leaf
self.max_features = max_features
self.max_leaf_nodes = max_leaf_nodes
self.min_impurity_split = min_impurity_split


class RandomForestRegressor(ForestRegressor):
Expand Down Expand Up @@ -1001,6 +1007,10 @@ class RandomForestRegressor(ForestRegressor):
If None then unlimited number of leaf nodes.
If not None then ``max_depth`` will be ignored.

min_impurity_split : float, optional (default=1e-7)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

versionadded is missing.

Threshold for early stopping in tree growth. A node will split
if its impurity is above the threshold, otherwise it is a leaf.

bootstrap : boolean, optional (default=True)
Whether bootstrap samples are used when building trees.

Expand Down Expand Up @@ -1064,6 +1074,7 @@ def __init__(self,
min_weight_fraction_leaf=0.,
max_features="auto",
max_leaf_nodes=None,
min_impurity_split=1e-7,
bootstrap=True,
oob_score=False,
n_jobs=1,
Expand All @@ -1075,7 +1086,7 @@ def __init__(self,
n_estimators=n_estimators,
estimator_params=("criterion", "max_depth", "min_samples_split",
"min_samples_leaf", "min_weight_fraction_leaf",
"max_features", "max_leaf_nodes",
"max_features", "max_leaf_nodes", "min_impurity_split",
"random_state"),
bootstrap=bootstrap,
oob_score=oob_score,
Expand All @@ -1091,6 +1102,7 @@ def __init__(self,
self.min_weight_fraction_leaf = min_weight_fraction_leaf
self.max_features = max_features
self.max_leaf_nodes = max_leaf_nodes
self.min_impurity_split = min_impurity_split


class ExtraTreesClassifier(ForestClassifier):
Expand Down Expand Up @@ -1160,6 +1172,10 @@ class ExtraTreesClassifier(ForestClassifier):
If None then unlimited number of leaf nodes.
If not None then ``max_depth`` will be ignored.

min_impurity_split : float, optional (default=1e-7)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

versionadded is missing.

Threshold for early stopping in tree growth. A node will split
if its impurity is above the threshold, otherwise it is a leaf.

bootstrap : boolean, optional (default=False)
Whether bootstrap samples are used when building trees.

Expand Down Expand Up @@ -1255,6 +1271,7 @@ def __init__(self,
min_weight_fraction_leaf=0.,
max_features="auto",
max_leaf_nodes=None,
min_impurity_split=1e-7,
bootstrap=False,
oob_score=False,
n_jobs=1,
Expand All @@ -1267,7 +1284,7 @@ def __init__(self,
n_estimators=n_estimators,
estimator_params=("criterion", "max_depth", "min_samples_split",
"min_samples_leaf", "min_weight_fraction_leaf",
"max_features", "max_leaf_nodes",
"max_features", "max_leaf_nodes", "min_impurity_split",
"random_state"),
bootstrap=bootstrap,
oob_score=oob_score,
Expand All @@ -1284,6 +1301,7 @@ def __init__(self,
self.min_weight_fraction_leaf = min_weight_fraction_leaf
self.max_features = max_features
self.max_leaf_nodes = max_leaf_nodes
self.min_impurity_split = min_impurity_split


class ExtraTreesRegressor(ForestRegressor):
Expand Down Expand Up @@ -1355,6 +1373,10 @@ class ExtraTreesRegressor(ForestRegressor):
If None then unlimited number of leaf nodes.
If not None then ``max_depth`` will be ignored.

min_impurity_split : float, optional (default=1e-7)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

versionadded is missing.

Threshold for early stopping in tree growth. A node will split
if its impurity is above the threshold, otherwise it is a leaf.

bootstrap : boolean, optional (default=False)
Whether bootstrap samples are used when building trees.

Expand Down Expand Up @@ -1419,6 +1441,7 @@ def __init__(self,
min_weight_fraction_leaf=0.,
max_features="auto",
max_leaf_nodes=None,
min_impurity_split=1e-7,
bootstrap=False,
oob_score=False,
n_jobs=1,
Expand All @@ -1430,7 +1453,7 @@ def __init__(self,
n_estimators=n_estimators,
estimator_params=("criterion", "max_depth", "min_samples_split",
"min_samples_leaf", "min_weight_fraction_leaf",
"max_features", "max_leaf_nodes",
"max_features", "max_leaf_nodes", "min_impurity_split",
"random_state"),
bootstrap=bootstrap,
oob_score=oob_score,
Expand All @@ -1446,6 +1469,7 @@ def __init__(self,
self.min_weight_fraction_leaf = min_weight_fraction_leaf
self.max_features = max_features
self.max_leaf_nodes = max_leaf_nodes
self.min_impurity_split = min_impurity_split


class RandomTreesEmbedding(BaseForest):
Expand Down Expand Up @@ -1500,6 +1524,10 @@ class RandomTreesEmbedding(BaseForest):
If None then unlimited number of leaf nodes.
If not None then ``max_depth`` will be ignored.

min_impurity_split : float, optional (default=1e-7)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

versionadded is missing.

Threshold for early stopping in tree growth. A node will split
if its impurity is above the threshold, otherwise it is a leaf.

sparse_output : bool, optional (default=True)
Whether or not to return a sparse CSR matrix, as default behavior,
or to return a dense array compatible with dense pipeline operators.
Expand Down Expand Up @@ -1544,6 +1572,7 @@ def __init__(self,
min_samples_leaf=1,
min_weight_fraction_leaf=0.,
max_leaf_nodes=None,
min_impurity_split=1e-7,
sparse_output=True,
n_jobs=1,
random_state=None,
Expand All @@ -1554,7 +1583,7 @@ def __init__(self,
n_estimators=n_estimators,
estimator_params=("criterion", "max_depth", "min_samples_split",
"min_samples_leaf", "min_weight_fraction_leaf",
"max_features", "max_leaf_nodes",
"max_features", "max_leaf_nodes", "min_impurity_split",
"random_state"),
bootstrap=False,
oob_score=False,
Expand All @@ -1570,6 +1599,7 @@ def __init__(self,
self.min_weight_fraction_leaf = min_weight_fraction_leaf
self.max_features = 1
self.max_leaf_nodes = max_leaf_nodes
self.min_impurity_split = min_impurity_split
self.sparse_output = sparse_output

def _set_oob_score(self, X, y):
Expand Down
23 changes: 17 additions & 6 deletions sklearn/ensemble/gradient_boosting.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,7 +722,7 @@ class BaseGradientBoosting(six.with_metaclass(ABCMeta, BaseEnsemble,
@abstractmethod
def __init__(self, loss, learning_rate, n_estimators, criterion,
min_samples_split, min_samples_leaf, min_weight_fraction_leaf,
max_depth, init, subsample, max_features,
max_depth, min_impurity_split, init, subsample, max_features,
random_state, alpha=0.9, verbose=0, max_leaf_nodes=None,
warm_start=False, presort='auto'):

Expand All @@ -736,6 +736,7 @@ def __init__(self, loss, learning_rate, n_estimators, criterion,
self.subsample = subsample
self.max_features = max_features
self.max_depth = max_depth
self.min_impurity_split = min_impurity_split
self.init = init
self.random_state = random_state
self.alpha = alpha
Expand Down Expand Up @@ -1358,6 +1359,10 @@ class GradientBoostingClassifier(BaseGradientBoosting, ClassifierMixin):
If None then unlimited number of leaf nodes.
If not None then ``max_depth`` will be ignored.

min_impurity_split : float, optional (default=1e-7)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

versionadded is missing.

Threshold for early stopping in tree growth. A node will split
if its impurity is above the threshold, otherwise it is a leaf.

init : BaseEstimator, None, optional (default=None)
An estimator object that is used to compute the initial
predictions. ``init`` has to provide ``fit`` and ``predict``.
Expand Down Expand Up @@ -1437,8 +1442,8 @@ class GradientBoostingClassifier(BaseGradientBoosting, ClassifierMixin):
def __init__(self, loss='deviance', learning_rate=0.1, n_estimators=100,
subsample=1.0, criterion='friedman_mse', min_samples_split=2,
min_samples_leaf=1, min_weight_fraction_leaf=0.,
max_depth=3, init=None, random_state=None,
max_features=None, verbose=0,
max_depth=3, min_impurity_split=1e-7, init=None,
random_state=None, max_features=None, verbose=0,
max_leaf_nodes=None, warm_start=False,
presort='auto'):

Expand All @@ -1450,7 +1455,9 @@ def __init__(self, loss='deviance', learning_rate=0.1, n_estimators=100,
max_depth=max_depth, init=init, subsample=subsample,
max_features=max_features,
random_state=random_state, verbose=verbose,
max_leaf_nodes=max_leaf_nodes, warm_start=warm_start,
max_leaf_nodes=max_leaf_nodes,
min_impurity_split=min_impurity_split,
warm_start=warm_start,
presort=presort)

def _validate_y(self, y):
Expand Down Expand Up @@ -1711,6 +1718,10 @@ class GradientBoostingRegressor(BaseGradientBoosting, RegressorMixin):
Best nodes are defined as relative reduction in impurity.
If None then unlimited number of leaf nodes.

min_impurity_split : float, optional (default=1e-7)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

versionadded is missing.

Threshold for early stopping in tree growth. A node will split
if its impurity is above the threshold, otherwise it is a leaf.

alpha : float (default=0.9)
The alpha-quantile of the huber loss function and the quantile
loss function. Only if ``loss='huber'`` or ``loss='quantile'``.
Expand Down Expand Up @@ -1791,7 +1802,7 @@ class GradientBoostingRegressor(BaseGradientBoosting, RegressorMixin):
def __init__(self, loss='ls', learning_rate=0.1, n_estimators=100,
subsample=1.0, criterion='friedman_mse', min_samples_split=2,
min_samples_leaf=1, min_weight_fraction_leaf=0.,
max_depth=3, init=None, random_state=None,
max_depth=3, min_impurity_split=1e-7, init=None, random_state=None,
max_features=None, alpha=0.9, verbose=0, max_leaf_nodes=None,
warm_start=False, presort='auto'):

Expand All @@ -1801,7 +1812,7 @@ def __init__(self, loss='ls', learning_rate=0.1, n_estimators=100,
min_samples_leaf=min_samples_leaf,
min_weight_fraction_leaf=min_weight_fraction_leaf,
max_depth=max_depth, init=init, subsample=subsample,
max_features=max_features,
max_features=max_features, min_impurity_split=min_impurity_split,
random_state=random_state, alpha=alpha, verbose=verbose,
max_leaf_nodes=max_leaf_nodes, warm_start=warm_start,
presort=presort)
Expand Down
2 changes: 2 additions & 0 deletions sklearn/tree/_tree.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# Joel Nothman <joel.nothman@gmail.com>
# Arnaud Joly <arnaud.v.joly@gmail.com>
# Jacob Schreiber <jmschreiber91@gmail.com>
# Nelson Liu <nelson@nelsonliu.me>
#
# License: BSD 3 clause

Expand Down Expand Up @@ -95,6 +96,7 @@ cdef class TreeBuilder:
cdef SIZE_t min_samples_leaf # Minimum number of samples in a leaf
cdef double min_weight_leaf # Minimum weight in a leaf
cdef SIZE_t max_depth # Maximal tree depth
cdef double min_impurity_split # Impurity threshold for early stopping

cpdef build(self, Tree tree, object X, np.ndarray y,
np.ndarray sample_weight=*,
Expand Down
16 changes: 11 additions & 5 deletions sklearn/tree/_tree.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# Joel Nothman <joel.nothman@gmail.com>
# Fares Hedayati <fares.hedayati@gmail.com>
# Jacob Schreiber <jmschreiber91@gmail.com>
# Nelson Liu <nelson@nelsonliu.me>
#
# License: BSD 3 clause

Expand Down Expand Up @@ -63,7 +64,6 @@ TREE_UNDEFINED = -2
cdef SIZE_t _TREE_LEAF = TREE_LEAF
cdef SIZE_t _TREE_UNDEFINED = TREE_UNDEFINED
cdef SIZE_t INITIAL_STACK_SIZE = 10
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was that not documented beforehand? I feel like that should have been in the docs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nope, it wasn't documented beforehand as far as I saw.

cdef DTYPE_t MIN_IMPURITY_SPLIT = 1e-7

# Repeat struct definition for numpy
NODE_DTYPE = np.dtype({
Expand Down Expand Up @@ -131,12 +131,13 @@ cdef class DepthFirstTreeBuilder(TreeBuilder):

def __cinit__(self, Splitter splitter, SIZE_t min_samples_split,
SIZE_t min_samples_leaf, double min_weight_leaf,
SIZE_t max_depth):
SIZE_t max_depth, double min_impurity_split):
self.splitter = splitter
self.min_samples_split = min_samples_split
self.min_samples_leaf = min_samples_leaf
self.min_weight_leaf = min_weight_leaf
self.max_depth = max_depth
self.min_impurity_split = min_impurity_split

cpdef build(self, Tree tree, object X, np.ndarray y,
np.ndarray sample_weight=None,
Expand Down Expand Up @@ -166,6 +167,7 @@ cdef class DepthFirstTreeBuilder(TreeBuilder):
cdef SIZE_t min_samples_leaf = self.min_samples_leaf
cdef double min_weight_leaf = self.min_weight_leaf
cdef SIZE_t min_samples_split = self.min_samples_split
cdef double min_impurity_split = self.min_impurity_split

# Recursive partition (without actual recursion)
splitter.init(X, y, sample_weight_ptr, X_idx_sorted)
Expand Down Expand Up @@ -223,7 +225,8 @@ cdef class DepthFirstTreeBuilder(TreeBuilder):
impurity = splitter.node_impurity()
first = 0

is_leaf = is_leaf or (impurity <= MIN_IMPURITY_SPLIT)
is_leaf = (is_leaf or
(impurity <= min_impurity_split))

if not is_leaf:
splitter.node_split(impurity, &split, &n_constant_features)
Expand Down Expand Up @@ -289,13 +292,15 @@ cdef class BestFirstTreeBuilder(TreeBuilder):

def __cinit__(self, Splitter splitter, SIZE_t min_samples_split,
SIZE_t min_samples_leaf, min_weight_leaf,
SIZE_t max_depth, SIZE_t max_leaf_nodes):
SIZE_t max_depth, SIZE_t max_leaf_nodes,
double min_impurity_split):
self.splitter = splitter
self.min_samples_split = min_samples_split
self.min_samples_leaf = min_samples_leaf
self.min_weight_leaf = min_weight_leaf
self.max_depth = max_depth
self.max_leaf_nodes = max_leaf_nodes
self.min_impurity_split = min_impurity_split

cpdef build(self, Tree tree, object X, np.ndarray y,
np.ndarray sample_weight=None,
Expand Down Expand Up @@ -421,6 +426,7 @@ cdef class BestFirstTreeBuilder(TreeBuilder):
cdef SIZE_t n_node_samples
cdef SIZE_t n_constant_features = 0
cdef double weighted_n_samples = splitter.weighted_n_samples
cdef double min_impurity_split = self.min_impurity_split
cdef double weighted_n_node_samples
cdef bint is_leaf
cdef SIZE_t n_left, n_right
Expand All @@ -436,7 +442,7 @@ cdef class BestFirstTreeBuilder(TreeBuilder):
(n_node_samples < self.min_samples_split) or
(n_node_samples < 2 * self.min_samples_leaf) or
(weighted_n_node_samples < self.min_weight_leaf) or
(impurity <= MIN_IMPURITY_SPLIT))
(impurity <= min_impurity_split))

if not is_leaf:
splitter.node_split(impurity, &split, &n_constant_features)
Expand Down
Loading
0