diff --git a/doc/modules/tree.rst b/doc/modules/tree.rst index 1eeabb1e3afc8..2916e4b36f306 100644 --- a/doc/modules/tree.rst +++ b/doc/modules/tree.rst @@ -534,16 +534,57 @@ Mean Absolute Error: where :math:`X_m` is the training data in node :math:`m` + +.. _minimal_cost_complexity_pruning: + +Minimal Cost-Complexity Pruning +=============================== + +Minimal cost-complexity pruning is an algorithm used to prune a tree to avoid +over-fitting, described in Chapter 3 of [BRE]_. This algorithm is parameterized +by :math:`\alpha\ge0` known as the complexity parameter. The complexity +parameter is used to define the cost-complexity measure, :math:`R_\alpha(T)` of +a given tree :math:`T`: + +.. math:: + + R_\alpha(T) = R(T) + \alpha|T| + +where :math:`|T|` is the number of terminal nodes in :math:`T` and :math:`R(T)` +is traditionally defined as the total misclassification rate of the terminal +nodes. Alternatively, scikit-learn uses the total sample weighted impurity of +the terminal nodes for :math:`R(T)`. As shown above, the impurity of a node +depends on the criterion. Minimal cost-complexity pruning finds the subtree of +:math:`T` that minimizes :math:`R_\alpha(T)`. + +The cost complexity measure of a single node is +:math:`R_\alpha(t)=R(t)+\alpha`. The branch, :math:`T_t`, is defined to be a +tree where node :math:`t` is its root. In general, the impurity of a node +is greater than the sum of impurities of its terminal nodes, +:math:`R(T_t)`. + +:mod:`sklearn.tree` +................... + +- |Feature| Adds minimal cost complexity pruning, controlled by ``ccp_alpha``, + to :class:`tree.DecisionTreeClassifier`, :class:`tree.DecisionTreeRegressor`, + :class:`tree.ExtraTreeClassifier`, :class:`tree.ExtraTreeRegressor`, + :class:`ensemble.RandomForestClassifier`, + :class:`ensemble.RandomForestRegressor`, + :class:`ensemble.ExtraTreesClassifier`, + :class:`ensemble.ExtraTreesRegressor`, + :class:`ensemble.RandomTreesEmbedding`, + :class:`ensemble.GradientBoostingClassifier`, + and :class:`ensemble.GradientBoostingRegressor`. + :pr:`12887` by `Thomas Fan`_. :mod:`sklearn.preprocessing` ............................ diff --git a/examples/tree/plot_cost_complexity_pruning.py b/examples/tree/plot_cost_complexity_pruning.py new file mode 100644 index 0000000000000..1a06ac3d18adc --- /dev/null +++ b/examples/tree/plot_cost_complexity_pruning.py @@ -0,0 +1,106 @@ +""" +======================================================== +Post pruning decision trees with cost complexity pruning +======================================================== + +.. currentmodule:: sklearn.tree + +The :class:`DecisionTreeClassifier` provides parameters such as +``min_samples_leaf`` and ``max_depth`` to prevent a tree from overfiting. Cost +complexity pruning provides another option to control the size of a tree. In +:class:`DecisionTreeClassifier`, this pruning technique is parameterized by the +cost complexity parameter, ``ccp_alpha``. Greater values of ``ccp_alpha`` +increase the number of nodes pruned. Here we only show the effect of +``ccp_alpha`` on regularizing the trees and how to choose a ``ccp_alpha`` +based on validation scores. + +See also `ref`:_minimal_cost_complexity_pruning` for details on pruning. +""" + +print(__doc__) +import matplotlib.pyplot as plt +from sklearn.model_selection import train_test_split +from sklearn.datasets import load_breast_cancer +from sklearn.tree import DecisionTreeClassifier + +############################################################################### +# Total impurity of leaves vs effective alphas of pruned tree +# --------------------------------------------------------------- +# Minimal cost complexity pruning recursively finds the node with the "weakest +# link". The weakest link is characterized by an effective alpha, where the +# nodes with the smallest effective alpha are pruned first. To get an idea of +# what values of ``ccp_alpha`` could be appropriate, scikit-learn provides +# :func:`DecisionTreeClassifier.cost_complexity_pruning_path` that returns the +# effective alphas and the corresponding total leaf impurities at each step of +# the pruning process. As alpha increases, more of the tree is pruned, which +# increases the total impurity of its leaves. +X, y = load_breast_cancer(return_X_y=True) +X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) + +clf = DecisionTreeClassifier(random_state=0) +path = clf.cost_complexity_pruning_path(X_train, y_train) +ccp_alphas, impurities = path.ccp_alphas, path.impurities + +############################################################################### +# In the following plot, the maximum effective alpha value is removed, because +# it is the trivial tree with only one node. +fig, ax = plt.subplots() +ax.plot(ccp_alphas[:-1], impurities[:-1], marker='o', drawstyle="steps-post") +ax.set_xlabel("effective alpha") +ax.set_ylabel("total impurity of leaves") +ax.set_title("Total Impurity vs effective alpha for training set") + +############################################################################### +# Next, we train a decision tree using the effective alphas. The last value +# in ``ccp_alphas`` is the alpha value that prunes the whole tree, +# leaving the tree, ``clfs[-1]``, with one node. +clfs = [] +for ccp_alpha in ccp_alphas: + clf = DecisionTreeClassifier(random_state=0, ccp_alpha=ccp_alpha) + clf.fit(X_train, y_train) + clfs.append(clf) +print("Number of nodes in the last tree is: {} with ccp_alpha: {}".format( + clfs[-1].tree_.node_count, ccp_alphas[-1])) + +############################################################################### +# For the remainder of this example, we remove the last element in +# ``clfs`` and ``ccp_alphas``, because it is the trivial tree with only one +# node. Here we show that the number of nodes and tree depth decreases as alpha +# increases. +clfs = clfs[:-1] +ccp_alphas = ccp_alphas[:-1] + +node_counts = [clf.tree_.node_count for clf in clfs] +depth = [clf.tree_.max_depth for clf in clfs] +fig, ax = plt.subplots(2, 1) +ax[0].plot(ccp_alphas, node_counts, marker='o', drawstyle="steps-post") +ax[0].set_xlabel("alpha") +ax[0].set_ylabel("number of nodes") +ax[0].set_title("Number of nodes vs alpha") +ax[1].plot(ccp_alphas, depth, marker='o', drawstyle="steps-post") +ax[1].set_xlabel("alpha") +ax[1].set_ylabel("depth of tree") +ax[1].set_title("Depth vs alpha") +fig.tight_layout() + +############################################################################### +# Accuracy vs alpha for training and testing sets +# ---------------------------------------------------- +# When ``ccp_alpha`` is set to zero and keeping the other default parameters +# of :class:`DecisionTreeClassifier`, the tree overfits, leading to +# a 100% training accuracy and 88% testing accuracy. As alpha increases, more +# of the tree is pruned, thus creating a decision tree that generalizes better. +# In this example, setting ``ccp_alpha=0.015`` maximizes the testing accuracy. +train_scores = [clf.score(X_train, y_train) for clf in clfs] +test_scores = [clf.score(X_test, y_test) for clf in clfs] + +fig, ax = plt.subplots() +ax.set_xlabel("alpha") +ax.set_ylabel("accuracy") +ax.set_title("Accuracy vs alpha for training and testing sets") +ax.plot(ccp_alphas, train_scores, marker='o', label="train", + drawstyle="steps-post") +ax.plot(ccp_alphas, test_scores, marker='o', label="test", + drawstyle="steps-post") +ax.legend() +plt.show() diff --git a/sklearn/ensemble/forest.py b/sklearn/ensemble/forest.py index 911b16b67df5f..c5d3184187807 100644 --- a/sklearn/ensemble/forest.py +++ b/sklearn/ensemble/forest.py @@ -915,6 +915,14 @@ class RandomForestClassifier(ForestClassifier): Note that these weights will be multiplied with sample_weight (passed through the fit method) if sample_weight is specified. + ccp_alpha : non-negative float, optional (default=0.0) + Complexity parameter used for Minimal Cost-Complexity Pruning. The + subtree with the largest cost complexity that is smaller than + ``ccp_alpha`` will be chosen. By default, no pruning is performed. See + :ref:`minimal_cost_complexity_pruning` for details. + + .. versionadded:: 0.22 + Attributes ---------- base_estimator_ : DecisionTreeClassifier @@ -1007,7 +1015,8 @@ def __init__(self, random_state=None, verbose=0, warm_start=False, - class_weight=None): + class_weight=None, + ccp_alpha=0.0): super().__init__( base_estimator=DecisionTreeClassifier(), n_estimators=n_estimators, @@ -1015,7 +1024,7 @@ def __init__(self, "min_samples_leaf", "min_weight_fraction_leaf", "max_features", "max_leaf_nodes", "min_impurity_decrease", "min_impurity_split", - "random_state"), + "random_state", "ccp_alpha"), bootstrap=bootstrap, oob_score=oob_score, n_jobs=n_jobs, @@ -1033,6 +1042,7 @@ def __init__(self, self.max_leaf_nodes = max_leaf_nodes self.min_impurity_decrease = min_impurity_decrease self.min_impurity_split = min_impurity_split + self.ccp_alpha = ccp_alpha class RandomForestRegressor(ForestRegressor): @@ -1180,6 +1190,14 @@ class RandomForestRegressor(ForestRegressor): and add more estimators to the ensemble, otherwise, just fit a whole new forest. See :term:`the Glossary `. + ccp_alpha : non-negative float, optional (default=0.0) + Complexity parameter used for Minimal Cost-Complexity Pruning. The + subtree with the largest cost complexity that is smaller than + ``ccp_alpha`` will be chosen. By default, no pruning is performed. See + :ref:`minimal_cost_complexity_pruning` for details. + + .. versionadded:: 0.22 + Attributes ---------- base_estimator_ : DecisionTreeRegressor @@ -1266,7 +1284,8 @@ def __init__(self, n_jobs=None, random_state=None, verbose=0, - warm_start=False): + warm_start=False, + ccp_alpha=0.0): super().__init__( base_estimator=DecisionTreeRegressor(), n_estimators=n_estimators, @@ -1274,7 +1293,7 @@ def __init__(self, "min_samples_leaf", "min_weight_fraction_leaf", "max_features", "max_leaf_nodes", "min_impurity_decrease", "min_impurity_split", - "random_state"), + "random_state", "ccp_alpha"), bootstrap=bootstrap, oob_score=oob_score, n_jobs=n_jobs, @@ -1291,6 +1310,7 @@ def __init__(self, self.max_leaf_nodes = max_leaf_nodes self.min_impurity_decrease = min_impurity_decrease self.min_impurity_split = min_impurity_split + self.ccp_alpha = ccp_alpha class ExtraTreesClassifier(ForestClassifier): @@ -1456,6 +1476,14 @@ class ExtraTreesClassifier(ForestClassifier): Note that these weights will be multiplied with sample_weight (passed through the fit method) if sample_weight is specified. + ccp_alpha : non-negative float, optional (default=0.0) + Complexity parameter used for Minimal Cost-Complexity Pruning. The + subtree with the largest cost complexity that is smaller than + ``ccp_alpha`` will be chosen. By default, no pruning is performed. See + :ref:`minimal_cost_complexity_pruning` for details. + + .. versionadded:: 0.22 + Attributes ---------- base_estimator_ : ExtraTreeClassifier @@ -1528,7 +1556,8 @@ def __init__(self, random_state=None, verbose=0, warm_start=False, - class_weight=None): + class_weight=None, + ccp_alpha=0.0): super().__init__( base_estimator=ExtraTreeClassifier(), n_estimators=n_estimators, @@ -1536,7 +1565,7 @@ def __init__(self, "min_samples_leaf", "min_weight_fraction_leaf", "max_features", "max_leaf_nodes", "min_impurity_decrease", "min_impurity_split", - "random_state"), + "random_state", "ccp_alpha"), bootstrap=bootstrap, oob_score=oob_score, n_jobs=n_jobs, @@ -1554,6 +1583,7 @@ def __init__(self, self.max_leaf_nodes = max_leaf_nodes self.min_impurity_decrease = min_impurity_decrease self.min_impurity_split = min_impurity_split + self.ccp_alpha = ccp_alpha class ExtraTreesRegressor(ForestRegressor): @@ -1698,6 +1728,14 @@ class ExtraTreesRegressor(ForestRegressor): and add more estimators to the ensemble, otherwise, just fit a whole new forest. See :term:`the Glossary `. + ccp_alpha : non-negative float, optional (default=0.0) + Complexity parameter used for Minimal Cost-Complexity Pruning. The + subtree with the largest cost complexity that is smaller than + ``ccp_alpha`` will be chosen. By default, no pruning is performed. See + :ref:`minimal_cost_complexity_pruning` for details. + + .. versionadded:: 0.22 + Attributes ---------- base_estimator_ : ExtraTreeRegressor @@ -1757,7 +1795,8 @@ def __init__(self, n_jobs=None, random_state=None, verbose=0, - warm_start=False): + warm_start=False, + ccp_alpha=0.0): super().__init__( base_estimator=ExtraTreeRegressor(), n_estimators=n_estimators, @@ -1765,7 +1804,7 @@ def __init__(self, "min_samples_leaf", "min_weight_fraction_leaf", "max_features", "max_leaf_nodes", "min_impurity_decrease", "min_impurity_split", - "random_state"), + "random_state", "ccp_alpha"), bootstrap=bootstrap, oob_score=oob_score, n_jobs=n_jobs, @@ -1782,6 +1821,7 @@ def __init__(self, self.max_leaf_nodes = max_leaf_nodes self.min_impurity_decrease = min_impurity_decrease self.min_impurity_split = min_impurity_split + self.ccp_alpha = ccp_alpha class RandomTreesEmbedding(BaseForest): @@ -1903,6 +1943,14 @@ class RandomTreesEmbedding(BaseForest): and add more estimators to the ensemble, otherwise, just fit a whole new forest. See :term:`the Glossary `. + ccp_alpha : non-negative float, optional (default=0.0) + Complexity parameter used for Minimal Cost-Complexity Pruning. The + subtree with the largest cost complexity that is smaller than + ``ccp_alpha`` will be chosen. By default, no pruning is performed. See + :ref:`minimal_cost_complexity_pruning` for details. + + .. versionadded:: 0.22 + Attributes ---------- estimators_ : list of DecisionTreeClassifier @@ -1934,7 +1982,8 @@ def __init__(self, n_jobs=None, random_state=None, verbose=0, - warm_start=False): + warm_start=False, + ccp_alpha=0.0): super().__init__( base_estimator=ExtraTreeRegressor(), n_estimators=n_estimators, @@ -1942,7 +1991,7 @@ def __init__(self, "min_samples_leaf", "min_weight_fraction_leaf", "max_features", "max_leaf_nodes", "min_impurity_decrease", "min_impurity_split", - "random_state"), + "random_state", "ccp_alpha"), bootstrap=False, oob_score=False, n_jobs=n_jobs, @@ -1958,6 +2007,7 @@ def __init__(self, self.min_impurity_decrease = min_impurity_decrease self.min_impurity_split = min_impurity_split self.sparse_output = sparse_output + self.ccp_alpha = ccp_alpha def _set_oob_score(self, X, y): raise NotImplementedError("OOB score not supported by tree embedding") diff --git a/sklearn/ensemble/gradient_boosting.py b/sklearn/ensemble/gradient_boosting.py index 43c4dae31f66e..ec5f9a111ccf1 100644 --- a/sklearn/ensemble/gradient_boosting.py +++ b/sklearn/ensemble/gradient_boosting.py @@ -1170,7 +1170,7 @@ class BaseGradientBoosting(BaseEnsemble, metaclass=ABCMeta): def __init__(self, loss, learning_rate, n_estimators, criterion, min_samples_split, min_samples_leaf, min_weight_fraction_leaf, max_depth, min_impurity_decrease, min_impurity_split, - init, subsample, max_features, + init, subsample, max_features, ccp_alpha, random_state, alpha=0.9, verbose=0, max_leaf_nodes=None, warm_start=False, presort='auto', validation_fraction=0.1, n_iter_no_change=None, @@ -1188,6 +1188,7 @@ def __init__(self, loss, learning_rate, n_estimators, criterion, self.max_depth = max_depth self.min_impurity_decrease = min_impurity_decrease self.min_impurity_split = min_impurity_split + self.ccp_alpha = ccp_alpha self.init = init self.random_state = random_state self.alpha = alpha @@ -1233,7 +1234,8 @@ def _fit_stage(self, i, X, y, raw_predictions, sample_weight, sample_mask, max_features=self.max_features, max_leaf_nodes=self.max_leaf_nodes, random_state=random_state, - presort=self.presort) + presort=self.presort, + ccp_alpha=self.ccp_alpha) if self.subsample < 1.0: # no inplace multiplication! @@ -1999,6 +2001,14 @@ class GradientBoostingClassifier(BaseGradientBoosting, ClassifierMixin): .. versionadded:: 0.20 + ccp_alpha : non-negative float, optional (default=0.0) + Complexity parameter used for Minimal Cost-Complexity Pruning. The + subtree with the largest cost complexity that is smaller than + ``ccp_alpha`` will be chosen. By default, no pruning is performed. See + :ref:`minimal_cost_complexity_pruning` for details. + + .. versionadded:: 0.22 + Attributes ---------- n_estimators_ : int @@ -2073,7 +2083,7 @@ def __init__(self, loss='deviance', learning_rate=0.1, n_estimators=100, random_state=None, max_features=None, verbose=0, max_leaf_nodes=None, warm_start=False, presort='auto', validation_fraction=0.1, - n_iter_no_change=None, tol=1e-4): + n_iter_no_change=None, tol=1e-4, ccp_alpha=0.0): super().__init__( loss=loss, learning_rate=learning_rate, n_estimators=n_estimators, @@ -2088,7 +2098,7 @@ def __init__(self, loss='deviance', learning_rate=0.1, n_estimators=100, min_impurity_split=min_impurity_split, warm_start=warm_start, presort=presort, validation_fraction=validation_fraction, - n_iter_no_change=n_iter_no_change, tol=tol) + n_iter_no_change=n_iter_no_change, tol=tol, ccp_alpha=ccp_alpha) def _validate_y(self, y, sample_weight): check_classification_targets(y) @@ -2471,6 +2481,13 @@ class GradientBoostingRegressor(BaseGradientBoosting, RegressorMixin): .. versionadded:: 0.20 + ccp_alpha : non-negative float, optional (default=0.0) + Complexity parameter used for Minimal Cost-Complexity Pruning. The + subtree with the largest cost complexity that is smaller than + ``ccp_alpha`` will be chosen. By default, no pruning is performed. See + :ref:`minimal_cost_complexity_pruning` for details. + + .. versionadded:: 0.22 Attributes ---------- @@ -2532,7 +2549,7 @@ def __init__(self, loss='ls', learning_rate=0.1, n_estimators=100, min_impurity_split=None, init=None, random_state=None, max_features=None, alpha=0.9, verbose=0, max_leaf_nodes=None, warm_start=False, presort='auto', validation_fraction=0.1, - n_iter_no_change=None, tol=1e-4): + n_iter_no_change=None, tol=1e-4, ccp_alpha=0.0): super().__init__( loss=loss, learning_rate=learning_rate, n_estimators=n_estimators, @@ -2546,7 +2563,7 @@ def __init__(self, loss='ls', learning_rate=0.1, n_estimators=100, random_state=random_state, alpha=alpha, verbose=verbose, max_leaf_nodes=max_leaf_nodes, warm_start=warm_start, presort=presort, validation_fraction=validation_fraction, - n_iter_no_change=n_iter_no_change, tol=tol) + n_iter_no_change=n_iter_no_change, tol=tol, ccp_alpha=ccp_alpha) def predict(self, X): """Predict regression target for X. diff --git a/sklearn/tree/_tree.pyx b/sklearn/tree/_tree.pyx index 2fecf36c72da7..dd90611716f06 100644 --- a/sklearn/tree/_tree.pyx +++ b/sklearn/tree/_tree.pyx @@ -1125,7 +1125,6 @@ cdef class Tree: arr.base = self return arr - def compute_partial_dependence(self, DTYPE_t[:, ::1] X, int[::1] target_features, double[::1] out): @@ -1231,3 +1230,421 @@ cdef class Tree: if not (0.999 < total_weight < 1.001): raise ValueError("Total weight should be 1.0 but was %.9f" % total_weight) + + +# ============================================================================= +# Build Pruned Tree +# ============================================================================= + + +cdef class _CCPPruneController: + """Base class used by build_pruned_tree_ccp and ccp_pruning_path + to control pruning. + """ + cdef bint stop_pruning(self, DOUBLE_t effective_alpha) nogil: + """Return 1 to stop pruning and 0 to continue pruning""" + return 0 + + cdef void save_metrics(self, DOUBLE_t effective_alpha, + DOUBLE_t subtree_impurities) nogil: + """Save metrics when pruning""" + pass + + cdef void after_pruning(self, unsigned char[:] in_subtree) nogil: + """Called after pruning""" + pass + + +cdef class _AlphaPruner(_CCPPruneController): + """Use alpha to control when to stop pruning.""" + cdef DOUBLE_t ccp_alpha + cdef SIZE_t capacity + + def __cinit__(self, DOUBLE_t ccp_alpha): + self.ccp_alpha = ccp_alpha + self.capacity = 0 + + cdef bint stop_pruning(self, DOUBLE_t effective_alpha) nogil: + # The subtree on the previous iteration has the greatest ccp_alpha + # less than or equal to self.ccp_alpha + return self.ccp_alpha < effective_alpha + + cdef void after_pruning(self, unsigned char[:] in_subtree) nogil: + """Updates the number of leaves in subtree""" + for i in range(in_subtree.shape[0]): + if in_subtree[i]: + self.capacity += 1 + + +cdef class _PathFinder(_CCPPruneController): + """Record metrics used to return the cost complexity path.""" + cdef DOUBLE_t[:] ccp_alphas + cdef DOUBLE_t[:] impurities + cdef UINT32_t count + + def __cinit__(self, int node_count): + self.ccp_alphas = np.zeros(shape=(node_count), dtype=np.float64) + self.impurities = np.zeros(shape=(node_count), dtype=np.float64) + self.count = 0 + + cdef void save_metrics(self, + DOUBLE_t effective_alpha, + DOUBLE_t subtree_impurities) nogil: + self.ccp_alphas[self.count] = effective_alpha + self.impurities[self.count] = subtree_impurities + self.count += 1 + + +cdef _cost_complexity_prune(unsigned char[:] leaves_in_subtree, # OUT + Tree orig_tree, + _CCPPruneController controller): + """Perform cost complexity pruning. + + This function takes an already grown tree, `orig_tree` and outputs a + boolean mask `leaves_in_subtree` to are the leaves in the pruned tree. The + controller signals when the pruning should stop and is passed the + metrics of the subtrees during the pruning process. + + Parameters + ---------- + leaves_in_subtree : unsigned char[:] + Output for leaves of subtree + orig_tree : Tree + Original tree + ccp_controller : _CCPPruneController + Cost complexity controller + """ + + cdef: + SIZE_t i + SIZE_t n_nodes = orig_tree.node_count + # prior probability using weighted samples + DOUBLE_t[:] weighted_n_node_samples = orig_tree.weighted_n_node_samples + DOUBLE_t total_sum_weights = weighted_n_node_samples[0] + DOUBLE_t[:] impurity = orig_tree.impurity + # weighted impurity of each node + DOUBLE_t[:] r_node = np.empty(shape=n_nodes, dtype=np.float64) + + SIZE_t[:] child_l = orig_tree.children_left + SIZE_t[:] child_r = orig_tree.children_right + SIZE_t[:] parent = np.zeros(shape=n_nodes, dtype=np.intp) + + # Only uses the start and parent variables + Stack stack = Stack(INITIAL_STACK_SIZE) + StackRecord stack_record + int rc = 0 + SIZE_t node_idx + + SIZE_t[:] n_leaves = np.zeros(shape=n_nodes, dtype=np.intp) + DOUBLE_t[:] r_branch = np.zeros(shape=n_nodes, dtype=np.float64) + DOUBLE_t current_r + SIZE_t leaf_idx + SIZE_t parent_idx + + # candidate nodes that can be pruned + unsigned char[:] candidate_nodes = np.zeros(shape=n_nodes, + dtype=np.uint8) + # nodes in subtree + unsigned char[:] in_subtree = np.ones(shape=n_nodes, dtype=np.uint8) + DOUBLE_t[:] g_node = np.zeros(shape=n_nodes, dtype=np.float64) + SIZE_t pruned_branch_node_idx + DOUBLE_t subtree_alpha + DOUBLE_t effective_alpha + SIZE_t child_l_idx + SIZE_t child_r_idx + SIZE_t n_pruned_leaves + DOUBLE_t r_diff + DOUBLE_t max_float64 = np.finfo(np.float64).max + + # find parent node ids and leaves + with nogil: + + for i in range(r_node.shape[0]): + r_node[i] = ( + weighted_n_node_samples[i] * impurity[i] / total_sum_weights) + + # Push root node, using StackRecord.start as node id + rc = stack.push(0, 0, 0, -1, 0, 0, 0) + if rc == -1: + with gil: + raise MemoryError("pruning tree") + + while not stack.is_empty(): + stack.pop(&stack_record) + node_idx = stack_record.start + parent[node_idx] = stack_record.parent + if child_l[node_idx] == _TREE_LEAF: + # ... and child_r[node_idx] == _TREE_LEAF: + leaves_in_subtree[node_idx] = 1 + else: + rc = stack.push(child_l[node_idx], 0, 0, node_idx, 0, 0, 0) + if rc == -1: + with gil: + raise MemoryError("pruning tree") + + rc = stack.push(child_r[node_idx], 0, 0, node_idx, 0, 0, 0) + if rc == -1: + with gil: + raise MemoryError("pruning tree") + + # computes number of leaves in all branches and the overall impurity of + # the branch. The overall impurity is the sum of r_node in its leaves. + for leaf_idx in range(leaves_in_subtree.shape[0]): + if not leaves_in_subtree[leaf_idx]: + continue + r_branch[leaf_idx] = r_node[leaf_idx] + + # bubble up values to ancestor nodes + current_r = r_node[leaf_idx] + while leaf_idx != 0: + parent_idx = parent[leaf_idx] + r_branch[parent_idx] += current_r + n_leaves[parent_idx] += 1 + leaf_idx = parent_idx + + for i in range(leaves_in_subtree.shape[0]): + candidate_nodes[i] = not leaves_in_subtree[i] + + # save metrics before pruning + controller.save_metrics(0.0, r_branch[0]) + + # while root node is not a leaf + while candidate_nodes[0]: + + # computes ccp_alpha for subtrees and finds the minimal alpha + effective_alpha = max_float64 + for i in range(n_nodes): + if not candidate_nodes[i]: + continue + subtree_alpha = (r_node[i] - r_branch[i]) / (n_leaves[i] - 1) + if subtree_alpha < effective_alpha: + effective_alpha = subtree_alpha + pruned_branch_node_idx = i + + if controller.stop_pruning(effective_alpha): + break + + # stack uses only the start variable + rc = stack.push(pruned_branch_node_idx, 0, 0, 0, 0, 0, 0) + if rc == -1: + with gil: + raise MemoryError("pruning tree") + + # descendants of branch are not in subtree + while not stack.is_empty(): + stack.pop(&stack_record) + node_idx = stack_record.start + + if not in_subtree[node_idx]: + continue # branch has already been marked for pruning + candidate_nodes[node_idx] = 0 + leaves_in_subtree[node_idx] = 0 + in_subtree[node_idx] = 0 + + if child_l[node_idx] != _TREE_LEAF: + # ... and child_r[node_idx] != _TREE_LEAF: + rc = stack.push(child_l[node_idx], 0, 0, 0, 0, 0, 0) + if rc == -1: + with gil: + raise MemoryError("pruning tree") + rc = stack.push(child_r[node_idx], 0, 0, 0, 0, 0, 0) + if rc == -1: + with gil: + raise MemoryError("pruning tree") + leaves_in_subtree[pruned_branch_node_idx] = 1 + in_subtree[pruned_branch_node_idx] = 1 + + # updates number of leaves + n_pruned_leaves = n_leaves[pruned_branch_node_idx] - 1 + n_leaves[pruned_branch_node_idx] = 0 + + # computes the increase in r_branch to bubble up + r_diff = r_node[pruned_branch_node_idx] - r_branch[pruned_branch_node_idx] + r_branch[pruned_branch_node_idx] = r_node[pruned_branch_node_idx] + + # bubble up values to ancestors + node_idx = parent[pruned_branch_node_idx] + while node_idx != -1: + n_leaves[node_idx] -= n_pruned_leaves + r_branch[node_idx] += r_diff + node_idx = parent[node_idx] + + controller.save_metrics(effective_alpha, r_branch[0]) + + controller.after_pruning(in_subtree) + + +def _build_pruned_tree_ccp( + Tree tree, # OUT + Tree orig_tree, + DOUBLE_t ccp_alpha): + """Build a pruned tree from the original tree using cost complexity + pruning. + + The values and nodes from the original tree are copied into the pruned + tree. + + Parameters + ---------- + tree : Tree + Location to place the pruned tree + orig_tree : Tree + Original tree + ccp_alpha : positive double + Complexity parameter. The subtree with the largest cost complexity + that is smaller than ``ccp_alpha`` will be chosen. By default, + no pruning is performed. + """ + + cdef: + SIZE_t n_nodes = orig_tree.node_count + unsigned char[:] leaves_in_subtree = np.zeros( + shape=n_nodes, dtype=np.uint8) + + pruning_controller = _AlphaPruner(ccp_alpha=ccp_alpha) + + _cost_complexity_prune(leaves_in_subtree, orig_tree, pruning_controller) + + _build_pruned_tree(tree, orig_tree, leaves_in_subtree, + pruning_controller.capacity) + + +def ccp_pruning_path(Tree orig_tree): + """Computes the cost complexity pruning path. + + Parameters + ---------- + tree : Tree + Original tree. + + Returns + ------- + path_info : dict + Information about pruning path with attributes: + + ccp_alphas : ndarray + Effective alphas of subtree during pruning. + + impurities : ndarray + Sum of the impurities of the subtree leaves for the + corresponding alpha value in ``ccp_alphas``. + """ + cdef: + unsigned char[:] leaves_in_subtree = np.zeros( + shape=orig_tree.node_count, dtype=np.uint8) + + path_finder = _PathFinder(orig_tree.node_count) + + _cost_complexity_prune(leaves_in_subtree, orig_tree, path_finder) + + cdef: + UINT32_t total_items = path_finder.count + np.ndarray ccp_alphas = np.empty(shape=total_items, + dtype=np.float64) + np.ndarray impurities = np.empty(shape=total_items, + dtype=np.float64) + UINT32_t count = 0 + + while count < total_items: + ccp_alphas[count] = path_finder.ccp_alphas[count] + impurities[count] = path_finder.impurities[count] + count += 1 + + return {'ccp_alphas': ccp_alphas, 'impurities': impurities} + + +cdef _build_pruned_tree( + Tree tree, # OUT + Tree orig_tree, + const unsigned char[:] leaves_in_subtree, + SIZE_t capacity): + """Build a pruned tree. + + Build a pruned tree from the original tree by transforming the nodes in + ``leaves_in_subtree`` into leaves. + + Parameters + ---------- + tree : Tree + Location to place the pruned tree + orig_tree : Tree + Original tree + leaves_in_subtree : unsigned char memoryview, shape=(node_count, ) + Boolean mask for leaves to include in subtree + capacity : SIZE_t + Number of nodes to initially allocate in pruned tree + """ + tree._resize(capacity) + + cdef: + SIZE_t orig_node_id + SIZE_t new_node_id + SIZE_t depth + SIZE_t parent + bint is_left + bint is_leaf + + # value_stride for original tree and new tree are the same + SIZE_t value_stride = orig_tree.value_stride + SIZE_t max_depth_seen = -1 + int rc = 0 + Node* node + double* orig_value_ptr + double* new_value_ptr + + # Only uses the start, depth, parent, and is_left variables + Stack stack = Stack(INITIAL_STACK_SIZE) + StackRecord stack_record + + with nogil: + # push root node onto stack + rc = stack.push(0, 0, 0, _TREE_UNDEFINED, 0, 0.0, 0) + if rc == -1: + with gil: + raise MemoryError("pruning tree") + + while not stack.is_empty(): + stack.pop(&stack_record) + + orig_node_id = stack_record.start + depth = stack_record.depth + parent = stack_record.parent + is_left = stack_record.is_left + + is_leaf = leaves_in_subtree[orig_node_id] + node = &orig_tree.nodes[orig_node_id] + + new_node_id = tree._add_node( + parent, is_left, is_leaf, node.feature, node.threshold, + node.impurity, node.n_node_samples, + node.weighted_n_node_samples) + + if new_node_id == (-1): + rc = -1 + break + + # copy value from original tree to new tree + orig_value_ptr = orig_tree.value + value_stride * orig_node_id + new_value_ptr = tree.value + value_stride * new_node_id + memcpy(new_value_ptr, orig_value_ptr, sizeof(double) * value_stride) + + if not is_leaf: + # Push right child on stack + rc = stack.push( + node.right_child, 0, depth + 1, new_node_id, 0, 0.0, 0) + if rc == -1: + break + + # push left child on stack + rc = stack.push( + node.left_child, 0, depth + 1, new_node_id, 1, 0.0, 0) + if rc == -1: + break + + if depth > max_depth_seen: + max_depth_seen = depth + + if rc >= 0: + tree.max_depth = max_depth_seen + if rc == -1: + raise MemoryError("pruning tree") diff --git a/sklearn/tree/tests/test_tree.py b/sklearn/tree/tests/test_tree.py index 12b424b9bf3b7..01b07e76345d4 100644 --- a/sklearn/tree/tests/test_tree.py +++ b/sklearn/tree/tests/test_tree.py @@ -39,7 +39,7 @@ from sklearn.tree import ExtraTreeRegressor from sklearn import tree -from sklearn.tree._tree import TREE_LEAF +from sklearn.tree._tree import TREE_LEAF, TREE_UNDEFINED from sklearn.tree.tree import CRITERIA_CLF from sklearn.tree.tree import CRITERIA_REG from sklearn import datasets @@ -1838,3 +1838,119 @@ def test_decision_tree_memmap(): with TempMemmap((X, y)) as (X_read_only, y_read_only): DecisionTreeClassifier().fit(X_read_only, y_read_only) + + +@pytest.mark.parametrize("criterion", CLF_CRITERIONS) +@pytest.mark.parametrize( + "dataset", sorted(set(DATASETS.keys()) - {"reg_small", "boston"})) +@pytest.mark.parametrize( + "tree_cls", [DecisionTreeClassifier, ExtraTreeClassifier]) +def test_prune_tree_classifier_are_subtrees(criterion, dataset, tree_cls): + dataset = DATASETS[dataset] + X, y = dataset["X"], dataset["y"] + est = tree_cls(max_leaf_nodes=20, random_state=0) + info = est.cost_complexity_pruning_path(X, y) + + pruning_path = info.ccp_alphas + impurities = info.impurities + assert np.all(np.diff(pruning_path) >= 0) + assert np.all(np.diff(impurities) >= 0) + + assert_pruning_creates_subtree(tree_cls, X, y, pruning_path) + + +@pytest.mark.parametrize("criterion", REG_CRITERIONS) +@pytest.mark.parametrize("dataset", DATASETS.keys()) +@pytest.mark.parametrize( + "tree_cls", [DecisionTreeRegressor, ExtraTreeRegressor]) +def test_prune_tree_regression_are_subtrees(criterion, dataset, tree_cls): + dataset = DATASETS[dataset] + X, y = dataset["X"], dataset["y"] + + est = tree_cls(max_leaf_nodes=20, random_state=0) + info = est.cost_complexity_pruning_path(X, y) + + pruning_path = info.ccp_alphas + impurities = info.impurities + assert np.all(np.diff(pruning_path) >= 0) + assert np.all(np.diff(impurities) >= 0) + + assert_pruning_creates_subtree(tree_cls, X, y, pruning_path) + + +def test_prune_single_node_tree(): + # single node tree + clf1 = DecisionTreeClassifier(random_state=0) + clf1.fit([[0], [1]], [0, 0]) + + # pruned single node tree + clf2 = DecisionTreeClassifier(random_state=0, ccp_alpha=10) + clf2.fit([[0], [1]], [0, 0]) + + assert_is_subtree(clf1.tree_, clf2.tree_) + + +def assert_pruning_creates_subtree(estimator_cls, X, y, pruning_path): + # generate trees with increasing alphas + estimators = [] + for ccp_alpha in pruning_path: + est = estimator_cls( + max_leaf_nodes=20, ccp_alpha=ccp_alpha, random_state=0).fit(X, y) + estimators.append(est) + + # A pruned tree must be a subtree of the previous tree (which had a + # smaller ccp_alpha) + for prev_est, next_est in zip(estimators, estimators[1:]): + assert_is_subtree(prev_est.tree_, next_est.tree_) + + +def assert_is_subtree(tree, subtree): + assert tree.node_count >= subtree.node_count + assert tree.max_depth >= subtree.max_depth + + tree_c_left = tree.children_left + tree_c_right = tree.children_right + subtree_c_left = subtree.children_left + subtree_c_right = subtree.children_right + + stack = [(0, 0)] + while stack: + tree_node_idx, subtree_node_idx = stack.pop() + assert_array_almost_equal(tree.value[tree_node_idx], + subtree.value[subtree_node_idx]) + assert_almost_equal(tree.impurity[tree_node_idx], + subtree.impurity[subtree_node_idx]) + assert_almost_equal(tree.n_node_samples[tree_node_idx], + subtree.n_node_samples[subtree_node_idx]) + assert_almost_equal(tree.weighted_n_node_samples[tree_node_idx], + subtree.weighted_n_node_samples[subtree_node_idx]) + + if (subtree_c_left[subtree_node_idx] == + subtree_c_right[subtree_node_idx]): + # is a leaf + assert_almost_equal(TREE_UNDEFINED, + subtree.threshold[subtree_node_idx]) + else: + # not a leaf + assert_almost_equal(tree.threshold[tree_node_idx], + subtree.threshold[subtree_node_idx]) + stack.append((tree_c_left[tree_node_idx], + subtree_c_left[subtree_node_idx])) + stack.append((tree_c_right[tree_node_idx], + subtree_c_right[subtree_node_idx])) + + +def test_prune_tree_raises_negative_ccp_alpha(): + clf = DecisionTreeClassifier() + msg = "ccp_alpha must be greater than or equal to 0" + + with pytest.raises(ValueError, match=msg): + clf.set_params(ccp_alpha=-1.0) + clf.fit(X, y) + + clf.set_params(ccp_alpha=0.0) + clf.fit(X, y) + + with pytest.raises(ValueError, match=msg): + clf.set_params(ccp_alpha=-1.0) + clf._prune_tree() diff --git a/sklearn/tree/tree.py b/sklearn/tree/tree.py index 9f6bf979717cf..0c6240ae71d5a 100644 --- a/sklearn/tree/tree.py +++ b/sklearn/tree/tree.py @@ -25,9 +25,11 @@ from ..base import BaseEstimator from ..base import ClassifierMixin +from ..base import clone from ..base import RegressorMixin from ..base import is_classifier from ..base import MultiOutputMixin +from ..utils import Bunch from ..utils import check_array from ..utils import check_random_state from ..utils import compute_sample_weight @@ -39,6 +41,8 @@ from ._tree import DepthFirstTreeBuilder from ._tree import BestFirstTreeBuilder from ._tree import Tree +from ._tree import _build_pruned_tree_ccp +from ._tree import ccp_pruning_path from . import _tree, _splitter, _criterion __all__ = ["DecisionTreeClassifier", @@ -90,7 +94,8 @@ def __init__(self, min_impurity_decrease, min_impurity_split, class_weight=None, - presort=False): + presort=False, + ccp_alpha=0.0): self.criterion = criterion self.splitter = splitter self.max_depth = max_depth @@ -104,6 +109,7 @@ def __init__(self, self.min_impurity_split = min_impurity_split self.class_weight = class_weight self.presort = presort + self.ccp_alpha = ccp_alpha def get_depth(self): """Returns the depth of the decision tree. @@ -124,6 +130,10 @@ def fit(self, X, y, sample_weight=None, check_input=True, X_idx_sorted=None): random_state = check_random_state(self.random_state) + + if self.ccp_alpha < 0.0: + raise ValueError("ccp_alpha must be greater than or equal to 0") + if check_input: X = check_array(X, dtype=DTYPE, accept_sparse="csc") y = check_array(y, ensure_2d=False, dtype=None) @@ -381,6 +391,8 @@ def fit(self, X, y, sample_weight=None, check_input=True, self.n_classes_ = self.n_classes_[0] self.classes_ = self.classes_[0] + self._prune_tree() + return self def _validate_X_predict(self, X, check_input): @@ -508,6 +520,62 @@ def decision_path(self, X, check_input=True): X = self._validate_X_predict(X, check_input) return self.tree_.decision_path(X) + def _prune_tree(self): + """Prune tree using Minimal Cost-Complexity Pruning.""" + check_is_fitted(self) + + if self.ccp_alpha < 0.0: + raise ValueError("ccp_alpha must be greater than or equal to 0") + + if self.ccp_alpha == 0.0: + return + + # build pruned treee + n_classes = np.atleast_1d(self.n_classes_) + pruned_tree = Tree(self.n_features_, n_classes, self.n_outputs_) + _build_pruned_tree_ccp(pruned_tree, self.tree_, self.ccp_alpha) + + self.tree_ = pruned_tree + + def cost_complexity_pruning_path(self, X, y, sample_weight=None): + """Compute the pruning path during Minimal Cost-Complexity Pruning. + + See `ref`:minimal_cost_complexity_pruning` for details on the pruning + process. + + Parameters + ---------- + X : {array-like, sparse matrix} of shape (n_samples, n_features) + The training input samples. Internally, it will be converted to + ``dtype=np.float32`` and if a sparse matrix is provided + to a sparse ``csc_matrix``. + + y : array-like of shape (n_samples,) or (n_samples, n_outputs) + The target values (class labels) as integers or strings. + + sample_weight : array-like of shape (n_samples,), default=None + Sample weights. If None, then samples are equally weighted. Splits + that would create child nodes with net zero or negative weight are + ignored while searching for a split in each node. Splits are also + ignored if they would result in any single class carrying a + negative weight in either child node. + + Returns + ------- + ccp_path : Bunch + Dictionary-like object, with attributes: + + ccp_alphas : ndarray + Effective alphas of subtree during pruning. + + impurities : ndarray + Sum of the impurities of the subtree leaves for the + corresponding alpha value in ``ccp_alphas``. + """ + est = clone(self).set_params(ccp_alpha=0.0) + est.fit(X, y, sample_weight=sample_weight) + return Bunch(**ccp_pruning_path(est.tree_)) + @property def feature_importances_(self): """Return the feature importances. @@ -664,6 +732,14 @@ class DecisionTreeClassifier(BaseDecisionTree, ClassifierMixin): When using either a smaller dataset or a restricted depth, this may speed up the training. + ccp_alpha : non-negative float, optional (default=0.0) + Complexity parameter used for Minimal Cost-Complexity Pruning. The + subtree with the largest cost complexity that is smaller than + ``ccp_alpha`` will be chosen. By default, no pruning is performed. See + :ref:`minimal_cost_complexity_pruning` for details. + + .. versionadded:: 0.22 + Attributes ---------- classes_ : array of shape = [n_classes] or a list of such arrays @@ -755,7 +831,8 @@ def __init__(self, min_impurity_decrease=0., min_impurity_split=None, class_weight=None, - presort=False): + presort=False, + ccp_alpha=0.0): super().__init__( criterion=criterion, splitter=splitter, @@ -769,7 +846,8 @@ def __init__(self, random_state=random_state, min_impurity_decrease=min_impurity_decrease, min_impurity_split=min_impurity_split, - presort=presort) + presort=presort, + ccp_alpha=ccp_alpha) def fit(self, X, y, sample_weight=None, check_input=True, X_idx_sorted=None): @@ -1016,6 +1094,14 @@ class DecisionTreeRegressor(BaseDecisionTree, RegressorMixin): When using either a smaller dataset or a restricted depth, this may speed up the training. + ccp_alpha : non-negative float, optional (default=0.0) + Complexity parameter used for Minimal Cost-Complexity Pruning. The + subtree with the largest cost complexity that is smaller than + ``ccp_alpha`` will be chosen. By default, no pruning is performed. See + :ref:`minimal_cost_complexity_pruning` for details. + + .. versionadded:: 0.22 + Attributes ---------- feature_importances_ : array of shape = [n_features] @@ -1098,7 +1184,8 @@ def __init__(self, max_leaf_nodes=None, min_impurity_decrease=0., min_impurity_split=None, - presort=False): + presort=False, + ccp_alpha=0.0): super().__init__( criterion=criterion, splitter=splitter, @@ -1111,7 +1198,8 @@ def __init__(self, random_state=random_state, min_impurity_decrease=min_impurity_decrease, min_impurity_split=min_impurity_split, - presort=presort) + presort=presort, + ccp_alpha=ccp_alpha) def fit(self, X, y, sample_weight=None, check_input=True, X_idx_sorted=None): @@ -1293,6 +1381,14 @@ class ExtraTreeClassifier(DecisionTreeClassifier): Note that these weights will be multiplied with sample_weight (passed through the fit method) if sample_weight is specified. + ccp_alpha : non-negative float, optional (default=0.0) + Complexity parameter used for Minimal Cost-Complexity Pruning. The + subtree with the largest cost complexity that is smaller than + ``ccp_alpha`` will be chosen. By default, no pruning is performed. See + :ref:`minimal_cost_complexity_pruning` for details. + + .. versionadded:: 0.22 + Attributes ---------- classes_ : array of shape = [n_classes] or a list of such arrays @@ -1350,7 +1446,8 @@ def __init__(self, max_leaf_nodes=None, min_impurity_decrease=0., min_impurity_split=None, - class_weight=None): + class_weight=None, + ccp_alpha=0.0): super().__init__( criterion=criterion, splitter=splitter, @@ -1363,7 +1460,8 @@ def __init__(self, class_weight=class_weight, min_impurity_decrease=min_impurity_decrease, min_impurity_split=min_impurity_split, - random_state=random_state) + random_state=random_state, + ccp_alpha=ccp_alpha) class ExtraTreeRegressor(DecisionTreeRegressor): @@ -1487,6 +1585,14 @@ class ExtraTreeRegressor(DecisionTreeRegressor): Best nodes are defined as relative reduction in impurity. If None then unlimited number of leaf nodes. + ccp_alpha : non-negative float, optional (default=0.0) + Complexity parameter used for Minimal Cost-Complexity Pruning. The + subtree with the largest cost complexity that is smaller than + ``ccp_alpha`` will be chosen. By default, no pruning is performed. See + :ref:`minimal_cost_complexity_pruning` for details. + + .. versionadded:: 0.22 + Attributes ---------- max_features_ : int, @@ -1534,7 +1640,8 @@ def __init__(self, random_state=None, min_impurity_decrease=0., min_impurity_split=None, - max_leaf_nodes=None): + max_leaf_nodes=None, + ccp_alpha=0.0): super().__init__( criterion=criterion, splitter=splitter, @@ -1546,4 +1653,5 @@ def __init__(self, max_leaf_nodes=max_leaf_nodes, min_impurity_decrease=min_impurity_decrease, min_impurity_split=min_impurity_split, - random_state=random_state) + random_state=random_state, + ccp_alpha=ccp_alpha)