8000 [MRG] Adds Minimal Cost-Complexity Pruning to Decision Trees by thomasjpfan · Pull Request #12887 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

[MRG] Adds Minimal Cost-Complexity Pruning to Decision Trees #12887

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 94 commits into from
Aug 20, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
94 commits
Select commit Hold shift + click to select a range
a5f295a
ENH: Adds cost complexity pruning
thomasjpfan Dec 28, 2018
9569e9f
Merge remote-tracking branch 'upstream/master' into ccp_prune_tree
thomasjpfan Dec 28, 2018
1a554f6
DOC: Update
thomasjpfan Dec 28, 2018
84dbc05
DOC: Adds comments to algorithm
thomasjpfan Dec 28, 2018
5e10962
RFC: Small
thomasjpfan Dec 28, 2018
745cd18
RFC: Moves some logic to cython
thomasjpfan Dec 28, 2018
c1cd149
DOC: More comments
thomasjpfan Dec 28, 2018
90c294e
Merge remote-tracking branch 'upstream/master' into ccp_prune_tree
thomasjpfan Dec 28, 2018
5c36185
DOC: Removes unused parameter
thomasjpfan Dec 29, 2018
4b277b9
DOC: Rewords
thomasjpfan Dec 29, 2018
b83b135
Merge remote-tracking branch 'upstream/master' into ccp_prune_tree
thomasjpfan Dec 29, 2018
ffece26
ENH: Adds support for extra trees
thomasjpfan Dec 29, 2018
8000 b2e2a52
DOC: Updates whats_new
thomasjpfan Dec 29, 2018
e95829f
RFC: Makes prune_tree public
thomasjpfan Dec 29, 2018
c313151
RFC: Less diffs
thomasjpfan Dec 29, 2018
fd5be88
RFC: Moves prune_tree closer to the end of fit
thomasjpfan Jan 1, 2019
2e348db
Merge remote-tracking branch 'upstream/master' into ccp_prune_tree
thomasjpfan Jan 1, 2019
75709a0
BUG: Fix
thomasjpfan Jan 1, 2019
efe9793
BUG: Fix
thomasjpfan Jan 1, 2019
568eb04
RFC: Addresses code review
thomasjpfan Jan 29, 2019
eb28d50
Merge remote-tracking branch 'upstream/master' into ccp_prune_tree
thomasjpfan Jan 29, 2019
847e1f0
RFC: Minimize diffs
thomasjpfan Jan 29, 2019
fa9c83c
RFC: Uses memoryviews
thomasjpfan Jan 29, 2019
81c776e
RFC: Deterministic ordering
thomasjpfan Jan 29, 2019
0d85747
ENH: Returns tree with greatest CCP less than alpha
thomasjpfan Jan 30, 2019
57963d5
RFC: Rename alpha to ccp_alpha
thomasjpfan Jan 31, 2019
25910e0
DOC: Uses ccp_alpha
thomasjpfan Jan 31, 2019
e2cd686
ENH: Users cython for pruning
thomasjpfan Feb 4, 2019
a43972a
ENH: Adds ccp_alpha to forest
thomasjpfan Feb 5, 2019
39dbccd
Merge remote-tracking branch 'upstream/master' into ccp_prune_tree
thomasjpfan Feb 5, 2019
1a347f8
BUG: Fixes doctest
thomasjpfan Feb 5, 2019
43a656b
ENH: Releases gil
thomasjpfan Feb 5, 2019
e59b662
BUG: Fix
thomasjpfan Feb 5, 2019
6465355
RFC Address comments
thomasjpfan Feb 7, 2019
bcfbfc3
STY Flake8
thomasjpfan Feb 7, 2019
b17433c
Merge remote-tracking branch 'upstream/master' into ccp_prune_tree
thomasjpfan Feb 7, 2019
71d0513
DOC adds raw for math
thomasjpfan Feb 7, 2019
7ee455e
RFC Address comments
thomasjpfan Feb 8, 2019
2e62490
DOC Adds pruning to user guide
thomasjpfan Feb 8, 2019
bba792d
DOC English
thomasjpfan Feb 8, 2019
ded8552
DOC Adds forests to whats_new
thomasjpfan Feb 8, 2019
3623657
ENH Adds pruning to gradient boosting
thomasjpfan Feb 8, 2019
2a3b554
DOC Fixes whats_new
thomasjpfan Feb 11, 2019
b0d76fc
DOC Show plt at the end
thomasjpfan Feb 12, 2019
97229ec
RFC Removes unneeded code
thomasjpfan Feb 13, 2019
013ca9e
STY pep257
thomasjpfan Feb 15, 2019
791077d
TST Adds prune all leaves test
thomasjpfan Feb 16, 2019
0fa13ed
RFC Address comments
thomasjpfan Feb 16, 2019
af54d21
DOC Adds more details
thomasjpfan Feb 27, 2019
88f0011
Merge remote-tracking branch 'upstream/master' into ccp_prune_tree
thomasjpfan Feb 28, 2019
ccd47d1
CLN Address comments
thomasjpfan Mar 12, 2019
ec1b9fc
DOC Fix
thomasjpfan Mar 12, 2019
4a4b2ac
Merge remote-tracking branch 'upstream/master' into ccp_prune_tree
thomasjpfan Apr 18, 2019
8132d2d
Merge remote-tracking branch 'upstream/master' into ccp_prune_tree
thomasjpfan Apr 26, 2019
8000
3e5486d
Merge remote-tracking branch 'upstream/master' into ccp_prune_tree
thomasjpfan Apr 26, 2019
188ccb8
ENH Adds cost complexity pruning path
thomasjpfan Apr 26, 2019
218311f
DOC Adds docstring
thomasjpfan Apr 26, 2019
b8a2769
Merge remote-tracking branch 'upstream/master' into ccp_prune_tree
thomasjpfan May 6, 2019
697a383
DOC Move whats_new
thomasjpfan May 6, 2019
2de7dfd
ENH Adds impurity tracking to pruning
thomasjpfan May 6, 2019
7452f1f
DOC New example using path function
thomasjpfan May 7, 2019
a199ce8
DOC Adjust titles
thomasjpfan May 7, 2019
dc6b6fd
Merge remote-tracking branch 'upstream/master' into ccp_prune_tree
thomasjpfan May 7, 2019
45b5cdc
ENH Returns a bunch when calcuating path
thomasjpfan May 20, 2019
abf41ca
BUG Uses bunch in tests
thomasjpfan May 21, 2019
8cc77ca
DOC Adds more details in example
thomasjpfan May 21, 2019
971f85a
CLN Adds more comments
thomasjpfan May 21, 2019
e81f2a3
Merge remote-tracking branch 'upstream/master' into ccp_prune_tree
thomasjpfan May 21, 2019
bc956ca
Merge remote-tracking branch 'upstream/master' into ccp_prune_tree
thomasjpfan May 31, 2019
7f620a8
DOC Removes last node in all plots
thomasjpfan Jun 3, 2019
d610101
Merge remote-tracking branch 'upstream/master' into ccp_prune_tree
thomasjpfan Jun 3, 2019
b9247fc
DOC Adjust layout
thomasjpfan Jun 4, 2019
cc5f1a9
CLN Address comments
thomasjpfan Jun 6, 2019
5e2ace3
CLN Adds error message to MemoryError
thomasjpfan Jun 17, 2019
5b50196
CLN Adds alpha dependency of t
thomasjpfan Jun 17, 2019
f612457
DOC Update wording
thomasjpfan Jun 17, 2019
86fdbc6
Merge remote-tracking branch 'upstream/master' into ccp_prune_tree
thomasjpfan Jul 17, 2019
dda0f5e
Merge remote-tracking branch 'upstream/master' into ccp_prune_tree
thomasjpfan Jul 30, 2019
40bab1a
CLN Remove file
thomasjpfan Jul 30, 2019
0a06e46
CLN Address comments
thomasjpfan Jul 30, 2019
9bf7d83
Merge remote-tracking branch 'upstream/master' into ccp_prune_tree
thomasjpfan Jul 30, 2019
31e7816
Merge remote-tracking branch 'upstream/master' into ccp_prune_tree
thomasjpfan Aug 16, 2019
7994897
CLN Address NicolasHug's comments
thomasjpfan Aug 16, 2019
2a42e0c
CLN Refactors tests to use pruning_path
thomasjpfan Aug 16, 2019
e8e3967
TST Adds single node tree test
thomasjpfan Aug 16, 2019
17b4112
STY flake8
thomasjpfan Aug 16, 2019
1a8f07e
TST Adds test on impurities from path
thomasjpfan Aug 16, 2019
17d3888
DOC Adds words
thomasjpfan Aug 16, 2019
9b01fc8
DOC Adds words
thomasjpfan Aug 16, 2019
073fd00
Merge remote-tracking branch 'upstream/master' into ccp_prune_tree
thomasjpfan Aug 16, 2019
1774b8c
DOC Better words
thomasjpfan Aug 16, 2019
73cdf1e
DOC Adds docstring to ccp_pruning_path
thomasjpfan Aug 16, 2019
a688f60
DOC Uses new standrad
thomasjpfan Aug 16, 2019
82f3aa1
CLN Address joels comments
thomasjpfan Aug 18, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 47 additions & 6 deletions doc/modules/tree.rst
Original file line number Diff line number Diff line change
Expand Up @@ -534,16 +534,57 @@ Mean Absolute Error:

where :math:`X_m` is the training data in node :math:`m`


.. _minimal_cost_complexity_pruning:

Minimal Cost-Complexity Pruning
===============================
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add a link to this section in every docstring for ccp_alpha. Else, there's no way users can know what ccp_alpha really is and how it works when they look e.g. at the RandomForest docstring.


Minimal cost-complexity pruning is an algorithm used to prune a tree to avoid
over-fitting, described in Chapter 3 of [BRE]_. This algorithm is parameterized
by :math:`\alpha\ge0` known as the complexity parameter. The complexity
parameter is used to define the cost-complexity measure, :math:`R_\alpha(T)` of
a given tree :math:`T`:

.. math::

R_\alpha(T) = R(T) + \alpha|T|

where :math:`|T|` is the number of terminal nodes in :math:`T` and :math:`R(T)`
is traditionally defined as the total misclassification rate of the terminal
nodes. Alternatively, scikit-learn uses the total sample weighted impurity of
the terminal nodes for :math:`R(T)`. As shown above, the impurity of a node
depends on the criterion. Minimal cost-complexity pruning finds the subtree of
:math:`T` that minimizes :math:`R_\alpha(T)`.

The cost complexity measure of a single node is
:math:`R_\alpha(t)=R(t)+\alpha`. The branch, :math:`T_t`, is defined to be a
tree where node :math:`t` is its root. In general, the impurity of a node
is greater than the sum of impurities of its terminal nodes,
:math:`R(T_t)<R(t)`. However, the cost complexity measure of a node,
:math:`t`, and its branch, :math:`T_t`, can be equal depending on
:math:`\alpha`. We define the effective :math:`\alpha` of a node to be the
value where they are equal, :math:`R_\alpha(T_t)=R_\alpha(t)` or
:math:`\alpha_{eff}(t)=\frac{R(t)-R(T_t)}{|T|-1}`. A non-terminal node
with the smallest value of :math:`\alpha_{eff}` is the weakest link and will
be pruned. This process stops when the pruned tree's minimal
:math:`\alpha_{eff}` is greater than the ``ccp_alpha`` parameter.

.. topic:: Examples:

* :ref:`sphx_glr_auto_examples_tree_plot_cost_complexity_pruning.py`

.. topic:: References:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a reference for the pruning to add here?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and here

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The pruning comes from L. Breiman, J. Friedman, R. Olshen, and C. Stone. Classification and Regression Trees (Chapter 3)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please mention directly link to it in the section. There are many references, it's not obvious which one is used for what.


.. [BRE] L. Breiman, J. Friedman, R. Olshen, and C. Stone. Classification
and Regression Trees. Wadsworth, Belmont, CA, 1984.

* https://en.wikipedia.org/wiki/Decision_tree_learning

* https://en.wikipedia.org/wiki/Predictive_analytics

* L. Breiman, J. Friedman, R. Olshen, and C. Stone. Classification and
Regression Trees. Wadsworth, Belmont, CA, 1984.

* J.R. Quinlan. C4. 5: programs for machine learning. Morgan Kaufmann, 1993.
* J.R. Quinlan. C4. 5: programs for machine learning. Morgan
Kaufmann, 1993.

* T. Hastie, R. Tibshirani and J. Friedman.
Elements of Statistical Learning, Springer, 2009.
* T. Hastie, R. Tibshirani and J. Friedman. Elements of Statistical
Learning, Springer, 2009.
15 changes: 15 additions & 0 deletions doc/whats_new/v0.22.rst
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,21 @@ Changelog
- |Enhancement| SVM now throws more specific error when fit on non-square data
and kernel = precomputed. :class:`svm.BaseLibSVM`
:pr:`14336` by :user:`Gregory Dexter <gdex1>`.

:mod:`sklearn.tree`
...................

- |Feature| Adds minimal cost complexity pruning, controlled by ``ccp_alpha``,
to :class:`tree.DecisionTreeClassifier`, :class:`tree.DecisionTreeRegressor`,
:class:`tree.ExtraTreeClassifier`, :class:`tree.ExtraTreeRegressor`,
:class:`ensemble.RandomForestClassifier`,
:class:`ensemble.RandomForestRegressor`,
:class:`ensemble.ExtraTreesClassifier`,
:class:`ensemble.ExtraTreesRegressor`,
:class:`ensemble.RandomTreesEmbedding`,
:class:`ensemble.GradientBoostingClassifier`,
and :class:`ensemble.GradientBoostingRegressor`.
:pr:`12887` by `Thomas Fan`_.

:mod:`sklearn.preprocessing`
............................
Expand Down
106 changes: 106 additions & 0 deletions examples/tree/plot_cost_complexity_pruning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
"""
========================================================
Post pruning decision trees with cost complexity pruning
========================================================

.. currentmodule:: sklearn.tree

The :class:`DecisionTreeClassifier` provides parameters such as
``min_samples_leaf`` and ``max_depth`` to prevent a tree from overfiting. Cost
complexity pruning provides another option to control the size of a tree. In
:class:`DecisionTreeClassifier`, this pruning technique is parameterized by the
cost complexity parameter, ``ccp_alpha``. Greater values of ``ccp_alpha``
increase the number of nodes pruned. Here we only show the effect of
``ccp_alpha`` on regularizing the trees and how to choose a ``ccp_alpha``
based on validation scores.

See also `ref`:_minimal_cost_complexity_pruning` for details on pruning.
"""

print(__doc__)
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_breast_cancer
from sklearn.tree import DecisionTreeClassifier

###############################################################################
# Total impurity of leaves vs effective alphas of pruned tree
# ---------------------------------------------------------------
# Minimal cost complexity pruning recursively finds the node with the "weakest
# link". The weakest link is characterized by an effective alpha, where the
# nodes with the smallest effective alpha are pruned first. To get an idea of
# what values of ``ccp_alpha`` could be appropriate, scikit-learn provides
# :func:`DecisionTreeClassifier.cost_complexity_pruning_path` that returns the
# effective alphas and the corresponding total leaf impurities at each step of
# the pruning process. As alpha increases, more of the tree is pruned, which
# increases the total impurity of its leaves.
X, y = load_breast_cancer(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)

clf = DecisionTreeClassifier(random_state=0)
path = clf.cost_complexity_pruning_path(X_train, y_train)
ccp_alphas, impurities = path.ccp_alphas, path.impurities

###############################################################################
# In the following plot, the maximum effective alpha value is removed, because
# it is the trivial tree with only one node.
fig, ax = plt.subplots()
ax.plot(ccp_alphas[:-1], impurities[:-1], marker='o', drawstyle="steps-post")
ax.set_xlabel("effective alpha")
ax.set_ylabel("total impurity of leaves")
ax.set_title("Total Impurity vs effective alpha for training set")

###############################################################################
# Next, we train a decision tree using the effective alphas. The last value
# in ``ccp_alphas`` is the alpha value that prunes the whole tree,
# leaving the tree, ``clfs[-1]``, with one node.
clfs = []
for ccp_alpha in ccp_alphas:
clf = DecisionTreeClassifier(random_state=0, ccp_alpha=ccp_alpha)
clf.fit(X_train, y_train)
clfs.append(clf)
print("Number of nodes in the last tree is: {} with ccp_alpha: {}".format(
clfs[-1].tree_.node_count, ccp_alphas[-1]))

###############################################################################
# For the remainder of this example, we remove the last element in
# ``clfs`` and ``ccp_alphas``, because it is the trivial tree with only one
# node. Here we show that the number of nodes and tree depth decreases as alpha
# increases.
clfs = clfs[:-1]
ccp_alphas = ccp_alphas[:-1]

node_counts = [clf.tree_.node_count for clf in clfs]
depth = [clf.tree_.max_depth for clf in clfs]
fig, ax = plt.subplots(2, 1)
ax[0].plot(ccp_alphas, node_counts, marker='o', drawstyle="steps-post")
ax[0].set_xlabel("alpha")
ax[0].set_ylabel("number of nodes")
ax[0].set_title("Number of nodes vs alpha")
ax[1].plot(ccp_alphas, depth, marker='o', drawstyle="steps-post")
ax[1].set_xlabel("alpha")
ax[1].set_ylabel("depth of tree")
ax[1].set_title("Depth vs alpha")
fig.tight_layout()

###############################################################################
# Accuracy vs alpha for training and testing sets
# ----------------------------------------------------
# When ``ccp_alpha`` is set to zero and keeping the other default parameters
# of :class:`DecisionTreeClassifier`, the tree overfits, leading to
# a 100% training accuracy and 88% testing accuracy. As alpha increases, more
# of the tree is pruned, thus creating a decision tree that generalizes better.
# In this example, setting ``ccp_alpha=0.015`` maximizes the testing accuracy.
train_scores = [clf.score(X_train, y_train) for clf in clfs]
test_scores = [clf.score(X_test, y_test) for clf in clfs]

fig, ax = plt.subplots()
ax.set_xlabel("alpha")
ax.set_ylabel("accuracy")
ax.set_title("Accuracy vs alpha for training and testing sets")
ax.plot(ccp_alphas, train_scores, marker='o', label="train",
drawstyle="steps-post")
ax.plot(ccp_alphas, test_scores, marker='o', label="test",
drawstyle="steps-post")
ax.legend()
plt.show()
Loading
0