-
-
Notifications
You must be signed in to change notification settings - Fork 25.9k
Adding a pruning method to the tree #941
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
---------- | ||
tree : binary tree object | ||
The binary tree for which to compute the complexity costs. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The object doesn't need to be listed as a parameter here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sorry, I forgot to remove that when I refactored my code.
Nice addition. I haven't had a chance to try it out in detail, but I read over the code and it looks good. |
@@ -266,6 +291,105 @@ def _add_leaf(self, parent, is_left_child, value, error, n_samples): | |||
|
|||
return node_id | |||
|
|||
def _copy(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason not to use clone
here? Not sure if that copies the arrays, though. But I'm sure there is a method that does (for serialization for example).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
clone
clones the parameters of the Estimator object. Tree results from the fit of a DecisionTree -- it is not copied. Also DecisionTree has a Tree and is an Estimator, but a Tree is not an estimator (it inherits directly from object).
If someone knows a copying function, I will gladly use it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
copy
or deepcopy
from the copy
module.
Thanks for your contribution. It would be nice to have some mentioning in the narrative docs and maybe extent an already existing example to show how the pruning works and what effect is hast. I'm not sure what would be a good test for the pruning but maybe you can come up with something, for example testing on some toy data with known expected outcome. Looking forward to seeing this in action :) |
t_nodes = _get_terminal_nodes(children) | ||
g_i = tree.init_error[t_nodes] - tree.best_error[t_nodes] | ||
|
||
#alpha_i = np.min(g_i) / len(t_nodes) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this isn't needed, should remove it...
I agree narrative docs and an example would be very helpful. You could modify these:
You could add a fit to the plot which uses a pruned tree. Hopefully it would do better than the green line, but not over-fit as much as the red line. Also, a new example with your Sorry to ask for more work when you've already done so much! Let us know if you need any help. I'm excited to see the results! 😁 |
|
||
Parameters | ||
---------- | ||
n_leaves : binary tree object |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't this an int
, what does it mean to be a binary tree object
?
Thanks for all your feedback! I will for sure make some docs, but I wanted to see if what I did was worth pulling before doing more work. About my helper functions, one little details: I use _get_leaves in the new property leaves that I added to tree (and use it to compute the number of nodes to prune). I usually like to have small helper functions, it makes things easier for me to read afterwards. Nevertheless if merging them in seems a better design to most of you, I'll do it. Also it would be nice if one of you plays a bit with the feature to have an external "test" on the function. I have used it for my own needs (and compared it in one case with some equivalent R function), but external confirmation is usually a good thing. |
Correction of small details (documentation, commented code and return value)
|
||
def prune(self, n_leaves): | ||
""" | ||
Prunes the tree in order to obtain the optimal tree with n_leaves |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would rather say the optimal subtree.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1
I just had a quick look at the
|
The synthetic example is pretty cool on the other hand and nicely illustrates the effect of pruning. It leaves me with one question, though: are there cases when pruning is better than regularizing via "min_leaf" or "max_depth"? |
I don't know what the other tree-growers think, but I feel having a separate "prune" method breaks the API pretty badly. I think |
I just realized that having a separate |
I think the heuristics in this case is to go for the smallest tree with the value of the plateau. Actually R automatically uses such an heuristic. As you said, I should explain more how to read the graphs.
I should probably show how the trees can be different in a case where we prune or not. Via pruning the resulting tree is "more optimal". It is possible that, while growing we do not go into a branch because the next step does not improve much the score, while the following step would improve it greatly -- but will not because this path is not chosen.
What do you think of a mixed approach? We could add an Note that this would change the interface of this particular method (by that I mean that the default behaviour would become linked with pruning, while it was not so far). We would need to discuss a bit further what the default behaviour is (do we ask only for Ok, if I sum it up, to follow better the API we need to change
|
@sgenoud Thanks for your quick feedback. Before you get to work, I would prefer if someone else would also voice there opinion, so that you don't do any unnecessary work regarding refactoring. If you could find an example that illustrates the benefits of pruning compared to other regularization methods, that would be great. For the synthetic example that you made, I fear that "more optimal" means "more overfitted" and for example "min_samples_leaf" does better. Maybe you should also mention somewhere that the pruning fits more strongly to the data. As a default behavior, I would have imagined to have no pruning, to be backward compatible and also because the ensemble module doesn't need pruning. |
@amueller Not that much, basically instead of stoppping when the tree has exactly |
@sgenoud Yes I fully agree with you. However don't you think it would be better to implement the method that textbooks describe? Only not to confuse those that may have some background knowledge. |
@glouppe It is probably a question of taste in the end. Is there a scikit-learn policy for this kind of choice? I would propose to make the point of Finally, we could also argue that it would make pre and post pruning more similar. |
I would favor |
Okay then, I am fine with that. |
Also renamed the helper function that plot the result of the function
I have tried to integrate your feedback, there is still two things to do:
|
@glouppe would you mind merging my modifications of the Tree object with your refactoring? I am no Cython expert and you are the one who refactored the object (and therefore know it well). |
@sgenoud I can do it, but it'll unfortunately have to wait until mid-Augustus. I am soon leaving for holidays and have a few remaining things to handle before. |
@sgenoud maybe you can work on the doc in the meantime? |
Anything new? |
Need help with this? @sgenoud, what is the status of the patch? |
Hi guys, sorry I have started a new job that takes me a lot of time. I will have more for this in December normally. I'll keep you posted |
This patch is basically completely busted now because it works on the old tree before things were ported to cython trees. |
I would like to do some pruning, so I've made an attempt to update the code in this pull request. I'm not sure if it is the best way to do things now, but at least it brings us close to where we were before. The updated code is here: https://github.com/aflaxman/scikit-learn/tree/tree_pruning Is there still interest in this? |
Yes, I think there is interest as long as it doesn't add overhead that can't be avoided for the forests. |
This seems to have gone stale. As a user of sklearn I'd love see pruning for decision trees. Are there any updates? |
Echoing @jessrosenfield on this (I'm guessing she's working on the same machine learning project I am...) - any updates? |
The tree and forest code has been changed considerably since this work was started; any new effort toward pruning would essentially have to start from scratch. Perhaps we should close this PR? |
Closing in favor of #6557 |
I have added a pruning method to the decision trees method. The idea with decision tree is to build a huge one, prune it via a weakest link algorithm until it reaches a size that is reasonable (neither overfitting nor under fitting).
I also have build a helper function, cv_scores_vs_n_leaves, that computes the cross validated scores for different sizes of the tree. This can be plotted with a function such as
Then we choose which size is the best for the data.
Just a couple of notes:
Before doing that work, I would gladly have some feedback on these modifications