8000 Merge pull request #3 from bdholt1/glouppe-tree-mo · NelleV/scikit-learn@dc8e65a · GitHub
[go: up one dir, main page]

Skip to content

Commit dc8e65a

Browse files
committed
Merge pull request #3 from bdholt1/glouppe-tree-mo
Glouppe tree mo
2 parents e5a61dc + 94a5f3f commit dc8e65a

File tree

2 files changed

+68
-3
lines changed

2 files changed

+68
-3
lines changed

doc/modules/tree.rst

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ A multi-output problem is a supervised learning problem with several outputs
191191
to predict, that is when Y is a 2d array of size ``[n_samples, n_outputs]``.
192192

193193
When there is no correlation between the outputs, a very simple way to solve
194-
this kind of problems is to build n independent models, i.e. one for each
194+
this kind of problem is to build n independent models, i.e. one for each
195195
output, and then to use those models to independently predict each one of the n
196196
outputs. However, because it is likely that the output values related to the
197197
same input are themselves correlated, an often better way is to build a single
@@ -200,7 +200,7 @@ lower training time since only a single estimator is built. Second, the
200200
generalization accuracy of the resulting estimator may often be increased.
201201

202202
With regard to decision trees, this strategy can readily be used to support
203-
multi-output problems. This indeed amounts to:
203+
multi-output problems. This requires the following changes:
204204

205205
- Store n output values in leaves, instead of 1;
206206
- Use splitting criteria that compute the average reduction across all
@@ -215,7 +215,16 @@ of size ``[n_samples, n_outputs]`` then the resulting estimator will:
215215
- Output a list of n_output arrays of class probabilities upon
216216
``predict_proba``.
217217

218-
The use of multi-output trees is demonstrated in
218+
The use of multi-output trees for regression is demonstrated in
219+
:ref:`example_tree_plot_tree_regression_multioutput.py`. In this example, the input
220+
X is a single real value and the outputs Y are the sine and cosine of X.
221+
222+
.. figure:: ../auto_examples/tree/images/plot_tree_regression_multioutput_1.png
223+
:target: ../auto_examples/tree/plot_tree_regression_multioutput.html
224+
:scale: 75
225+
:align: center
226+
227+
The use of multi-output trees for classification is demonstrated in
219228
:ref:`example_ensemble_plot_forest_multioutput.py`. In this example, the inputs
220229
X are the pixels of the upper half of faces and the outputs Y are the pixels of
221230
the lower half of those faces.
@@ -227,6 +236,7 @@ the lower half of those faces.
227236

228237
.. topic:: Examples:
229238

239+
* :ref:`example_tree_plot_tree_regression_multioutput.py`
230240
* :ref:`example_ensemble_plot_forest_multioutput.py`
231241

232242

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
"""
2+
===================================================================
3+
Multi-output Decision Tree Regression
4+
===================================================================
5+
6+
Multi-output regression with :ref:`decision trees <tree>`: the decision tree
7+
is used to predict simultaneously the noisy x and y observations of a circle
8+
given a single underlying feature. As a result, it learns local linear
9+
regressions approximating the circle.
10+
11+
We can see that if the maximum depth of the tree (controlled by the
12+
`max_depth` parameter) is set too high, the decision trees learn too fine
13+
details of the training data and learn from the noise, i.e. they overfit.
14+
"""
15+
print __doc__
16+
17+
import numpy as np
18+
19+
# Create a random dataset
20+
rng = np.random.RandomState(1)
21+
X = np.sort(200 * rng.rand(100, 1) - 100, axis=0)
22+
y = np.array([np.pi * np.sin(X).ravel(), np.pi * np.cos(X).ravel()]).T
23+
y[::5,:] += (0.5 - rng.rand(20,2))
24+
25+
# Fit regression model
26+
from sklearn.tree import DecisionTreeRegressor
27+
28+
clf_1 = DecisionTreeRegressor(max_depth=2)
29+
clf_2 = DecisionTreeRegressor(max_depth=5)
30+
clf_3 = DecisionTreeRegressor(max_depth=8)
31+
clf_1.fit(X, y)
32+
clf_2.fit(X, y)
33+
clf_3.fit(X, y)
34+
35+
# Predict
36+
X_test = np.arange(-100.0, 100.0, 0.01)[:, np.newaxis]
37+
y_1 = clf_1.predict(X_test)
38+
y_2 = clf_2.predict(X_test)
39+
y_3 = clf_3.predict(X_test)
40+
41+
# Plot the results
42+
import pylab as pl
43+
44+
pl.figure()
45+
pl.scatter(y[:,0], y[:,1], c="k", label="data")
46+
pl.scatter(y_1[:,0], y_1[:,1], c="g", label="max_depth=2")
47+
pl.scatter(y_2[:,0], y_2[:,1], c="r", label="max_depth=5")
48+
pl.scatter(y_3[:,0], y_3[:,1], c="b", label="max_depth=8")
49+
pl.xlim([-6, 6])
50+
pl.ylim([-6, 6])
51+
pl.xlabel("data")
52+
pl.ylabel("target")
53+
pl.title("Multi-output Decision Tree Regression")
54+
pl.legend()
55+
pl.show()

0 commit comments

Comments
 (0)
0