2
2
========================
3
3
Plotting Learning Curves
4
4
========================
5
-
6
- On the left side the learning curve of a naive Bayes classifier is shown for
7
- the digits dataset. Note that the training score and the cross-validation score
8
- are both not very good at the end. However, the shape of the curve can be found
9
- in more complex datasets very often: the training score is very high at the
10
- beginning and decreases and the cross-validation score is very low at the
11
- beginning and increases. On the right side we see the learning curve of an SVM
12
- with RBF kernel. We can see clearly that the training score is still around
13
- the maximum and the validation score could be increased with more training
14
- samples.
5
+ In the first column, first row the learning curve of a naive Bayes classifier
6
+ is shown for the digits dataset. Note that the training score and the
7
+ cross-validation score are both not very good at the end. However, the shape
8
+ of the curve can be found in more complex datasets very often: the training
9
+ score is very high at the beginning and decreases and the cross-validation
10
+ score is very low at the beginning and increases. In the second column, first
11
+ row we see the learning curve of an SVM with RBF kernel. We can see clearly
12
+ that the training score is still around the maximum and the validation score
13
+ could be increased with more training samples. The plots in the second row
14
+ show the times required by the models to train with various sizes of training
15
+ dataset. The plots in the third row show how much time was required to train
16
+ the models for each training sizes.
15
17
"""
16
18
print (__doc__ )
17
19
24
26
from sklearn .model_selection import ShuffleSplit
25
27
26
28
27
- def plot_learning_curve (estimator , title , X , y , ylim = None , cv = None ,
29
+ def plot_learning_curve (estimator , title , X , y , axes = None , ylim = None , cv = None ,
28
30
n_jobs = None , train_sizes = np .linspace (.1 , 1.0 , 5 )):
29
31
"""
30
- Generate a simple plot of the test and training learning curve.
32
+ Generate 3 plots: the test and training learning curve, the training
33
+ samples vs fit times curve, the fit times vs score curve.
31
34
32
35
Parameters
33
36
----------
@@ -45,6 +48,9 @@ def plot_learning_curve(estimator, title, X, y, ylim=None, cv=None,
45
48
Target relative to X for classification or regression;
46
49
None for unsupervised learning.
47
50
51
+ axes : array of 3 axes, optional (default=None)
52
+ Axes to use for plotting the curves.
53
+
48
54
ylim : tuple, shape (ymin, ymax), optional
49
55
Defines minimum and maximum yvalues plotted.
50
56
@@ -79,34 +85,63 @@ def plot_learning_curve(estimator, title, X, y, ylim=None, cv=None,
79
85
be big enough to contain at least one sample from each class.
80
86
(default: np.linspace(0.1, 1.0, 5))
81
87
"""
82
- plt .figure ()
83
- plt .title (title )
88
+ if axes is None :
89
+ _ , axes = plt .subplots (1 , 3 , figsize = (20 , 5 ))
90
+
91
+ axes [0 ].set_title (title )
84
92
if ylim is not None :
85
- plt .ylim (* ylim )
86
- plt .xlabel ("Training examples" )
87
- plt .ylabel ("Score" )
88
- train_sizes , train_scores , test_scores = learning_curve (
89
- estimator , X , y , cv = cv , n_jobs = n_jobs , train_sizes = train_sizes )
93
+ axes [0 ].set_ylim (* ylim )
94
+ axes [0 ].set_xlabel ("Training examples" )
95
+ axes [0 ].set_ylabel ("Score" )
96
+
97
+ train_sizes , train_scores , test_scores , fit_times , _ = \
98
+ learning_curve (estimator , X , y , cv = cv , n_jobs = n_jobs ,
99
+ train_sizes = train_sizes ,
100
+ return_times = True )
90
101
train_scores_mean = np .mean (train_scores , axis = 1 )
91
102
train_scores_std = np .std (train_scores , axis = 1 )
92
103
test_scores_mean = np .mean (test_scores , axis = 1 )
93
104
test_scores_std = np .std (test_scores , axis = 1 )
94
- plt .grid ()
95
-
96
- plt .fill_between (train_sizes , train_scores_mean - train_scores_std ,
97
- train_scores_mean + train_scores_std , alpha = 0.1 ,
98
- color = "r" )
99
- plt .fill_between (train_sizes , test_scores_mean - test_scores_std ,
100
- test_scores_mean + test_scores_std , alpha = 0.1 , color = "g" )
101
- plt .plot (train_sizes , train_scores_mean , 'o-' , color = "r" ,
102
- label = "Training score" )
103
- plt .plot (train_sizes , test_scores_mean , 'o-' , color = "g" ,
104
- label = "Cross-validation score" )
105
-
106
- plt .legend (loc = "best" )
105
+ fit_times_mean = np .mean (fit_times , axis = 1 )
106
+ fit_times_std = np .std (fit_times , axis = 1 )
107
+
108
+ # Plot learning curve
109
+ axes [0 ].grid ()
110
+ axes [0 ].fill_between (train_sizes , train_scores_mean - train_scores_std ,
111
+ train_scores_mean + train_scores_std , alpha = 0.1 ,
112
+ color = "r" )
113
+ axes [0 ].fill_between (train_sizes , test_scores_mean - test_scores_std ,
114
+ test_scores_mean + test_scores_std , alpha = 0.1 ,
115
+ color = "g" )
116
+ axes [0 ].plot (train_sizes , train_scores_mean , 'o-' , color = "r" ,
117
+ label = "Training score" )
118
+ axes[0 ].plot (train_sizes , test_scores_mean , 'o-' , color = "g" ,
119
+ label = "Cross-validation score" )
120
+ axes [0 ].legend (loc = "best" )
121
+
122
+ # Plot n_samples vs fit_times
123
+ axes [1 ].grid ()
124
+ axes [1 ].plot (train_sizes , fit_times_mean , 'o-' )
125
+ axes [1 ].fill_between (train_sizes , fit_times_mean - fit_times_std ,
126
+ fit_times_mean + fit_times_std , alpha = 0.1 )
127
+ axes [1 ].set_xlabel ("Training examples" )
128
+ axes [1 ].set_ylabel ("fit_times" )
129
+ axes [1 ].set_title ("Scalability of the model" )
130
+
131
+ # Plot fit_time vs score
132
+ axes [2 ].grid ()
133
+ axes [2 ].plot (fit_times_mean , test_scores_mean , 'o-' )
134
+ axes [2 ].fill_between (fit_times_mean , test_scores_mean - test_scores_std ,
135
+ test_scores_mean + test_scores_std , alpha = 0.1 )
136
+ axes [2 ].set_xlabel ("fit_times" )
137
+ axes [2 ].set_ylabel ("Score" )
138
+ axes [2 ].set_title ("Performance of the model" )
139
+
107
140
return plt
108
141
109
142
143
+ fig , axes = plt .subplots (3 , 2 , figsize = (10 , 15 ))
144
+
110
145
digits = load_digits ()
111
146
X , y = digits .data , digits .target
112
147
@@ -117,12 +152,14 @@ def plot_learning_curve(estimator, title, X, y, ylim=None, cv=None,
117
152
cv = ShuffleSplit (n_splits = 100 , test_size = 0.2 , random_state = 0 )
118
153
119
154
estimator = GaussianNB ()
120
- plot_learning_curve (estimator , title , X , y , ylim = (0.7 , 1.01 ), cv = cv , n_jobs = 4 )
155
+ plot_learning_curve (estimator , title , X , y , axes = axes [:, 0 ], ylim = (0.7 , 1.01 ),
156
+ cv = cv , n_jobs = 4 )
121
157
122
158
title = r"Learning Curves (SVM, RBF kernel, $\gamma=0.001$)"
123
159
# SVC is more expensive so we do a lower number of CV iterations:
124
160
cv = ShuffleSplit (n_splits = 10 , test_size = 0.2 , random_state = 0 )
125
161
estimator = SVC (gamma = 0.001 )
126
- plot_learning_curve (estimator , title , X , y , (0.7 , 1.01 ), cv = cv , n_jobs = 4 )
162
+ plot_learning_curve (estimator , title , X , y , axes = axes [:, 1 ], ylim = (0.7 , 1.01 ),
163
+ cv = cv , n_jobs = 4 )
127
164
128
165
plt .show ()
0 commit comments