10000 [MRG+2] ENH/FIX Introduce min_impurity_decrease param for early stopping based on impurity; Deprecate min_impurity_split by raghavrv · Pull Request #8449 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

[MRG+2] ENH/FIX Introduce min_impurity_decrease param for early stopping based on impurity; Deprecate min_impurity_split #8449

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Apr 3, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,12 @@ API changes summary
**sklearn.utils.estimator_checks** to check their consistency.
:issue:`7578` by :user:`Shubham Bhardwaj <shubham0704>`

- All tree based estimators now accept a ``min_impurity_decrease``
parameter in lieu of the ``min_impurity_split``, which is now deprecated.
The ``min_impurity_decrease`` helps stop splitting the nodes in which
the weighted impurity decrease from splitting is no longer alteast
``min_impurity_decrease``. :issue:`8449` by `Raghav RV_`


.. _changes_0_18_1:

Expand Down
138 changes: 108 additions & 30 deletions sklearn/ensemble/forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -807,11 +807,23 @@ class RandomForestClassifier(ForestClassifier):
Best nodes are defined as relative reduction in impurity.
If None then unlimited number of leaf nodes.

min_impurity_split : float, optional (default=1e-7)
Threshold for early stopping in tree growth. A node will split
if its impurity is above the threshold, otherwise it is a leaf.
min_impurity_decrease : float, optional (default=0.)
A node will be split if this split induces a decrease of the impurity
greater than or equal to this value.

.. versionadded:: 0.18
The weighted impurity decrease equation is the following::

N_t / N * (impurity - N_t_R / N_t * right_impurity
- N_t_L / N_t * left_impurity)

where ``N`` is the total number of samples, ``N_t`` is the number of
samples at the current node, ``N_t_L`` is the number of samples in the
left child, and ``N_t_R`` is the number of samples in the right child.

``N``, ``N_t``, ``N_t_R`` and ``N_t_L`` all refer to the weighted sum,
if ``sample_weight`` is passed.

.. versionadded:: 0.19

bootstrap : boolean, optional (default=True)
Whether bootstrap samples are used when building trees.
Expand Down Expand Up @@ -916,7 +928,8 @@ def __init__(self,
min_weight_fraction_leaf=0.,
max_features="auto",
max_leaf_nodes=None,
min_impurity_split=1e-7,
min_impurity_decrease=0.,
min_impurity_split=None,
bootstrap=True,
oob_score=False,
n_jobs=1,
Expand All @@ -929,7 +942,8 @@ def __init__(self,
n_estimators=n_estimators,
estimator_params=("criterion", "max_depth", "min_samples_split",
"min_samples_leaf", "min_weight_fraction_leaf",
"max_features", "max_leaf_nodes", "min_impurity_split",
"max_features", "max_leaf_nodes",
"min_impurity_decrease", "min_impurity_split",
"random_state"),
bootstrap=bootstrap,
oob_score=oob_score,
Expand All @@ -946,6 +960,7 @@ def __init__(self,
self.min_weight_fraction_leaf = min_weight_fraction_leaf
self.max_features = max_features
self.max_leaf_nodes = max_leaf_nodes
self.min_impurity_decrease = min_impurity_decrease
self.min_impurity_split = min_impurity_split


Expand Down Expand Up @@ -1028,11 +1043,23 @@ class RandomForestRegressor(ForestRegressor):
Best nodes are defined as relative reduction in impurity.
If None then unlimited number of leaf nodes.

min_impurity_split : float, optional (default=1e-7)
Threshold for early stopping in tree growth. A node will split
if its impurity is above the threshold, otherwise it is a leaf.
min_impurity_decrease : float, optional (default=0.)
A node will be split if this split induces a decrease of the impurity
greater than or equal to this value.

.. versionadded:: 0.18
The weighted impurity decrease equation is the following::

N_t / N * (impurity - N_t_R / N_t * right_impurity
- N_t_L / N_t * left_impurity)

where ``N`` is the total number of samples, ``N_t`` is the number of
samples at the current node, ``N_t_L`` is the number of samples in the
left child, and ``N_t_R`` is the number of samples in the right child.

``N``, ``N_t``, ``N_t_R`` and ``N_t_L`` all refer to the weighted sum,
if ``sample_weight`` is passed.

.. versionadded:: 0.19

bootstrap : boolean, optional (default=True)
Whether bootstrap samples are used when building trees.
Expand Down Expand Up @@ -1106,7 +1133,8 @@ def __init__(self,
min_weight_fraction_leaf=0.,
max_features="auto",
max_leaf_nodes=None,
min_impurity_split=1e-7,
min_impurity_decrease=0.,
min_impurity_split=None,
bootstrap=True,
oob_score=False,
n_jobs=1,
Expand All @@ -1118,7 +1146,8 @@ def __init__(self,
n_estimators=n_estimators,
estimator_params=("criterion", "max_depth", "min_samples_split",
"min_samples_leaf", "min_weight_fraction_leaf",
"max_features", "max_leaf_nodes", "min_impurity_split",
"max_features", "max_leaf_nodes",
"min_impurity_decrease", "min_impurity_split",
"random_state"),
bootstrap=bootstrap,
oob_score=oob_score,
Expand All @@ -1134,6 +1163,7 @@ def __init__(self,
self.min_weight_fraction_leaf = min_weight_fraction_leaf
self.max_features = max_features
self.max_leaf_nodes = max_leaf_nodes
self.min_impurity_decrease = min_impurity_decrease
self.min_impurity_split = min_impurity_split


Expand Down Expand Up @@ -1209,11 +1239,23 @@ class ExtraTreesClassifier(ForestClassifier):
Best nodes are defined as relative reduction in impurity.
If None then unlimited number of leaf nodes.

min_impurity_split : float, optional (default=1e-7)
Threshold for early stopping in tree growth. A node will split
if its impurity is above the threshold, otherwise it is a leaf.
min_impurity_decrease : float, optional (default=0.)
A node will be split if this split induces a decrease of the impurity
greater than or equal to this value.

.. versionadded:: 0.18
The weighted impurity decrease equation is the following::

N_t / N * (impurity - N_t_R / N_t * right_impurity
- N_t_L / N_t * left_impurity)

where ``N`` is the total number of samples, ``N_t`` is the number of
samples at the current node, ``N_t_L`` is the number of samples in the
left child, and ``N_t_R`` is the number of samples in the right child.

``N``, ``N_t``, ``N_t_R`` and ``N_t_L`` all refer to the weighted sum,
if ``sample_weight`` is passed.

.. versionadded:: 0.19

bootstrap : boolean, optional (default=False)
Whether bootstrap samples are used when building trees.
Expand Down Expand Up @@ -1310,7 +1352,8 @@ def __init__(self,
min_weight_fraction_leaf=0.,
max_features="auto",
max_leaf_nodes=None,
min_impurity_split=1e-7,
min_impurity_decrease=0.,
min_impurity_split=None,
bootstrap=False,
oob_score=False,
n_jobs=1,
Expand All @@ -1323,7 +1366,8 @@ def __init__(self,
n_estimators=n_estimators,
estimator_params=("criterion", "max_depth", "min_samples_split",
"min_samples_leaf", "min_weight_fraction_leaf",
"max_features", "max_leaf_nodes", "min_impurity_split",
"max_features", "max_leaf_nodes",
"min_impurity_decrease", "min_impurity_split",
"random_state"),
bootstrap=bootstrap,
oob_score=oob_score,
Expand All @@ -1340,6 +1384,7 @@ def __init__(self,
self.min_weight_fraction_leaf = min_weight_fraction_leaf
self.max_features = max_features
self.max_leaf_nodes = max_leaf_nodes
self.min_impurity_decrease = min_impurity_decrease
self.min_impurity_split = min_impurity_split


Expand Down Expand Up @@ -1420,11 +1465,23 @@ class ExtraTreesRegressor(ForestRegressor):
Best nodes are defined as relative reduction in impurity.
If None then unlimited number of leaf nodes.

min_impurity_split : float, optional (default=1e-7)
Threshold for early stopping in tree growth. A node will split
if its impurity is above the threshold, otherwise it is a leaf.
min_impurity_decrease : float, optional (default=0.)
A node will be split if this split induces a decrease of the impurity
greater than or equal to this value.

.. versionadded:: 0.18
The weighted impurity decrease equation is the following::

N_t / N * (impurity - N_t_R / N_t * right_impurity
- N_t_L / N_t * left_impurity)

where ``N`` is the total number of samples, ``N_t`` is the number of
samples at the current node, ``N_t_L`` is the number of samples in the
left child, and ``N_t_R`` is the number of samples in the right child.

``N``, ``N_t``, ``N_t_R`` and ``N_t_L`` all refer to the weighted sum,
if ``sample_weight`` is passed.

.. versionadded:: 0.19

bootstrap : boolean, optional (default=False)
Whether bootstrap samples are used when building trees.
Expand Down Expand Up @@ -1490,7 +1547,8 @@ def __init__(self,
min_weight_fraction_leaf=0.,
max_features="auto",
max_leaf_nodes=None,
min_impurity_split=1e-7,
min_impurity_decrease=0.,
min_impurity_split=None,
bootstrap=False,
oob_score=False,
n_jobs=1,
Expand All @@ -1502,7 +1560,8 @@ def __init__(self,
n_estimators=n_estimators,
estimator_params=("criterion", "max_depth", "min_samples_split",
"min_samples_leaf", "min_weight_fraction_leaf",
"max_features", "max_leaf_nodes", "min_impurity_split",
"max_features", "max_leaf_nodes",
"min_impurity_decrease", "min_impurity_split",
"random_state"),
bootstrap=bootstrap,
oob_score=oob_score,
Expand All @@ -1518,6 +1577,7 @@ def __init__(self,
self.min_weight_fraction_leaf = min_weight_fraction_leaf
self.max_features = max_features
self.max_leaf_nodes = max_leaf_nodes
self.min_impurity_decrease = min_impurity_decrease
self.min_impurity_split = min_impurity_split


Expand Down Expand Up @@ -1578,11 +1638,26 @@ class RandomTreesEmbedding(BaseForest):
Best nodes are defined as relative reduction in impurity.
If None then unlimited number of leaf nodes.

min_impurity_split : float, optional (default=1e-7)
Threshold for early stopping in tree growth. A node will split
if its impurity is above the threshold, otherwise it is a leaf.
min_impurity_decrease : float, optional (default=0.)
A node will be split if this split induces a decrease of the impurity
greater than or equal to this value.

.. versionadded:: 0.18
The weighted impurity decrease equation is the following::

N_t / N * (impurity - N_t_R / N_t * right_impurity
- N_t_L / N_t * left_impurity)

where ``N`` is the total number of samples, ``N_t`` is the number of
samples at the current node, ``N_t_L`` is the number of samples in the
left child, and ``N_t_R`` is the number of samples in the right child.

``N``, ``N_t``, ``N_t_R`` and ``N_t_L`` all refer to the weighted sum,
if ``sample_weight`` is passed.

.. versionadded:: 0.19

bootstrap : boolean, optional (default=True)
Whether bootstrap samples are used when building trees.

sparse_output : bool, optional (default=True)
Whether or not to return a sparse CSR matrix, as default behavior,
Expand Down Expand Up @@ -1628,7 +1703,8 @@ def __init__(self,
min_samples_leaf=1,
min_weight_fraction_leaf=0.,
max_leaf_nodes=None,
min_impurity_split=1e-7,
min_impurity_decrease=0.,
min_impurity_split=None,
sparse_output=True,
n_jobs=1,
random_state=None,
Expand All @@ -1639,7 +1715,8 @@ def __init__(self,
n_estimators=n_estimators,
estimator_params=("criterion", "max_depth", "min_samples_split",
"min_samples_leaf", "min_weight_fraction_leaf",
"max_features", "max_leaf_nodes", "min_impurity_split",
"max_features", "max_leaf_nodes",
"min_impurity_decrease", "min_impurity_split",
"random_state"),
bootstrap=False,
oob_score=False,
Expand All @@ -1655,6 +1732,7 @@ def __init__(self,
self.min_weight_fraction_leaf = min_weight_fraction_leaf
self.max_features = 1
self.max_leaf_nodes = max_leaf_nodes
self.min_impurity_decrease = min_impurity_decrease
self.min_impurity_split = min_impurity_split
self.sparse_output = sparse_output

Expand Down
Loading
0