8000 [MRG+2] ENH/FIX Introduce min_impurity_decrease param for early stopping based on impurity; Deprecate min_impurity_split by raghavrv · Pull Request #8449 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

[MRG+2] ENH/FIX Introduce min_impurity_decrease param for early stopping based on impurity; Deprecate min_impurity_split #8449

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Apr 3, 2017

Conversation

raghavrv
Copy link
Member
@raghavrv raghavrv commented Feb 24, 2017

Fixes #8400

Also ref Gilles' comment

This PR tries to stop splitting if the weighted impurity gain after a potential split is not above a user-given threshold...

@amueller Can you try this on your use cases and see if it gives better control than min_impurity_split?

@jnothman @glouppe @nelson-liu @glemaitre @jmschrei

8000
imp_right = est.tree_.impurity[right]
weighted_n_right = est.tree_.weighted_n_node_samples[right]

actual_decrease = (est.tree_.impurity[node] -
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#TODO this is incorrect comparison. The actual decrease should again by multiplied by fractional weight of the parent node...

@@ -446,7 +454,8 @@ cdef class BestFirstTreeBuilder(TreeBuilder):

if not is_leaf:
splitter.node_split(impurity, &split, &n_constant_features)
is_leaf = is_leaf or (split.pos >= end)
is_leaf = (is_leaf or split.pos >= end or
split.improvement + EPSILON < min_impurity_decrease)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the need for epsilon here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did this to avoid floating precision inconsistencies affecting the split... I'll explain clearly in a subsequent comment...

Copy link
Member Author
@raghavrv raghavrv Mar 7, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I did this to avoid not splitting if split.improvement is almost equal to min_impurity_decrease within the precision of the machine. For instance if you give min_impurity_decrease as 1e-7, it does not build the tree completely as sometimes the improvement is almost equal to 1e-7...

And I added it to the left and not right as it would give splitting the benefit of doubt (as opposed to not splitting)...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To clarify further. Setting it to 1e-7 as done for other stopping params to denote eps will not let the tree grow fully and produce trees dissimilar to master...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add this as an inline comment, then.

@@ -272,10 +275,23 @@ def fit(self, X, y, sample_weight=None, check_input=True,
min_weight_leaf = (self.min_weight_fraction_leaf *
np.sum(sample_weight))

if self.min_impurity_split < 0.:
if self.min_impurity_split is not None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a deprication decorator which can be used? I know there is one for depricated functions, but I'm not sure about parameters.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we typically use our deprecated decorator for attributes not parameters... But I'm unsure... @amueller thoughts?

@jmschrei
Copy link
Member

In general this looks good. I didn't check your test though to make sure it was correct.

@raghavrv
Copy link
Member Author

Thanks a lot @jmschrei for the review!

@raghavrv
Copy link
Member Author
raghavrv commented Feb 27, 2017

Others @glouppe @amueller Reviews please :)

@nelson-liu
Copy link
Contributor

Functionality wise this looks good to me, pending that comment about the deprecation decorator. Good work @raghavrv

@raghavrv
Copy link
Member Author
raghavrv commented Mar 7, 2017

Thanks @nelson-liu and @jmschrei. Andy or Gilles??

@raghavrv
Copy link
Member Author

Or maybe @glemaitre / @ogrisel have some time for reviews?

@glemaitre
Copy link
Member

Should you mention in the docstring that min_impurity_split will be deprecated?

Threshold for early stopping in tree growth. A node will split
if its impurity is above the threshold, otherwise it is a leaf.
min_impurity_decrease : float, optional (default=0.)
Threshold for early stopping in tree growth. A node will be split
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would change with:

A node will be split if this split induces a decrease of the impurity
greater than or equal to this value.


.. versionadded:: 0.18
The impurity decrease due to a potential split is the difference in the
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would remove "due to a potential split"

Threshold for early stopping in tree growth. A node will split
if its impurity is above the threshold, otherwise it is a leaf.
min_impurity_decrease : float, optional (default=0.)
Threshold for early stopping in tree growth. A node will be split
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same changes as in RandomForestClassifier

Threshold for early stopping in tree growth. A node will split
if its impurity is above the threshold, otherwise it is a leaf.
min_impurity_decrease : float, optional (default=0.)
Threshold for early stopping in tree growth. A node will be split
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same changes as in RandomForestClassifier

Threshold for early stopp F438 ing in tree growth. A node will split
if its impurity is above the threshold, otherwise it is a leaf.
min_impurity_decrease : float, optional (default=0.)
Threshold for early stopping in tree growth. A node will be split
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same changes as in RandomForestClassifier

@@ -1406,7 +1417,8 @@ class GradientBoostingClassifier(BaseGradientBoosting, ClassifierMixin):
def __init__(self, loss='deviance', learning_rate=0.1, n_estimators=100,
subsample=1.0, criterion='friedman_mse', min_samples_split=2,
min_samples_leaf=1, min_weight_fraction_leaf=0.,
max_depth=3, min_impurity_split=1e-7, init=None,
max_depth=3, min_impurity_decrease=0.,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

min_impurity_decrease is define at 1e-7 in the above docstring.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the catch. I changed the doc to 0... I'm using 0 because of the EPSILON added as described here...

min_impurity_split : float, optional (default=1e-7)
Threshold for early stopping in tree growth. A node will split
if its impurity is above the threshold, otherwise it is a leaf.
min_impurity_decrease : float, optional (default=1e-7)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check the default value

@@ -1790,7 +1811,8 @@ class GradientBoostingRegressor(BaseGradientBoosting, RegressorMixin):
def __init__(self, loss='ls', learning_rate=0.1, n_estimators=100,
subsample=1.0, criterion='friedman_mse', min_samples_split=2,
min_samples_leaf=1, min_weight_fraction_leaf=0.,
max_depth=3, min_impurity_split=1e-7, init=None, random_state=None,
max_depth=3, min_impurity_decrease=0.,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check the default value

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Same as above)

Threshold for early stopping in tree growth. If the impurity
of a node is below the threshold, the node is a leaf.
min_impurity_decrease : float, optional (default=0.)
Threshold for early stopping in tree growth. A node will be split
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same changes as in RandomForestClassifier

Threshold for early stopping in tree growth. A node will split
if its impurity is above the threshold, otherwise it is a leaf.
min_impurity_decrease : float, optional (default=0.)
Threshold for early stopping in tree growth. A node will be split
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same changes as in RandomForestClassifier

@raghavrv
Copy link
Member Author

Should you mention in the docstring that min_impurity_split will be deprecated?

Generally we don't mention that in docstring. We deprecate it and remove the doc for that param...

Thanks for the review. Have addressed it :) Another round?

@jnothman Could you take a look this too?

@raghavrv raghavrv force-pushed the min_impurity_decrease branch from 4775b93 to 0ca3a4e Compare March 24, 2017 20:44
Copy link
Member
@MechCoder MechCoder left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some minor comments, looks fine otherwise.


.. versionadded:: 0.18
The impurity decrease is the difference in the parent node's impurity
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer the easier-to-follow definition over here (https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/tree/_criterion.pyx#L177).

Also, there seems to be an extra term outside the bracket (N_parent / N_total) from your tests here. (https://github.com/scikit-learn/scikit-learn/pull/8449/files#diff-c3874016cfa1f9bc378d573240ff0502R890)


fractional_node_weight = (
est.tree_.weighted_n_node_samples[node] /
est.tree_.weighted_n_node_samples[0])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick: Can you replace the denominator by just X.shape[0]?

est.tree_.impurity[node] -
(weighted_n_left * imp_left +
weighted_n_right * imp_right) /
(weighted_n_left + weighted_n_right)))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be simpler to write (N_parent * Imp_parent - N_left * imp_left - N_right * imp_right) / N

def test_min_impurity_decrease():
# test if min_impurity_decrease ensure that a split is made only if
# if the impurity decrease is atleast that value
X, y = datasets.make_classification(n_samples=10000, random_state=42)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should test regressors also no?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes! The ALL_TREES[...] contains regressors too... Just that I use the same classification data to test the regressors too...

@@ -446,7 +454,8 @@ cdef class BestFirstTreeBuilder(TreeBuilder):

if not is_leaf:
splitter.node_split(impurity, &split, &n_constant_features)
is_leaf = is_leaf or (split.pos >= end)
is_leaf = (is_leaf or split.pos >= end or
split.improvement + EPSILON < min_impurity_decrease)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add this as an inline comment, then.

# Test if min_impurity_split of base estimators is set
# Regression test for #8006
X, y = datasets.make_hastie_10_2(n_samples=100, random_state=1)
all_estimators = [GradientBoostingRegressor,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You need to test for random forests also?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! done in the latest commit..

@MechCoder
Copy link
Member

I agree that the behaviour of min_impurity_decrease is much more intuitive than min_impurity_split.

@MechCoder
Copy link
Member
MechCoder commented Mar 31, 2017

It's the same expression your one with the "fractional_weight" and the one documented in the criterion file. It is just that I find the latter easier to read, but it's fine. (I meant having the extra term is right and it wasn't reflected in the documentation)

@MechCoder
Copy link
Member

LGTM!

@MechCoder MechCoder changed the title [MRG] ENH/FIX Introduce min_impurity_decrease param for early stopping based on impurity; Deprecate min_impurity_split [MRG+1] ENH/FIX Introduce min_impurity_decrease param for early stopping based on impurity; Deprecate min_impurity_split Mar 31, 2017

.. versionadded:: 0.18
The weighted impurity decrease equation is the following:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we using the ::math environment in the docstring?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@raghavrv will the math display correctly from lines 815-816? The `` tag will work properly, but does indenting alone work as intended?

Copy link
Member
@jmschrei jmschrei left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. If you can address the one typesetting comment I'll go ahead and merge it.


.. versionadded:: 0.18
The weighted impurity decrease equation is the following:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@raghavrv will the math display correctly from lines 815-816? The `` tag will work properly, but does indenting alone work as intended?

@raghavrv
Copy link
Member Author
raghavrv commented Apr 3, 2017

@jmschrei @glemaitre Thanks for pointing that out! It was not displaying correctly before but after the latest commit it should look like this

image

@jmschrei jmschrei changed the title [MRG+1] ENH/FIX Introduce min_impurity_decrease param for early stopping based on impurity; Deprecate min_impurity_split [MRG+2] ENH/FIX Introduce min_impurity_decrease param for early stopping based on impurity; Deprecate min_impurity_split Apr 3, 2017
@jmschrei jmschrei merged commit fc2f249 into scikit-learn:master Apr 3, 2017
@raghavrv
Copy link
Member Author
raghavrv commented Apr 3, 2017

Yohoo!! Thanks for the reviews and merge @jmschrei @MechCoder and @glemaitre :)

@raghavrv raghavrv deleted the min_impurity_decrease branch April 3, 2017 16:44
@glouppe
Copy link
Contributor
glouppe commented Apr 4, 2017

Nice :)

@amueller
Copy link
Member
amueller commented Apr 5, 2017

Sweet, thanks!
Can I haz example?

massich pushed a commit to massich/scikit-learn that referenced this pull request Apr 26, 2017
…ing based on impurity; Deprecate min_impurity_split (scikit-learn#8449)

[MRG+2] ENH/FIX Introduce min_impurity_decrease param for early stopping based on impurity; Deprecate min_impurity_split
Sundrique pushed a commit to Sundrique/scikit-learn that referenced this pull request Jun 14, 2017
…ing based on impurity; Deprecate min_impurity_split (scikit-learn#8449)

[MRG+2] ENH/FIX Introduce min_impurity_decrease param for early stopping based on impurity; Deprecate min_impurity_split
NelleV pushed a commit to NelleV/scikit-learn that referenced this pull request Aug 11, 2017
…ing based on impurity; Deprecate min_impurity_split (scikit-learn#8449)

[MRG+2] ENH/FIX Introduce min_impurity_decrease param for early stopping based on impurity; Deprecate min_impurity_split
paulha pushed a commit to paulha/scikit-learn that referenced this pull request Aug 19, 2017
…ing based on impurity; Deprecate min_impurity_split (scikit-learn#8449)

[MRG+2] ENH/FIX Introduce min_impurity_decrease param for early stopping based on impurity; Deprecate min_impurity_split
sebp added a commit to sebp/scikit-survival that referenced this pull request Oct 16, 2017
sebp added a commit to sebp/scikit-survival that referenced this pull request Oct 16, 2017
sebp added a commit to sebp/scikit-survival that referenced this pull request Oct 30, 2017
maskani-moh pushed a commit to maskani-moh/scikit-learn that referenced this pull request Nov 15, 2017
…ing based on impurity; Deprecate min_impurity_split (scikit-learn#8449)

[MRG+2] ENH/FIX Introduce min_impurity_decrease param for early stopping based on impurity; Deprecate min_impurity_split
sebp added a commit to sebp/scikit-survival that referenced this pull request Nov 18, 2017
jwjohnson314 pushed a commit to jwjohnson314/scikit-learn that referenced this pull request Dec 18, 2017
…ing based on impurity; Deprecate min_impurity_split (scikit-learn#8449)

[MRG+2] ENH/FIX Introduce min_impurity_decrease param for early stopping based on impurity; Deprecate min_impurity_split
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants
0