8000 Add learning curve · scikit-learn/scikit-learn@a876682 · GitHub
[go: up one dir, main page]

Skip to content

Commit a876682

Browse files
AlexanderFabischamueller
authored andcommitted
Add learning curve
1 parent 8f2d8b9 commit a876682

File tree

5 files changed

+543
-53
lines changed

5 files changed

+543
-53
lines changed

examples/plot_learning_curve.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
"""
2+
========================
3+
Plotting Learning Curves
4+
========================
5+
6+
A learning curve shows the validation and training score of a learning
7+
algorithm for varying numbers of training samples. It is a tool to
8+
find out how much we benefit from adding more training data. If both
9+
the validation score and the training score converge to a value that is
10+
too low, we will not benefit much from more training data and we will
11+
probably have to use a learning algorithm or a parametrization of the
12+
current learning algorithm that can learn more complex concepts (i.e.
13+
has a lower bias).
14+
15+
In this example, on the left side the learning curve of a naive Bayes
16+
classifier is shown for the digits dataset. Note that the training score
17+
and the cross-validation score are both not very good at the end. However,
18+
the shape of the curve can be found in more complex datasets very often:
19+
the training score is very high at the beginning and decreases and the
20+
cross-validation score is very low at the beginning and increases. On the
21+
right side we see the learning curve of an SVM with RBF kernel. We can
22+
see clearly that the training score is still around the maximum and the
23+
validation score could be increased with more training samples.
24+
"""
25+
print(__doc__)
26+
27+
import matplotlib.pyplot as plt
28+
from sklearn.naive_bayes import GaussianNB
29+
from sklearn.svm import SVC
30+
from sklearn.datasets import load_digits
31+
from sklearn.learning_curve import learning_curve
32+
33+
34+
digits = load_digits()
35+
X, y = digits.data, digits.target
36+
37+
plt.figure()
38+
plt.title("Learning Curve (Naive Bayes)")
39+
plt.xlabel("Training examples")
40+
plt.ylabel("Score")
41+
train_sizes, train_scores, test_scores = learning_curve(
42+
GaussianNB(), X, y, cv=10, n_jobs=1)
43+
plt.plot(train_sizes, train_scores, label="Training score")
44+
plt.plot(train_sizes, test_scores, label="Cross-validation score")
45+
plt.legend(loc="best")
46+
47+
plt.figure()
48+
plt.title("Learning Curve (SVM, RBF kernel, $\gamma=0.001$)")
49+
plt.xlabel("Training examples")
50+
plt.ylabel("Score")
51+
train_sizes, train_scores, test_scores = learning_curve(
52+
SVC(gamma=0.001), X, y, cv=10, n_jobs=1)
53+
plt.plot(train_sizes, train_scores, label="Training score")
54+
plt.plot(train_sizes, test_scores, label="Cross-validation score")
55+
plt.legend(loc="best")
56+
57+
plt.show()

sklearn/grid_search.py

Lines changed: 74 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -236,59 +236,79 @@ def fit_grid_point(X, y, base_estimator, parameters, train, test, scorer,
236236
print("[GridSearchCV] %s %s" % (msg, (64 - len(msg)) * '.'))
237237

238238
# update parameters of the classifier after a copy of its base structure
239-
clf = clone(base_estimator)
240-
clf.set_params(**parameters)
239+
estimator = clone(base_estimator)
240+
estimator.set_params(**parameters)
241241

242-
if hasattr(base_estimator, 'kernel') and callable(base_estimator.kernel):
242+
X_train, y_train = _split(estimator, X, y, train)
243+
X_test, y_test = _split(estimator, X, y, test, train)
244+
_fit(estimator.fit, X_train, y_train, **fit_params)
245+
this_score = _score(estimator, X_test, y_test, scorer)
246+
247+
if verbose > 2:
248+
msg += ", score=%f" % this_score
249+
if verbose > 1:
250+
end_msg = "%s -%s" % (msg,
251+
logger.short_format_time(time.time() -
252+
start_time))
253+
print("[GridSearchCV] %s %s" % ((64 - len(end_msg)) * '.', end_msg))
254+
255+
return this_score, parameters, _num_samples(X_test)
256+
257+
258+
def _split(estimator, X, y, indices, train_indices=None):
259+
"""Create subset of dataset."""
260+
if hasattr(estimator, 'kernel') and callable(estimator.kernel):
243261
# cannot compute the kernel values with custom function
244262
raise ValueError("Cannot use a custom kernel function. "
245263
"Precompute the kernel matrix instead.")
246264

247265
if not hasattr(X, "shape"):
248-
if getattr(base_estimator, "_pairwise", False):
266+
if getattr(estimator, "_pairwise", False):
249267
raise ValueError("Precomputed kernels or affinity matrices have "
250268
"to be passed as arrays or sparse matrices.")
251-
X_train = [X[idx] for idx in train]
252-
X_test = [X[idx] for idx in test]
269+
X_subset = [X[idx] for idx in indices]
253270
else:
254-
if getattr(base_estimator, "_pairwise", False):
271+
if getattr(estimator, "_pairwise", False):
255272
# X is a precomputed square kernel matrix
256273
if X.shape[0] != X.shape[1]:
257274
raise ValueError("X should be a square kernel matrix")
258-
X_train = X[np.ix_(train, train)]
259-
X_test = X[np.ix_(test, train)]
275+
if train_indices is None:
276+
X_subset = X[np.ix_(indices, indices)]
277+
else:
278+
X_subset = X[np.ix_(indices, train_indices)]
260279
else:
261-
X_train = X[safe_mask(X, train)]
262-
X_test = X[safe_mask(X, test)]
280+
X_subset = X[safe_mask(X, indices)]
263281

264282
if y is not None:
265-
y_test = y[safe_mask(y, test)]
266-
y_train = y[safe_mask(y, train)]
267-
clf.fit(X_train, y_train, **fit_params)
283+
y_subset = y[safe_mask(y, indices)]
284+
else:
285+
y_subset = None
286+
287+
return X_subset, y_subset
288+
268289

269-
if scorer is not None:
270-
this_score = scorer(clf, X_test, y_test)
290+
def _fit(fit_function, X_train, y_train, **fit_params):
291+
"""Fit and estimator on a given training set."""
292+
if y_train is None:
293+
fit_function(X_train, **fit_params)
294+
else:
295+
fit_function(X_train, y_train, **fit_params)
296+
297+
298+
def _score(estimator, X_test, y_test, scorer):
299+
"""Compute the score of an estimator on a given test set."""
300+
if y_test is None:
301+
if scorer is None:
302+
this_score = estimator.score(X_test)
271303
else:
272-
this_score = clf.score(X_test, y_test)
304+
this_score = scorer(estimator, X_test)
273305
else:
274-
clf.fit(X_train, **fit_params)
275-
if scorer is not None:
276-
this_score = scorer(clf, X_test)
306+
if scorer is None:
307+
this_score = estimator.score(X_test, y_test)
277308
else:
278-
this_score = clf.score(X_test)
279-
280-
if not isinstance(this_score, numbers.Number):
281-
raise ValueError("scoring must return a number, got %s (%s)"
282-
" instead." % (str(this_score), type(this_score)))
309+
this_score = scorer(estimator, X_test, y_test)
283310

284-
if verbose > 2:
285-
msg += ", score=%f" % this_score
286-
if verbose > 1:
287-
end_msg = "%s -%s" % (msg,
288-
logger.short_format_time(time.time() -
289-
start_time))
290-
print("[GridSearchCV] %s %s" % ((64 - len(end_msg)) * '.', end_msg))
291-
return this_score, parameters, _num_samples(X_test)
311+
return this_score
292312

293313

294314
def _check_param_grid(param_grid):
@@ -331,6 +351,24 @@ def __repr__(self):
331351
self.parameters)
332352

333353

354+
def _check_scorable(estimator, scoring=None, loss_func=None, score_func=None):
355+
"""Check that estimator can be fitted and score can be computed."""
356+
if (not hasattr(estimator, 'fit') or
357+
not (hasattr(estimator, 'predict')
358+
or hasattr(estimator, 'score'))):
359+
raise TypeError("estimator should a be an estimator implementing"
360+
" 'fit' and 'predict' or 'score' methods,"
361+
" %s (type %s) was passed" %
362+
(estimator, type(estimator)))
363+
if (scoring is None and loss_func is None and score_func
364+
is None):
365+
if not hasattr(estimator, 'score'):
366+
raise TypeError(
367+
"If no scoring is specified, the estimator passed "
368+
"should have a 'score' method. The estimator %s "
369+
"does not." % estimator)
370+
371+
334372
class BaseSearchCV(six.with_metaclass(ABCMeta, BaseEstimator,
335373
MetaEstimatorMixin)):
336374
"""Base class for hyper parameter search with cross-validation."""
@@ -351,7 +389,8 @@ def __init__(self, estimator, scoring=None, loss_func=None,
351389
self.cv = cv
352390
self.verbose = verbose
353391
self.pre_dispatch = pre_dispatch
354-
self._check_estimator()
392+
_check_scorable(self.estimator, scoring=self.scoring,
393+
loss_func=self.loss_func, score_func=self.score_func)
355394

356395
def score(self, X, y=None):
357396
"""Returns the score on the given test data and labels, if the search
@@ -396,24 +435,7 @@ def decision_function(self):
396435
@property
397436
def transform(self):
398437
return self.best_estimator_.transform
399-
400-
def _check_estimator(self):
401-
"""Check that estimator can be fitted and score can be computed."""
402-
if (not hasattr(self.estimator, 'fit') or
403-
not (hasattr(self.estimator, 'predict')
404-
or hasattr(self.estimator, 'score'))):
405-
raise TypeError("estimator should a be an estimator implementing"
406-
" 'fit' and 'predict' or 'score' methods,"
407-
" %s (type %s) was passed" %
408-
(self.estimator, type(self.estimator)))
409-
if (self.scoring is None and self.loss_func is None and self.score_func
410-
is None):
411-
if not hasattr(self.estimator, 'score'):
412-
raise TypeError(
413-
"If no scoring is specified, the estimator passed "
414-
"should have a 'score' method. The estimator %s "
415-
"does not." % self.estimator)
416-
438+
417439
def _fit(self, X, y, parameter_iterable):
418440
"""Actual fitting, performing the search over parameters."""
419441

0 commit comments

Comments
 (0)
0