8000 [MRG+1] ENH/FIX Introduce min_impurity_decrease param for early stopp… · massich/scikit-learn@08e384c · GitHub
[go: up one dir, main page]

Skip to content

Commit 08e384c

Browse files
raghavrvJoan Massich
authored andcommitted
[MRG+1] ENH/FIX Introduce min_impurity_decrease param for early stopping based on impurity; Deprecate min_impurity_split (scikit-learn#8449)
[MRG+2] ENH/FIX Introduce min_impurity_decrease param for early stopping based on impurity; Deprecate min_impurity_split
1 parent b5580de commit 08e384c

File tree

9 files changed

+380
-82
lines changed

9 files changed

+380
-82
lines changed

doc/whats_new.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,12 @@ API changes summary
309309
**sklearn.utils.estimator_checks** to check their consistency.
310310
:issue:`7578` by :user:`Shubham Bhardwaj <shubham0704>`
311311

312+
- All tree based estimators now accept a ``min_impurity_decrease``
313+
parameter in lieu of the ``min_impurity_split``, which is now deprecated.
314+
The ``min_impurity_decrease`` helps stop splitting the nodes in which
315+
the weighted impurity decrease from splitting is no longer alteast
316+
``min_impurity_decrease``. :issue:`8449` by `Raghav RV_`
317+
312318

313319
.. _changes_0_18_1:
314320

sklearn/ensemble/forest.py

Lines changed: 108 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -813,11 +813,23 @@ class RandomForestClassifier(ForestClassifier):
813813
Best nodes are defined as relative reduction in impurity.
814814
If None then unlimited number of leaf nodes.
815815
816-
min_impurity_split : float, optional (default=1e-7)
817-
Threshold for early stopping in tree growth. A node will split
818-
if its impurity is above the threshold, otherwise it is a leaf.
816+
min_impurity_decrease : float, optional (default=0.)
817+
A node will be split if this split induces a decrease of the impurity
818+
greater than or equal to this value.
819819
820-
.. versionadded:: 0.18
820+
The weighted impurity decrease equation is the following::
821+
822+
N_t / N * (impurity - N_t_R / N_t * right_impurity
823+
- N_t_L / N_t * left_impurity)
824+
825+
where ``N`` is the total number of samples, ``N_t`` is the number of
826+
samples at the current node, ``N_t_L`` is the number of samples in the
827+
left child, and ``N_t_R`` is the number of samples in the right child.
828+
829+
``N``, ``N_t``, ``N_t_R`` and ``N_t_L`` all refer to the weighted sum,
830+
if ``sample_weight`` is passed.
831+
832+
.. versionadded:: 0.19
821833
822834
bootstrap : boolean, optional (default=True)
823835
Whether bootstrap samples are used when building trees.
@@ -922,7 +934,8 @@ def __init__(self,
922934
min_weight_fraction_leaf=0.,
923935
max_features="auto",
924936
max_leaf_nodes=None,
925-
min_impurity_split=1e-7,
937+
min_impurity_decrease=0.,
938+
min_impurity_split=None,
926939
bootstrap=True,
927940
oob_score=False,
928941
n_jobs=1,
@@ -935,7 +948,8 @@ def __init__(self,
935948
n_estimators=n_estimators,
936949
estimator_params=("criterion", "max_depth", "min_samples_split",
937950
"min_samples_leaf", "min_weight_fraction_leaf",
938-
"max_features", "max_leaf_nodes", "min_impurity_split",
951+
"max_features", "max_leaf_nodes",
952+
"min_impurity_decrease", "min_impurity_split",
939953
"random_state"),
940954
bootstrap=bootstrap,
941955
oob_score=oob_score,
@@ -952,6 +966,7 @@ def __init__(self,
952966
self.min_weight_fraction_leaf = min_weight_fraction_leaf
953967
self.max_features = max_features
954968
self.max_leaf_nodes = max_leaf_nodes
969+
self.min_impurity_decrease = min_impurity_decrease
955970
self.min_impurity_split = min_impurity_split
956971

957972

@@ -1034,11 +1049,23 @@ class RandomForestRegressor(ForestRegressor):
10341049
Best nodes are defined as relative reduction in impurity.
10351050
If None then unlimited number of leaf nodes.
10361051
1037-
min_impurity_split : float, optional (default=1e-7)
1038-
Threshold for early stopping in tree growth. A node will split
1039-
if its impurity is above the threshold, otherwise it is a leaf.
1052+
min_impurity_decrease : float, optional (default=0.)
1053+
A node will be split if this split induces a decrease of the impurity
1054+
greater than or equal to this value.
10401055
1041-
.. versionadded:: 0.18
1056+
The weighted impurity decrease equation is the following::
1057+
1058+
N_t / N * (impurity - N_t_R / N_t * right_impurity
1059+
- N_t_L / N_t * left_impurity)
1060+
1061+
where ``N`` is the total number of samples, ``N_t`` is the number of
1062+
samples at the current node, ``N_t_L`` is the number of samples in the
1063+
left child, and ``N_t_R`` is the number of samples in the right child.
1064+
1065+
``N``, ``N_t``, ``N_t_R`` and ``N_t_L`` all refer to the weighted sum,
1066+
if ``sample_weight`` is passed.
1067+
1068+
.. versionadded:: 0.19
10421069
10431070
bootstrap : boolean, optional (default=True)
10441071
Whether bootstrap samples are used when building trees.
@@ -1112,7 +1139,8 @@ def __init__(self,
11121139
min_weight_fraction_leaf=0.,
11131140
max_features="auto",
11141141
max_leaf_nodes=None,
1115-
min_impurity_split=1e-7,
1142+
min_impurity_decrease=0.,
1143+
min_impurity_split=None,
11161144
bootstrap=True,
11171145
oob_score=False,
11181146
n_jobs=1,
@@ -1124,7 +1152,8 @@ def __init__(self,
11241152
n_estimators=n_estimators,
11251153
estimator_params=("criterion", "max_depth", "min_samples_split",
11261154
"min_samples_leaf", "min_weight_fraction_leaf",
1127-
"max_features", "max_leaf_nodes", "min_impurity_split",
1155+
"max_features", "max_leaf_nodes",
1156+
"min_impurity_decrease", "min_impurity_split",
11281157
"random_state"),
11291158
bootstrap=bootstrap,
11301159
oob_score=oob_score,
@@ -1140,6 +1169,7 @@ def __init__(self,
11401169
self.min_weight_fraction_leaf = min_weight_fraction_leaf
11411170
self.max_features = max_features
11421171
self.max_leaf_nodes = max_leaf_nodes
1172+
self.min_impurity_decrease = min_impurity_decrease
11431173
self.min_impurity_split = min_impurity_split
11441174

11451175

@@ -1215,11 +1245,23 @@ class ExtraTreesClassifier(ForestClassifier):
12151245
Best nodes are defined as relative reduction in impurity.
12161246
If None then unlimited number of leaf nodes.
12171247
1218-
min_impurity_split : float, optional (default=1e-7)
1219-
Threshold for early stopping in tree growth. A node will split
1220-
if its impurity is above the threshold, otherwise it is a leaf.
1248+
min_impurity_decrease : float, optional (default=0.)
1249+
A node will be split if this split induces a decrease of the impurity
1250+
greater than or equal to this value.
12211251
1222-
.. versionadded:: 0.18
1252+
The weighted impurity decrease equation is the following::
1253+
1254+
N_t / N * (impurity - N_t_R / N_t * right_impurity
1255+
- N_t_L / N_t * left_impurity)
1256+
1257+
where ``N`` is the total number of samples, ``N_t`` is the number of
1258+
samples at the current node, ``N_t_L`` is the number of samples in the
1259+
left child, and ``N_t_R`` is the number of samples in the right child.
1260+
1261+
``N``, ``N_t``, ``N_t_R`` and ``N_t_L`` all refer to the weighted sum,
1262+
if ``sample_weight`` is passed.
1263+
1264+
.. versionadded:: 0.19
12231265
12241266
bootstrap : boolean, optional (default=False)
12251267
Whether bootstrap samples are used when building trees.
@@ -1316,7 +1358,8 @@ def __init__(self,
13161358
min_weight_fraction_leaf=0.,
13171359
max_features="auto",
13181360
max_leaf_nodes=None,
1319-
min_impurity_split=1e-7,
1361+
min_impurity_decrease=0.,
1362+
min_impurity_split=None,
13201363
bootstrap=False,
13211364
oob_score=False,
13221365
n_jobs=1,
@@ -1329,7 +1372,8 @@ def __init__(self,
13291372
n_estimators=n_estimators,
13301373
estimator_params=("criterion", "max_depth", "min_samples_split",
13311374
"min_samples_leaf", "min_weight_fraction_leaf",
1332-
"max_features", "max_leaf_nodes", "min_impurity_split",
1375+
"max_features", "max_leaf_nodes",
1376+
"min_impurity_decrease", "min_impurity_split",
13331377
"random_state"),
13341378
bootstrap=bootstrap,
13351379
oob_score=oob_score,
@@ -1346,6 +1390,7 @@ def __init__(self,
13461390
self.min_weight_fraction_leaf = min_weight_fraction_leaf
13471391
self.max_features = max_features
13481392
self.max_leaf_nodes = max_leaf_nodes
1393+
self.min_impurity_decrease = min_impurity_decrease
13491394
self.min_impurity_split = min_impurity_split
13501395

13511396

@@ -1426,11 +1471,23 @@ class ExtraTreesRegressor(ForestRegressor):
14261471
Best nodes are defined as relative reduction in impurity.
14271472
If None then unlimited number of leaf nodes.
14281473
1429-
min_impurity_split : float, optional (default=1e-7)
1430-
Threshold for early stopping in tree growth. A node will split
1431-
if its impurity is above the threshold, otherwise it is a leaf.
1474+
min_impurity_decrease : float, optional (default=0.)
1475+
A node will be split if this split induces a decrease of the impurity
1476+
greater than or equal to this value.
14321477
1433-
.. versionadded:: 0.18
1478+
The weighted impurity decrease equation is the following::
1479+
1480+
N_t / N * (impurity - N_t_R / N_t * right_impurity
1481+
- N_t_L / N_t * left_impurity)
1482+
1483+
where ``N`` is the total number of samples, ``N_t`` is the number of
1484+
samples at the current node, ``N_t_L`` is the number of samples in the
1485+
left child, and ``N_t_R`` is the number of samples in the right child.
1486+
1487+
``N``, ``N_t``, ``N_t_R`` and ``N_t_L`` all refer to the weighted sum,
1488+
if ``sample_weight`` is passed.
1489+
1490+
.. versionadded:: 0.19
14341491
14351492
bootstrap : boolean, optional (default=False)
14361493
Whether bootstrap samples are used when building trees.
@@ -1496,7 +1553,8 @@ def __init__(self,
14961553
min_weight_fraction_leaf=0.,
14971554
max_features="auto",
14981555
max_leaf_nodes=None,
1499-
min_impurity_split=1e-7,
1556+
min_impurity_decrease=0.,
1557+
min_impurity_split=None,
15001558
bootstrap=False,
15011559
oob_score=False,
15021560
n_jobs=1,
@@ -1508,7 +1566,8 @@ def __init__(self,
15081566
n_estimators=n_estimators,
15091567
estimator_params=("criterion", "max_depth", "min_samples_split",
15101568
"min_samples_leaf", "min_weight_fraction_leaf",
1511-
"max_features", "max_leaf_nodes", "min_impurity_split",
1569+
"max_features", "max_leaf_nodes",
1570+
"min_impurity_decrease", "min_impurity_split",
15121571
"random_state"),
15131572
bootstrap=bootstrap,
15141573
oob_score=oob_score,
@@ -1524,6 +1583,7 @@ def __init__(self,
15241583
self.min_weight_fraction_leaf = min_weight_fraction_leaf
15251584
self.max_features = max_features
15261585
self.max_leaf_nodes = max_leaf_nodes
1586+
self.min_impurity_decrease = min_impurity_decrease
15271587
self.min_impurity_split = min_impurity_split
15281588

15291589

@@ -1584,11 +1644,26 @@ class RandomTreesEmbedding(BaseForest):
15841644
Best nodes are defined as relative reduction in impurity.
15851645
If None then unlimited number of leaf nodes.
15861646
1587-
min_impurity_split : float, optional (default=1e-7)
1588-
Threshold for early stopping in tree growth. A node will split
1589-
if its impurity is above the threshold, otherwise it is a leaf.
1647+
min_impurity_decrease : float, optional (default=0.)
1648+
A node will be split if this split induces a decrease of the impurity
1649+
greater than or equal to this value.
15901650
1591-
.. versionadded:: 0.18
1651+
The weighted impurity decrease equation is the following::
1652+
1653+
N_t / N * (impurity - N_t_R / N_t * right_impurity
1654+
- N_t_L / N_t * left_impurity)
1655+
1656+
where ``N`` is the total number of samples, ``N_t`` is the number of
1657+
samples at the current node, ``N_t_L`` is the number of samples in the
1658+
left child, and ``N_t_R`` is the number of samples in the right child.
1659+
1660+
``N``, ``N_t``, ``N_t_R`` and ``N_t_L`` all refer to the weighted sum,
1661+
if ``sample_weight`` is passed.
1662+
1663+
.. versionadded:: 0.19
1664+
1665+
bootstrap : boolean, optional (default=True)
1666+
Whether bootstrap samples are used when building trees.
15921667
15931668
sparse_output : bool, optional (default=True)
15941669
Whether or not to return a sparse CSR matrix, as default behavior,
@@ -1634,7 +1709,8 @@ def __init__(self,
16341709
min_samples_leaf=1,
16351710
min_weight_fraction_leaf=0.,
16361711
max_leaf_nodes=None,
1637-
min_impurity_split=1e-7,
1712+
min_impurity_decrease=0.,
1713+
min_impurity_split=None,
16381714
sparse_output=True,
16391715
n_jobs=1,
16401716
random_state=None,
@@ -1645,7 +1721,8 @@ def __init__(self,
16451721
n_estimators=n_estimators,
16461722
estimator_params=("criterion", "max_depth", "min_samples_split",
16471723
"min_samples_leaf", "min_weight_fraction_leaf",
1648-
"max_features", "max_leaf_nodes", "min_impurity_split",
1724+
"max_features", "max_leaf_nodes",
1725+
"min_impurity_decrease", "min_impurity_split",
16491726
"random_state"),
16501727
bootstrap=False,
16511728
oob_score=False,
@@ -1661,6 +1738,7 @@ def __init__(self,
16611738
self.min_weight_fraction_leaf = min_weight_fraction_leaf
16621739
self.max_features = 1
16631740
self.max_leaf_nodes = max_leaf_nodes
1741+
self.min_impurity_decrease = min_impurity_decrease
16641742
self.min_impurity_split = min_impurity_split
16651743
self.sparse_output = sparse_output
16661744

0 commit comments

Comments
 (0)
0