8000 DOC add an example highlighting the tree structure · scikit-learn/scikit-learn@7ca3b2d · GitHub
[go: up one dir, main page]

Skip to content

Commit 7ca3b2d

committed
DOC add an example highlighting the tree structure
1 parent c6eac2b commit 7ca3b2d

File tree

1 file changed

+110
-0
lines changed

1 file changed

+110
-0
lines changed

examples/tree/plot_structure.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
"""
2+
=========================================
3+
Understanding the decision tree structure
4+
=========================================
5+
6+
The decision tree structure could be analysed to gain further insight on the
7+
relation between the features and the target to predict. In this example, we
8+
show how to retrieve:
9+
- the binary tree structure;
10+
- the nodes that were reaches by a sample using the decision_paths method;
11+
- the leaf that was reaches by a sample using the apply method;
12+
- the rules that were used to predict a sample;
13+
- the decision path shared by a group of samples.
14+
15+
"""
16+
import numpy as np
17+
18+
from sklearn.cross_validation import train_test_split
19+
from sklearn.datasets import load_iris
20+
from sklearn.tree import DecisionTreeRegressor
21+
22+
23+
iris = load_iris()
24+
X = iris.data
25+
y = iris.target
26+
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
27+
28+
estimator = DecisionTreeRegressor(max_leaf_nodes=3, random_state=0)
29+
estimator.fit(X_train, y_train)
30+
31+
# The decision estimator has an attribute called tree_ which stores the entire
32+
# tree structure and allow to access to low level attribute. The binary tree
33+
# tree_ is represented as a number of parallel arrays. The i-th element of each
34+
# array holds information about the node `i`. Node 0 is the tree's root. NOTE:
35+
# Some of the arrays only apply to either leaves or split nodes, resp. In this
36+
# case the values of nodes of the other type are arbitrary!
37+
#
38+
# Among those arrays, we have:
39+
# - left_child, id of the left child of the node
40+
# - right_child, id of the right child of the node
41+
# - feature, feature used for splitting the node
42+
# - threshold, threshold value at the node
43+
#
44+
45+
# Using those array, we can parse the tree structure:
46+
47+
print("The binary tree structure has %s nodes and has "
48+
"the following tree structure:"
49+
% estimator.tree_.node_count)
50+
51+
for i in np.arange(estimator.tree_.node_count):
52+
if estimator.tree_.children_left[i] == estimator.tree_.children_right[i]:
53+
print("node=%s leaf node." % i)
54+
else:
55+
print("node=%s test node: go to node %s if X[:, %s] <= %ss else %s."
56+
% (i,
57+
estimator.tree_.children_left[i],
58+
estimator.tree_.feature[i],
59+
estimator.tree_.threshold[i],
60+
estimator.tree_.children_right[i],
61+
))
62+
print()
63+
64+
# First let's retrieve the decision path of each sample. The decision_paths
65+
# method allows to retrieve the node indicator function. A non zero elements at
66+
# position (i, j) indicates that the sample i goes # through the node j.
67+
68+
node_indicator = estimator.decision_paths(X_test)
69+
70+
# Similarly, we can also have the leaves ids reach by each sample.
71+
72+
leave_id = set(estimator.apply(X_test))
73+
74+
# Now, it's possible to get the tests that were used to predict a sample or
75+
# a group of samples. First, let's make it for the sample.
76+
77+
sample_id = 0
78+
node_index = node_indicator.indices[node_indicator.indptr[sample_id]:
79+
node_indicator.indptr[sample_id + 1]]
80+
81+
print('Rules used to predict sample %s: ' % sample_id)
82+
for i, node_id in enumerate(node_index):
83+
if node_id in leave_id:
84+
continue
85+
86+
if (X_test[i, estimator.tree_.feature[node_id]] <=
87+
estimator.tree_.threshold[node_id]):
88+
threshold_sign = "<="
89+
else:
90+
threshold_sign = ">"
91+
92+
print("rule %s : (X[%s, %s] (= %s) %s %s)"
93+
% (i,
94+
sample_id,
95+
estimator.tree_.feature[node_id],
96+
X_test[i, estimator.tree_.feature[node_id]],
97+
threshold_sign,
98+
estimator.tree_.threshold[node_id]))
99+
100+
# For a group of samples, we have the following common node.
101+
sample_ids = [0, 1]
102+
common_nodes = (node_indicator.toarray()[sample_ids].sum(axis=0) ==
103+
len(sample_ids))
104+
105+
common_node_id = np.arange(estimator.tree_.node_count)[common_nodes]
106+
107+
print("\nThe following sample %s shares the following path %s in the tree"
108+
% (sample_ids, common_node_id))
109+
print("It is %s %% of all nodes."
110+
% (len(common_node_id) / estimator.tree_.node_count * 100,))

0 commit comments

Comments
 (0)
0