8000 Fix max_depth overshoot in BFS expansion of trees by adrinjalali · Pull Request #12344 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

Fix max_depth overshoot in BFS expansion of trees #12344

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
13 changes: 13 additions & 0 deletions doc/whats_new/v0.21.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ occurs due to changes in the modelling logic (bug fixes or enhancements), or in
random sampling procedures.

- please add class and reason here (see version 0.20 what's new)
- Decision trees and derived ensembles when both `max_depth` and
`max_leaf_nodes` are set. (bug fix)

Details are listed in the changelog below.

Expand Down Expand Up @@ -116,6 +118,17 @@ Support for Python 3.4 and below has been officially dropped.
and :class:`tree.ExtraTreeRegressor`.
:issue:`12300` by :user:`Adrin Jalali <adrinjalali>`.

- |Fix| Fixed an issue with :class:`tree.BaseDecisionTree`
and consequently all estimators based
on it, including :class:`tree.DecisionTreeClassifier`,
:class:`tree.DecisionTreeRegressor`, :class:`tree.ExtraTreeClassifier`,
and :class:`tree.ExtraTreeRegressor`, where they used to exceed the given
``max_depth`` by 1 while expanding the tree if ``max_leaf_nodes`` and
``max_depth`` were both specified by the user. Please note that this also
affects all ensemble methods using decision trees.
:pr:`12344` by :user:`Adrin Jalali <adrinjalali>`.


Multiple modules
................

Expand Down
4 changes: 2 additions & 2 deletions sklearn/ensemble/tests/test_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,11 +709,11 @@ def check_max_leaf_nodes_max_depth(name):
ForestEstimator = FOREST_ESTIMATORS[name]
est = ForestEstimator(max_depth=1, max_leaf_nodes=4,
n_estimators=1, random_state=0).fit(X, y)
assert_greater(est.estimators_[0].tree_.max_depth, 1)
assert_equal(est.estimators_[0].get_depth(), 1)

est = ForestEstimator(max_depth=1, n_estimators=1,
random_state=0).fit(X, y)
assert_equal(est.estimators_[0].tree_.max_depth, 1)
assert_equal(est.estimators_[0].get_depth(), 1)


@pytest.mark.parametrize('name', FOREST_ESTIMATORS)
Expand Down
2 changes: 1 addition & 1 deletion sklearn/ensemble/tests/test_gradient_boosting.py
Original file line number Diff line number Diff line change
Expand Up @@ -1104,7 +1104,7 @@ def test_max_leaf_nodes_max_depth(GBEstimator):

est = GBEstimator(max_depth=1, max_leaf_nodes=k).fit(X, y)
tree = est.estimators_[0, 0].tree_
assert_greater(tree.max_depth, 1)
assert_equal(tree.max_depth, 1)

est = GBEstimator(max_depth=1).fit(X, y)
tree = est.estimators_[0, 0].tree_
Expand Down
2 changes: 1 addition & 1 deletion sklearn/tree/_tree.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,7 @@ cdef class BestFirstTreeBuilder(TreeBuilder):
impurity = splitter.node_impurity()

n_node_samples = end - start
is_leaf = (depth > self.max_depth or
is_leaf = (depth >= self.max_depth or
n_node_samples < self.min_samples_split or
n_node_samples < 2 * self.min_samples_leaf or
weighted_n_node_samples < 2 * self.min_weight_leaf or
Expand Down
3 changes: 1 addition & 2 deletions sklearn/tree/tests/test_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -1210,7 +1210,6 @@ def test_class_weight_errors(name):

def test_max_leaf_nodes():
# Test greedy trees with max_depth + 1 leafs.
from sklearn.tree._tree import TREE_LEAF
X, y = datasets.make_hastie_10_2(n_samples=100, random_state=1)
k = 4
for name, TreeEstimator in ALL_TREES.items():
Expand All @@ -1232,7 +1231,7 @@ def test_max_leaf_nodes_max_depth():
k = 4
for name, TreeEstimator in ALL_TREES.items():
est = TreeEstimator(max_depth=1, max_leaf_nodes=k).fit(X, y)
assert_greater(est.get_depth(), 1)
assert_equal(est.get_depth(), 1)


def test_arrays_persist():
Expand Down
0