8000 Add Minimal Cost-Complexity Pruning to Decision Trees (#12887) · scikit-learn/scikit-learn@67c94c7 · GitHub
[go: up one dir, main page]

Skip to content

Commit 67c94c7

Browse files
thomasjpfanNicolasHug
authored andcommitted
Add Minimal Cost-Complexity Pruning to Decision Trees (#12887)
1 parent 8fe3c0a commit 67c94c7

File tree

8 files changed

+903
-33
lines changed

8 files changed

+903
-33
lines changed

doc/modules/tree.rst

Lines changed: 47 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -534,16 +534,57 @@ Mean Absolute Error:
534534
535535
where :math:`X_m` is the training data in node :math:`m`
536536

537+
538+
.. _minimal_cost_complexity_pruning:
539+
540+
Minimal Cost-Complexity Pruning
541+
===============================
542+
543+
Minimal cost-complexity pruning is an algorithm used to prune a tree to avoid
544+
over-fitting, described in Chapter 3 of [BRE]_. This algorithm is parameterized
545+
by :math:`\alpha\ge0` known as the complexity parameter. The complexity
546+
parameter is used to define the cost-complexity measure, :math:`R_\alpha(T)` of
547+
a given tree :math:`T`:
548+
549+
.. math::
550+
551+
R_\alpha(T) = R(T) + \alpha|T|
552+
553+
where :math:`|T|` is the number of terminal nodes in :math:`T` and :math:`R(T)`
554+
is traditionally defined as the total misclassification rate of the terminal
555+
nodes. Alternatively, scikit-learn uses the total sample weighted impurity of
556+
the terminal nodes for :math:`R(T)`. As shown above, the impurity of a node
557+
depends on the criterion. Minimal cost-complexity pruning finds the subtree of
558+
:math:`T` that minimizes :math:`R_\alpha(T)`.
559+
560+
The cost complexity measure of a single node is
561+
:math:`R_\alpha(t)=R(t)+\alpha`. The branch, :math:`T_t`, is defined to be a
562+
tree where node :math:`t` is its root. In general, the impurity of a node
563+
is greater than the sum of impurities of its terminal nodes,
564+
:math:`R(T_t)<R(t)`. However, the cost complexity measure of a node,
565+
:math:`t`, and its branch, :math:`T_t`, can be equal depending on
566+
:math:`\alpha`. We define the effective :math:`\alpha` of a node to be the
567+
value where they are equal, :math:`R_\alpha(T_t)=R_\alpha(t)` or
568+
:math:`\alpha_{eff}(t)=\frac{R(t)-R(T_t)}{|T|-1}`. A non-terminal node
569+
with the smallest value of :math:`\alpha_{eff}` is the weakest link and will
570+
be pruned. This process stops when the pruned tree's minimal
571+
:math:`\alpha_{eff}` is greater than the ``ccp_alpha`` parameter.
572+
573+
.. topic:: Examples:
574+
575+
* :ref:`sphx_glr_auto_examples_tree_plot_cost_complexity_pruning.py`
576+
537577
.. topic:: References:
538578

579+
.. [BRE] L. Breiman, J. Friedman, R. Olshen, and C. Stone. Classification
580+
and Regression Trees. Wadsworth, Belmont, CA, 1984.
581+
539582
* https://en.wikipedia.org/wiki/Decision_tree_learning
540583

541584
* https://en.wikipedia.org/wiki/Predictive_analytics
542585

543-
* L. Breiman, J. Friedman, R. Olshen, and C. Stone. Classification and
544-
Regression Trees. Wadsworth, Belmont, CA, 1984.
545-
546-
* J.R. Quinlan. C4. 5: programs for machine learning. Morgan Kaufmann, 1993.
586+
* J.R. Quinlan. C4. 5: programs for machine learning. Morgan
587+
Kaufmann, 1993.
547588

548-
* T. Hastie, R. Tibshirani and J. Friedman.
549-
Elements of Statistical Learning, Springer, 2009.
589+
* T. Hastie, R. Tibshirani and J. Friedman. Elements of Statistical
590+
Learning, Springer, 2009.

doc/whats_new/v0.22.rst

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,21 @@ Changelog
281281
- |Enhancement| SVM now throws more specific error when fit on non-square data
282282
and kernel = precomputed. :class:`svm.BaseLibSVM`
283283
:pr:`14336` by :user:`Gregory Dexter <gdex1>`.
284+
285+
:mod:`sklearn.tree`
286+
...................
287+
288+
- |Feature| Adds minimal cost complexity pruning, controlled by ``ccp_alpha``,
289+
to :class:`tree.DecisionTreeClassifier`, :class:`tree.DecisionTreeRegressor`,
290+
:class:`tree.ExtraTreeClassifier`, :class:`tree.ExtraTreeRegressor`,
291+
:class:`ensemble.RandomForestClassifier`,
292+
:class:`ensemble.RandomForestRegressor`,
293+
:class:`ensemble.ExtraTreesClassifier`,
294+
:class:`ensemble.ExtraTreesRegressor`,
295+
:class:`ensemble.RandomTreesEmbedding`,
296+
:class:`ensemble.GradientBoostingClassifier`,
297+
and :class:`ensemble.GradientBoostingRegressor`.
298+
:pr:`12887` by `Thomas Fan`_.
284299

285300
:mod:`sklearn.preprocessing`
286301
............................
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
"""
2+
========================================================
3+
Post pruning decision trees with cost complexity pruning
4+
========================================================
5+
6+
.. currentmodule:: sklearn.tree
7+
8+
The :class:`DecisionTreeClassifier` provides parameters such as
9+
``min_samples_leaf`` and ``max_depth`` to prevent a tree from overfiting. Cost
10+
complexity pruning provides another option to control the size of a tree. In
11+
:class:`DecisionTreeClassifier`, this pruning technique is parameterized by the
12+
cost complexity parameter, ``ccp_alpha``. Greater values of ``ccp_alpha``
13+
increase the number of nodes pruned. Here we only show the effect of
14+
``ccp_alpha`` on regularizing the trees and how to choose a ``ccp_alpha``
15+
based on validation scores.
16+
17+
See also `ref`:_minimal_cost_complexity_pruning` for details on pruning.
18+
"""
19+
20+
print(__doc__)
21+
import matplotlib.pyplot as plt
22+
from sklearn.model_selection import train_test_split
23+
from sklearn.datasets import load_breast_cancer
24+
from sklearn.tree import DecisionTreeClassifier
25+
26+
###############################################################################
27+
# Total impurity of leaves vs effective alphas of pruned tree
28+
# ---------------------------------------------------------------
29+
# Minimal cost complexity pruning recursively finds the node with the "weakest
30+
# link". The weakest link is characterized by an effective alpha, where the
31+
# nodes with the smallest effective alpha are pruned first. To get an idea of
32+
# what values of ``ccp_alpha`` could be appropriate, scikit-learn provides
33+
# :func:`DecisionTreeClassifier.cost_complexity_pruning_path` that returns the
34+
# effective alphas and the corresponding total leaf impurities at each step of
35+
# the pruning process. As alpha increases, more of the tree is pruned, which
36+
# increases the total impurity of its leaves.
37+
X, y = load_breast_cancer(return_X_y=True)
38+
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
39+
40+
clf = DecisionTreeClassifier(random_state=0)
41+
path = clf.cost_complexity_pruning_path(X_train, y_train)
42+
ccp_alphas, impurities = path.ccp_alphas, path.impurities
43+
44+
###############################################################################
45+
# In the following plot, the maximum effective alpha value is removed, because
46+
# it is the trivial tree with only one node.
47+
fig, ax = plt.subplots()
48+
ax.plot(ccp_alphas[:-1], impurities[:-1], marker='o', drawstyle="steps-post")
49+
ax.set_xlabel("effective alpha")
50+
ax.set_ylabel("total impurity of leaves")
51+
ax.set_title("Total Impurity vs effective alpha for training set")
52+
53+
###############################################################################
54+
# Next, we train a decision tree using the effective alphas. The last value
55+
# in ``ccp_alphas`` is the alpha value that prunes the whole tree,
56+
# leaving the tree, ``clfs[-1]``, with one node.
57+
clfs = []
58+
for ccp_alpha in ccp_alphas:
59+
clf = DecisionTreeClassifier(random_state=0, ccp_alpha=ccp_alpha)
60+
clf.fit(X_train, y_train)
61+
clfs.append(clf)
62+
print("Number of nodes in the last tree is: {} with ccp_alpha: {}".format(
63+
clfs[-1].tree_.node_count, ccp_alphas[-1]))
64+
65+
###############################################################################
66+
# For the remainder of this example, we remove the last element in
67+
# ``clfs`` and ``ccp_alphas``, because it is the trivial tree with only one
68+
# node. Here we show that the number of nodes and tree depth decreases as alpha
69+
# increases.
70+
clfs = clfs[:-1]
71+
ccp_alphas = ccp_alphas[:-1]
72+
73+
node_counts = [clf.tree_.node_count for clf in clfs]
74+
depth = [clf.tree_.max_depth for clf in clfs]
75+
fig, ax = plt.subplots(2, 1)
76+
ax[0].plot(ccp_alphas, node_counts, marker='o', drawstyle="steps-post")
77+
ax[0].set_xlabel("alpha")
78+
ax[0].set_ylabel("number of nodes")
79+
ax[0].set_title("Number of nodes vs alpha")
80+
ax[1].plot(ccp_alphas, depth, marker='o', drawstyle="steps-post")
81+
ax[1].set_xlabel("alpha")
82+
ax[1].set_ylabel("depth of tree")
83+
ax[1].set_title("Depth vs alpha")
84+
fig.tight_layout()
85+
86+
###############################################################################
87+
# Accuracy vs alpha for training and testing sets
88+
# ----------------------------------------------------
89+
# When ``ccp_alpha`` is set to zero and keeping the other default parameters
90+
# of :class:`DecisionTreeClassifier`, the tree overfits, leading to
91+
# a 100% training accuracy and 88% testing accuracy. As alpha increases, more
92+
# of the tree is pruned, thus creating a decision tree that generalizes better.
93+
# In this example, setting ``ccp_alpha=0.015`` maximizes the testing accuracy.
94+
train_scores = [clf.score(X_train, y_train) for clf in clfs]
95+
test_scores = [clf.score(X_test, y_test) for clf in clfs]
96+
97+
fig, ax = plt.subplots()
98+
ax.set_xlabel("alpha")
99+
ax.set_ylabel("accuracy")
100+
ax.set_title("Accuracy vs alpha for training and testing sets")
101+
ax.plot(ccp_alphas, train_scores, marker='o', label="train",
102+
drawstyle="steps-post")
103+
ax.plot(ccp_alphas, test_scores, marker='o', label="test",
104+
drawstyle="steps-post")
105+
ax.legend()
106+
plt.show()

0 commit comments

Comments
 (0)
0