-
-
Notifications
You must be signed in to change notification settings - Fork 25.9k
[MRG] Learning curves #2701
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MRG] Learning curves #2701
Changes from all commits
0ff07a0
be6e185
96fc078
2f6373d
811838a
0972b2e
30fd1c0
a0b1190
ef23d62
5fccd17
b3cb14e
d3e52f6
08a0b49
cd41b09
9fc5f3b
6a516b4
39dcb68
c42cae6
9dd0601
cec31e2
a64ef4e
36be74d
b9c838d
ff1aef4
3283298
754a104
9ede03c
93f1acb
9748127
822bd7b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
""" | ||
======================== | ||
Plotting Learning Curves | ||
======================== | ||
|
||
A learning curve shows the validation and training score of a learning | ||
algorithm for varying numbers of training samples. It is a tool to | ||
find out how much we benefit from adding more training data. If both | ||
the validation score and the training score converge to a value that is | ||
too low, we will not benefit much from more training data and we will | ||
probably have to use a learning algorithm or a parametrization of the | ||
current learning algorithm that can learn more complex concepts (i.e. | ||
has a lower bias). | ||
|
||
In this example, on the left side the learning curve of a naive Bayes | ||
classifier is shown for the digits dataset. Note that the training score | ||
and the cross-validation score are both not very good at the end. However, | ||
the shape of the curve can be found in more complex datasets very often: | ||
the training score is very high at the beginning and decreases and the | ||
cross-validation score is very low at the beginning and increases. On the | ||
right side we see the learning curve of an SVM with RBF kernel. We can | ||
see clearly that the training score is still around the maximum and the | ||
validation score could be increased with more training samples. | ||
""" | ||
print(__doc__) | ||
|
||
import matplotlib.pyplot as plt | ||
from sklearn.naive_bayes import GaussianNB | ||
from sklearn.svm import SVC | ||
from sklearn.datasets import load_digits | ||
from sklearn.learning_curve import learning_curve | ||
|
||
|
||
digits = load_digits() | ||
X, y = digits.data, digits.target | ||
|
||
plt.figure() | ||
plt.title("Learning Curve (Naive Bayes)") | ||
plt.xlabel("Training examples") | ||
plt.ylabel("Score") | ||
train_sizes, train_scores, test_scores = learning_curve( | ||
GaussianNB(), X, y, cv=10, n_jobs=1) | ||
plt.plot(train_sizes, train_scores, label="Training score") | ||
plt.plot(train_sizes, test_scores, label="Cross-validation score") | ||
plt.legend(loc="best") | ||
|
||
plt.figure() | ||
plt.title("Learning Curve (SVM, RBF kernel, $\gamma=0.001$)") | ||
plt.xlabel("Training examples") | ||
plt.ylabel("Score") | ||
train_sizes, train_scores, test_scores = learning_curve( | ||
SVC(gamma=0.001), X, y, cv=10, n_jobs=1) | ||
plt.plot(train_sizes, train_scores, label="Training score") | ||
plt.plot(train_sizes, test_scores, label="Cross-validation score") | ||
plt.legend(loc="best") | ||
|
||
plt.show() |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -236,59 +236,79 @@ def fit_grid_point(X, y, base_estimator, parameters, train, test, scorer, | |
print("[GridSearchCV] %s %s" % (msg, (64 - len(msg)) * '.')) | ||
|
||
# update parameters of the classifier after a copy of its base structure | ||
clf = clone(base_estimator) | ||
clf.set_params(**parameters) | ||
estimator = clone(base_estimator) | ||
estimator.set_params(**parameters) | ||
|
||
if hasattr(base_estimator, 'kernel') and callable(base_estimator.kernel): | ||
X_train, y_train = _split(estimator, X, y, train) | ||
X_test, y_test = _split(estimator, X, y, test, train) | ||
_fit(estimator.fit, X_train, y_train, **fit_params) | ||
this_score = _score(estimator, X_test, y_test, scorer) | ||
|
||
if verbose > 2: | ||
msg += ", score=%f" % this_score | ||
if verbose > 1: | ||
end_msg = "%s -%s" % (msg, | ||
logger.short_format_time(time.time() - | ||
start_time)) | ||
print("[GridSearchCV] %s %s" % ((64 - len(end_msg)) * '.', end_msg)) | ||
|
||
return this_score, parameters, _num_samples(X_test) | ||
|
||
|
||
def _split(estimator, X, y, indices, train_indices=None): | ||
"""Create subset of dataset.""" | ||
if hasattr(estimator, 'kernel') and callable(estimator.kernel): | ||
# cannot compute the kernel values with custom function | ||
raise ValueError("Cannot use a custom kernel function. " | ||
"Precompute the kernel matrix instead.") | ||
|
||
if not hasattr(X, "shape"): | ||
if getattr(base_estimator, "_pairwise", False): | ||
if getattr(estimator, "_pairwise", False): | ||
raise ValueError("Precomputed kernels or affinity matrices have " | ||
"to be passed as arrays or sparse matrices.") | ||
X_train = [X[idx] for idx in train] | ||
X_test = [X[idx] for idx in test] | ||
X_subset = [X[idx] for idx in indices] | ||
else: | ||
if getattr(base_estimator, "_pairwise", False): | ||
if getattr(estimator, "_pairwise", False): | ||
# X is a precomputed square kernel matrix | ||
if X.shape[0] != X.shape[1]: | ||
raise ValueError("X should be a square kernel matrix") | ||
X_train = X[np.ix_(train, train)] | ||
X_test = X[np.ix_(test, train)] | ||
if train_indices is None: | ||
X_subset = X[np.ix_(indices, indices)] | ||
else: | ||
X_subset = X[np.ix_(indices, train_indices)] | ||
else: | ||
X_train = X[safe_mask(X, train)] | ||
X_test = X[safe_mask(X, test)] | ||
X_subset = X[safe_mask(X, indices)] | ||
|
||
if y is not None: | ||
y_test = y[safe_mask(y, test)] | ||
y_train = y[safe_mask(y, train)] | ||
clf.fit(X_train, y_train, **fit_params) | ||
y_subset = y[safe_mask(y, indices)] | ||
else: | ||
y_subset = None | ||
|
||
return X_subset, y_subset | ||
|
||
|
||
if scorer is not None: | ||
this_score = scorer(clf, X_test, y_test) | ||
def _fit(fit_function, X_train, y_train, **fit_params): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree with @amueller that it is not great style to have a helper function that is a 4-liner just to avoid an "if" statement in the main code. The reason that I dislike such style is that when reading the calling code, you don't know what this function is doing, especially given the fact that '_fit' as a name doesn't doesn't tell what this does differently from calling the fit_function. All in all, the code is riddled with those small helper functions, and I find this hard to read. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This can be reconsidered in #2736. I'm not a big fan of these mini-helpers either. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1 as well. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That has been fixed in #2736 already. |
||
"""Fit and estimator on a given training set.""" | ||
if y_train is None: | ||
fit_function(X_train, **fit_params) | ||
else: | ||
fit_function(X_train, y_train, **fit_params) | ||
|
||
|
||
def _score(estimator, X_test, y_test, scorer): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These other private functions need minimal documentation as well: they have fewer arguments, and their meaning is more clear, so probably just a one-line summary doc-string would be fine. |
||
"""Compute the score of an estimator on a given test set.""" | ||
if y_test is None: | ||
if scorer is None: | ||
this_score = estimator.score(X_test) | ||
else: | ||
this_score = clf.score(X_test, y_test) | ||
this_score = scorer(estimator, X_test) | ||
else: | ||
clf.fit(X_train, **fit_params) | ||
if scorer is not None: | ||
this_score = scorer(clf, X_test) | ||
if scorer is None: | ||
this_score = estimator.score(X_test, y_test) | ||
else: | ||
this_score = clf.score(X_test) | ||
|
||
if not isinstance(this_score, numbers.Number): | ||
raise ValueError("scoring must return a number, got %s (%s)" | ||
" instead." % (str(this_score), type(this_score))) | ||
this_score = scorer(estimator, X_test, y_test) | ||
|
||
if verbose > 2: | ||
msg += ", score=%f" % this_score | ||
if verbose > 1: | ||
end_msg = "%s -%s" % (msg, | ||
logger.short_format_time(time.time() - | ||
start_time)) | ||
print("[GridSearchCV] %s %s" % ((64 - len(end_msg)) * '.', end_msg)) | ||
return this_score, parameters, _num_samples(X_test) | ||
return this_score | ||
|
||
|
||
def _check_param_grid(param_grid): | ||
|
@@ -331,6 +351,24 @@ def __repr__(self): | |
self.parameters) | ||
|
||
|
||
def _check_scorable(estimator, scoring=None, loss_func=None, score_func=None): | ||
"""Check that estimator can be fitted and score can be computed.""" | ||
if (not hasattr(estimator, 'fit') or | ||
not (hasattr(estimator, 'predict') | ||
or hasattr(estimator, 'score'))): | ||
raise TypeError("estimator should a be an estimator implementing" | ||
" 'fit' and 'predict' or 'score' methods," | ||
" %s (type %s) was passed" % | ||
(estimator, type(estimator))) | ||
if (scoring is None and loss_func is None and score_func | ||
is None): | ||
if not hasattr(estimator, 'score'): | ||
raise TypeError( | ||
"If no scoring is specified, the estimator passed " | ||
"should have a 'score' method. The estimator %s " | ||
"does not." % estimator) | ||
|
||
|
||
class BaseSearchCV(six.with_metaclass(ABCMeta, BaseEstimator, | ||
MetaEstimatorMixin)): | ||
"""Base class for hyper parameter search with cross-validation.""" | ||
|
@@ -351,7 +389,8 @@ def __init__(self, estimator, scoring=None, loss_func=None, | |
self.cv = cv | ||
self.verbose = verbose | ||
self.pre_dispatch = pre_dispatch | ||
self._check_estimator() | ||
_check_scorable(self.estimator, scoring=self.scoring, | ||
loss_func=self.loss_func, score_func=self.score_func) | ||
|
||
def score(self, X, y=None): | ||
"""Returns the score on the given test data and labels, if the search | ||
|
@@ -396,24 +435,7 @@ def decision_function(self): | |
@property | ||
def transform(self): | ||
return self.best_estimator_.transform | ||
|
||
def _check_estimator(self): | ||
"""Check that estimator can be fitted and score can be computed.""" | ||
if (not hasattr(self.estimator, 'fit') or | ||
not (hasattr(self.estimator, 'predict') | ||
or hasattr(self.estimator, 'score'))): | ||
raise TypeError("estimator should a be an estimator implementing" | ||
" 'fit' and 'predict' or 'score' methods," | ||
" %s (type %s) was passed" % | ||
(self.estimator, type(self.estimator))) | ||
if (self.scoring is None and self.loss_func is None and self.score_func | ||
is None): | ||
if not hasattr(self.estimator, 'score'): | ||
raise TypeError( | ||
"If no scoring is specified, the estimator passed " | ||
"should have a 'score' method. The estimator %s " | ||
"does not." % self.estimator) | ||
|
||
|
||
def _fit(self, X, y, parameter_iterable): | ||
"""Actual fitting, performing the search over parameters.""" | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not entirely convinced that it is good to use this function here instead of writing the
if
and defining the function only inlearning_curve
, but I guess it is a matter of style and I don't have a strong opinion.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I prefer this version. The other version would require handling the unsupervised case twice.