-
-
Notifications
You must be signed in to change notification settings - Fork 25.9k
[MRG+1] Add an example and a method to analyse the decision tree stucture #5487
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
||
# First let's retrieve the decision path of each sample. The decision_paths | ||
# method allows to retrieve the node indicator function. A non zero elements at | ||
# position (i, j) indicates that the sample i goes # through the node j. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On Python 2, this yields an error.
➜ scikit-learn git:(9c2af10) python examples/tree/plot_structure.py
File "examples/tree/plot_structure.py", line 66
SyntaxError: Non-ASCII character '\xc2' in file examples/tree/plot_structure.py on line 66, but no encoding declared; see http://python.org/dev/peps/pep-0263/ for details
(Everything works fine under Python 3 though)
It would be nice to include in the example a visualization of the tree, as done with |
Understanding the decision tree structure | ||
========================================= | ||
|
||
The decision tree structure could be analysed to gain further insight on the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could -> can
@glouppe I have taken your comments into account. |
ping @jmschrei |
ping @jnothman |
What about this? I am not sure it easily feasible though, since the example gallery expect plots to be generated from matplotlib... (as far as I know) |
Other than that, the rest of the code looks good to me. Thanks for this! |
@@ -867,6 +865,166 @@ cdef class Tree: | |||
|
|||
return out | |||
|
|||
cpdef object decision_paths(self, object X): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might want to mention that it returns a sparse matrix, and the reason (because it has to have as many columns as the maximal path)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, I think it should be decision_path
, not decision_paths
Do you think it might be worthwhile to wrap some of the attributes of _tree in the |
It would be nice to have an option where you can pass in a tree, and a path, and get a picture of the tree with the path highlighted. I understand this may be out of scope for this PR though. |
I don't have a strong opinion. However, I believe it will complicate the overall code by adding many properties. |
This is a good idea however I don't see how to make this with tools such as matplotlib.
I forgot this one. I can use export_graphviz, but I fear that I won't get anything better than a list of unreadable strings. |
We could maybe convert it to an image and then "plot" it using |
I am thinking of adding a decision path in the forest. This could be useful to generate new feature from the forest. |
this should be ready for a new round of review. |
Okay, fair enough. I'll see what I can do once this is merged. |
I have taken into account your comment @amueller. |
one of the five ;) |
Now it's good; |
thanks :) I didn't review the |
|
||
# First let's retrieve the decision path of each sample. The decision_path | ||
# method allows to retrieve the node indicator function. A non zero elements at | ||
# position (i, j) indicates that the sample i goes sthrough the node j. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
elements -> element
sthrough -> through
I find the example very useful ! |
y = iris.target | ||
est = DecisionTreeClassifier(random_state=0, max_depth=1).fit(X, y) | ||
node_indicator = est.decision_path(X[:2]).toarray() | ||
assert_array_equal(node_indicator, [[1, 1, 0], [1, 0, 1]]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I am misunderstanding these paths. How can the first point go to node 1, then node 1 again, then node 0? My understanding was that the decision path was an array of node IDs, ending in a leaf node.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is the binary encoded version of this. If path[i] == 1
, then the sample traverses node with node_id i
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay. So then the later nodes always correspond to later nodes visited, so using it as a mask on the nodes gives you the path. Got it.
leave_indicator = [indicator[i, n_nodes_ptr[est_id] + j] | ||
for i, j in enumerate(leaves[:, est_id])] | ||
assert_array_almost_equal(leave_indicator, np.ones(shape=n_samples)) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand the almost_equal
For what I understood it seems good to me but I'm not fluent enough in cython to give a +1. |
[MRG+1] Add an example and a method to analyse the decision tree stucture
Thanks |
This reverts commit 46ad44a.
example + benchmark explanation make some private functions + fix public API IForest using BaseForest base class for trees debug + plot_iforest classic anomaly detection datasets and benchmark small modif BaseBagging inheritance shuffle dataset before benchmarking BaseBagging inheritance remove class label 4 from shuttle dataset pep8 + rm shuttle.csv bench_IsolationForest.png + doc decision_function add tests remove comments fetching kddcup99 and shuttle datasets fetching kddcup99 and shuttle datasets pep8 fetching kddcup99 and shuttle datasets pep8 new files iforest.py and test_iforest.py sc alternative to pandas (but very slow) in kddcup99.py faster parser sc pep8 + cleanup + simplification example outlier detection clean and correct idem random_state added percent10=True in benchmark mc remove shuttle + minor changes sc undo modif on forest.py and recompile cython on _tree.c fix travis cosmit change bagging to fix travis Revert "change bagging to fix travis" This reverts commit 30ea500. add max_samples_ in BaseBagging.fit to fix travis mc API : don't add fit param but use a private _fit + update tests + examples to avoid warning adapt to the new structure of _tree.pyx cosmit add performance test for iforest add _tree.c _utils.c _criterion.c TST : pass on tests remove test relax roc-auc to fix AppVeyor add test on toy samples Handle depth averaging at python level plot example: rm html add png load_kddcup99 -> fetch_kddcup99 + doc Take into account arjoly comments sh -> shuffle add decision_path code from scikit-learn#5487 to bench Take into account arjoly comments Revert "add decision_path code from scikit-learn#5487 to bench" This reverts commit 46ad44a. fix bug with max_samples != int
example + benchmark explanation make some private functions + fix public API IForest using BaseForest base class for trees debug + plot_iforest classic anomaly detection datasets and benchmark small modif BaseBagging inheritance shuffle dataset before benchmarking BaseBagging inheritance remove class label 4 from shuttle dataset pep8 + rm shuttle.csv bench_IsolationForest.png + doc decision_function add tests remove comments fetching kddcup99 and shuttle datasets fetching kddcup99 and shuttle datasets pep8 fetching kddcup99 and shuttle datasets pep8 new files iforest.py and test_iforest.py sc alternative to pandas (but very slow) in kddcup99.py faster parser sc pep8 + cleanup + simplification example outlier detection clean and correct idem random_state added percent10=True in benchmark mc remove shuttle + minor changes sc undo modif on forest.py and recompile cython on _tree.c fix travis cosmit change bagging to fix travis Revert "change bagging to fix travis" This reverts commit 30ea500. add max_samples_ in BaseBagging.fit to fix travis mc API : don't add fit param but use a private _fit + update tests + examples to avoid warning adapt to the new structure of _tree.pyx cosmit add performance test for iforest add _tree.c _utils.c _criterion.c TST : pass on tests remove test relax roc-auc to fix AppVeyor add test on toy samples Handle depth averaging at python level plot example: rm html add png load_kddcup99 -> fetch_kddcup99 + doc Take into account arjoly comments sh -> shuffle add decision_path code from scikit-learn#5487 to bench Take into account arjoly comments Revert "add decision_path code from scikit-learn#5487 to bench" This reverts commit 46ad44a. fix bug with max_samples != int
example + benchmark explanation make some private functions + fix public API IForest using BaseForest base class for trees debug + plot_iforest classic anomaly detection datasets and benchmark small modif BaseBagging inheritance shuffle dataset before benchmarking BaseBagging inheritance remove class label 4 from shuttle dataset pep8 + rm shuttle.csv bench_IsolationForest.png + doc decision_function add tests remove comments fetching kddcup99 and shuttle datasets fetching kddcup99 and shuttle datasets pep8 fetching kddcup99 and shuttle datasets pep8 new files iforest.py and test_iforest.py sc alternative to pandas (but very slow) in kddcup99.py faster parser sc pep8 + cleanup + simplification example outlier detection clean and correct idem random_state added percent10=True in benchmark mc remove shuttle + minor changes sc undo modif on forest.py and recompile cython on _tree.c fix travis cosmit change bagging to fix travis Revert "change bagging to fix travis" This reverts commit 30ea500. add max_samples_ in BaseBagging.fit to fix travis mc API : don't add fit param but use a private _fit + update tests + examples to avoid warning adapt to the new structure of _tree.pyx cosmit add performance test for iforest add _tree.c _utils.c _criterion.c TST : pass on tests remove test relax roc-auc to fix AppVeyor add test on toy samples Handle depth averaging at python level plot example: rm html add png load_kddcup99 -> fetch_kddcup99 + doc Take into account arjoly comments sh -> shuffle add decision_path code from scikit-learn#5487 to bench Take into account arjoly comments Revert "add decision_path code from scikit-learn#5487 to bench" This reverts commit 46ad44a. fix bug with max_samples != int
example + benchmark explanation make some private functions + fix public API IForest using BaseForest base class for trees debug + plot_iforest classic anomaly detection datasets and benchmark small modif BaseBagging inheritance shuffle dataset before benchmarking BaseBagging inheritance remove class label 4 from shuttle dataset pep8 + rm shuttle.csv bench_IsolationForest.png + doc decision_function add tests remove comments fetching kddcup99 and shuttle datasets fetching kddcup99 and shuttle datasets pep8 fetching kddcup99 and shuttle datasets pep8 new files iforest.py and test_iforest.py sc alternative to pandas (but very slow) in kddcup99.py faster parser sc pep8 + cleanup + simplification example outlier detection clean and correct idem random_state added percent10=True in benchmark mc remove shuttle + minor changes sc undo modif on forest.py and recompile cython on _tree.c fix travis cosmit change bagging to fix travis Revert "change bagging to fix travis" This reverts commit 30ea500. add max_samples_ in BaseBagging.fit to fix travis mc API : don't add fit param but use a private _fit + update tests + examples to avoid warning adapt to the new structure of _tree.pyx cosmit add performance test for iforest add _tree.c _utils.c _criterion.c TST : pass on tests remove test relax roc-auc to fix AppVeyor add test on toy samples Handle depth averaging at python level plot example: rm html add png load_kddcup99 -> fetch_kddcup99 + doc Take into account arjoly comments sh -> shuffle add decision_path code from scikit-learn#5487 to bench Take into account arjoly comments Revert "add decision_path code from scikit-learn#5487 to bench" This reverts commit 46ad44a. fix bug with max_samples != int
Ping @glouppe, @pprett, @ogrisel, @amueller
Suggestions are welcome to improve the example.
It should fix #1105 and #5441.