-
-
Notifications
You must be signed in to change notification settings - Fork 25.9k
[MRG] Adds Minimal Cost-Complexity Pruning to Decision Trees #12887
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
a5f295a
9569e9f
1a554f6
84dbc05
5e10962
745cd18
c1cd149
90c294e
5c36185
4b277b9
b83b135
ffece26
b2e2a52
e95829f
c313151
fd5be88
2e348db
75709a0
efe9793
568eb04
eb28d50
847e1f0
fa9c83c
81c776e
0d85747
57963d5
25910e0
e2cd686
a43972a
39dbccd
1a347f8
43a656b
e59b662
6465355
bcfbfc3
b17433c
71d0513
7ee455e
2e62490
bba792d
ded8552
3623657
2a3b554
b0d76fc
97229ec
013ca9e
791077d
0fa13ed
af54d21
88f0011
ccd47d1
ec1b9fc
4a4b2ac
8132d2d
3e5486d
188ccb8
218311f
b8a2769
697a383
2de7dfd
7452f1f
a199ce8
dc6b6fd
45b5cdc
abf41ca
8cc77ca
971f85a
e81f2a3
bc956ca
7f620a8
d610101
b9247fc
cc5f1a9
5e2ace3
5b50196
f612457
86fdbc6
dda0f5e
40bab1a
0a06e46
9bf7d83
31e7816
7994897
2a42e0c
e8e3967
17b4112
1a8f07e
17d3888
9b01fc8
073fd00
1774b8c
73cdf1e
a688f60
82f3aa1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -534,16 +534,57 @@ Mean Absolute Error: | |
|
||
where :math:`X_m` is the training data in node :math:`m` | ||
|
||
|
||
.. _minimal_cost_complexity_pruning: | ||
|
||
Minimal Cost-Complexity Pruning | ||
=============================== | ||
|
||
Minimal cost-complexity pruning is an algorithm used to prune a tree to avoid | ||
over-fitting, described in Chapter 3 of [BRE]_. This algorithm is parameterized | ||
by :math:`\alpha\ge0` known as the complexity parameter. The complexity | ||
parameter is used to define the cost-complexity measure, :math:`R_\alpha(T)` of | ||
a given tree :math:`T`: | ||
|
||
.. math:: | ||
|
||
R_\alpha(T) = R(T) + \alpha|T| | ||
|
||
where :math:`|T|` is the number of terminal nodes in :math:`T` and :math:`R(T)` | ||
is traditionally defined as the total misclassification rate of the terminal | ||
nodes. Alternatively, scikit-learn uses the total sample weighted impurity of | ||
the terminal nodes for :math:`R(T)`. As shown above, the impurity of a node | ||
depends on the criterion. Minimal cost-complexity pruning finds the subtree of | ||
:math:`T` that minimizes :math:`R_\alpha(T)`. | ||
|
||
The cost complexity measure of a single node is | ||
:math:`R_\alpha(t)=R(t)+\alpha`. The branch, :math:`T_t`, is defined to be a | ||
tree where node :math:`t` is its root. In general, the impurity of a node | ||
is greater than the sum of impurities of its terminal nodes, | ||
:math:`R(T_t)<R(t)`. However, the cost complexity measure of a node, | ||
:math:`t`, and its branch, :math:`T_t`, can be equal depending on | ||
:math:`\alpha`. We define the effective :math:`\alpha` of a node to be the | ||
value where they are equal, :math:`R_\alpha(T_t)=R_\alpha(t)` or | ||
:math:`\alpha_{eff}(t)=\frac{R(t)-R(T_t)}{|T|-1}`. A non-terminal node | ||
with the smallest value of :math:`\alpha_{eff}` is the weakest link and will | ||
be pruned. This process stops when the pruned tree's minimal | ||
:math:`\alpha_{eff}` is greater than the ``ccp_alpha`` parameter. | ||
|
||
.. topic:: Examples: | ||
|
||
* :ref:`sphx_glr_auto_examples_tree_plot_cost_complexity_pruning.py` | ||
|
||
.. topic:: References: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is there a reference for the pruning to add here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. and here There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The pruning comes from L. Breiman, J. Friedman, R. Olshen, and C. Stone. Classification and Regression Trees (Chapter 3) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please mention directly link to it in the section. There are many references, it's not obvious which one is used for what. |
||
|
||
.. [BRE] L. Breiman, J. Friedman, R. Olshen, and C. Stone. Classification | ||
and Regression Trees. Wadsworth, Belmont, CA, 1984. | ||
|
||
* https://en.wikipedia.org/wiki/Decision_tree_learning | ||
|
||
* https://en.wikipedia.org/wiki/Predictive_analytics | ||
|
||
* L. Breiman, J. Friedman, R. Olshen, and C. Stone. Classification and | ||
Regression Trees. Wadsworth, Belmont, CA, 1984. | ||
|
||
* J.R. Quinlan. C4. 5: programs for machine learning. Morgan Kaufmann, 1993. | ||
* J.R. Quinlan. C4. 5: programs for machine learning. Morgan | ||
Kaufmann, 1993. | ||
|
||
* T. Hastie, R. Tibshirani and J. Friedman. | ||
Elements of Statistical Learning, Springer, 2009. | ||
* T. Hastie, R. Tibshirani and J. Friedman. Elements of Statistical | ||
Learning, Springer, 2009. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
""" | ||
======================================================== | ||
Post pruning decision trees with cost complexity pruning | ||
======================================================== | ||
|
||
.. currentmodule:: sklearn.tree | ||
|
||
The :class:`DecisionTreeClassifier` provides parameters such as | ||
``min_samples_leaf`` and ``max_depth`` to prevent a tree from overfiting. Cost | ||
complexity pruning provides another option to control the size of a tree. In | ||
:class:`DecisionTreeClassifier`, this pruning technique is parameterized by the | ||
cost complexity parameter, ``ccp_alpha``. Greater values of ``ccp_alpha`` | ||
increase the number of nodes pruned. Here we only show the effect of | ||
``ccp_alpha`` on regularizing the trees and how to choose a ``ccp_alpha`` | ||
based on validation scores. | ||
|
||
See also `ref`:_minimal_cost_complexity_pruning` for details on pruning. | ||
""" | ||
|
||
print(__doc__) | ||
import matplotlib.pyplot as plt | ||
from sklearn.model_selection import train_test_split | ||
from sklearn.datasets import load_breast_cancer | ||
from sklearn.tree import DecisionTreeClassifier | ||
|
||
############################################################################### | ||
# Total impurity of leaves vs effective alphas of pruned tree | ||
# --------------------------------------------------------------- | ||
# Minimal cost complexity pruning recursively finds the node with the "weakest | ||
# link". The weakest link is characterized by an effective alpha, where the | ||
# nodes with the smallest effective alpha are pruned first. To get an idea of | ||
# what values of ``ccp_alpha`` could be appropriate, scikit-learn provides | ||
# :func:`DecisionTreeClassifier.cost_complexity_pruning_path` that returns the | ||
# effective alphas and the corresponding total leaf impurities at each step of | ||
# the pruning process. As alpha increases, more of the tree is pruned, which | ||
# increases the total impurity of its leaves. | ||
X, y = load_breast_cancer(return_X_y=True) | ||
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) | ||
|
||
clf = DecisionTreeClassifier(random_state=0) | ||
path = clf.cost_complexity_pruning_path(X_train, y_train) | ||
ccp_alphas, impurities = path.ccp_alphas, path.impurities | ||
|
||
############################################################################### | ||
# In the following plot, the maximum effective alpha value is removed, because | ||
# it is the trivial tree with only one node. | ||
fig, ax = plt.subplots() | ||
ax.plot(ccp_alphas[:-1], impurities[:-1], marker='o', drawstyle="steps-post") | ||
ax.set_xlabel("effective alpha") | ||
ax.set_ylabel("total impurity of leaves") | ||
ax.set_title("Total Impurity vs effective alpha for training set") | ||
|
||
############################################################################### | ||
# Next, we train a decision tree using the effective alphas. The last value | ||
# in ``ccp_alphas`` is the alpha value that prunes the whole tree, | ||
# leaving the tree, ``clfs[-1]``, with one node. | ||
clfs = [] | ||
for ccp_alpha in ccp_alphas: | ||
clf = DecisionTreeClassifier(random_state=0, ccp_alpha=ccp_alpha) | ||
clf.fit(X_train, y_train) | ||
clfs.append(clf) | ||
print("Number of nodes in the last tree is: {} with ccp_alpha: {}".format( | ||
clfs[-1].tree_.node_count, ccp_alphas[-1])) | ||
|
||
############################################################################### | ||
# For the remainder of this example, we remove the last element in | ||
# ``clfs`` and ``ccp_alphas``, because it is the trivial tree with only one | ||
# node. Here we show that the number of nodes and tree depth decreases as alpha | ||
# increases. | ||
clfs = clfs[:-1] | ||
ccp_alphas = ccp_alphas[:-1] | ||
|
||
node_counts = [clf.tree_.node_count for clf in clfs] | ||
depth = [clf.tree_.max_depth for clf in clfs] | ||
fig, ax = plt.subplots(2, 1) | ||
ax[0].plot(ccp_alphas, node_counts, marker='o', drawstyle="steps-post") | ||
ax[0].set_xlabel("alpha") | ||
ax[0].set_ylabel("number of nodes") | ||
ax[0].set_title("Number of nodes vs alpha") | ||
ax[1].plot(ccp_alphas, depth, marker='o', drawstyle="steps-post") | ||
ax[1].set_xlabel("alpha") | ||
ax[1].set_ylabel("depth of tree") | ||
ax[1].set_title("Depth vs alpha") | ||
fig.tight_layout() | ||
|
||
############################################################################### | ||
# Accuracy vs alpha for training and testing sets | ||
# ---------------------------------------------------- | ||
# When ``ccp_alpha`` is set to zero and keeping the other default parameters | ||
# of :class:`DecisionTreeClassifier`, the tree overfits, leading to | ||
# a 100% training accuracy and 88% testing accuracy. As alpha increases, more | ||
adrinjalali marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# of the tree is pruned, thus creating a decision tree that generalizes better. | ||
# In this example, setting ``ccp_alpha=0.015`` maximizes the testing accuracy. | ||
train_scores = [clf.score(X_train, y_train) for clf in clfs] | ||
test_scores = [clf.score(X_test, y_test) for clf in clfs] | ||
|
||
fig, ax = plt.subplots() | ||
ax.set_xlabel("alpha") | ||
ax.set_ylabel("accuracy") | ||
ax.set_title("Accuracy vs alpha for training and testing sets") | ||
ax.plot(ccp_alphas, train_scores, marker='o', label="train", | ||
drawstyle="steps-post") | ||
ax.plot(ccp_alphas, test_scores, marker='o', label="test", | ||
drawstyle="steps-post") | ||
ax.legend() | ||
plt.show() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's add a link to this section in every docstring for
ccp_alpha
. Else, there's no way users can know whatccp_alpha
really is and how it works when they look e.g. at the RandomForest docstring.