8000 Add tests · scikit-learn/scikit-learn@9dc7d72 · GitHub
[go: up one dir, main page]

Skip to content

Commit 9dc7d72

Browse files
Add tests
1 parent 7bb4946 commit 9dc7d72

File tree

4 files changed

+134
-34
lines changed

4 files changed

+134
-34
lines changed

examples/plot_learning_curve.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import matplotlib.pyplot as plt
33
from sklearn.naive_bayes import GaussianNB
44
from sklearn.datasets import load_digits
5-
from sklearn.learning_curve import learning_curve # TODO should be: from sklearn import learning_curve
5+
from sklearn.learning_curve import learning_curve
66

77
if __name__ == "__main__":
88
estimator = GaussianNB()

sklearn/grid_search.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,8 @@ def _split_and_score(base_estimator, X, y, parameters, train, test, scorer,
254254
return_train_score=False, **fit_params):
255255
# update parameters of the classifier after a copy of its base structure
256256
estimator = clone(base_estimator)
257-
estimator.set_params(**parameters)
257+
if len(parameters) > 0:
258+
estimator.set_params(**parameters)
258259

259260
if hasattr(base_estimator, 'kernel') and callable(base_estimator.kernel):
260261
# cannot compute the kernel values with custom function

sklearn/learning_curve.py

Lines changed: 47 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import numpy as np
2+
import warnings
23
from .base import is_classifier, clone
34
from .cross_validation import _check_cv
45
from .utils import check_arrays
@@ -9,7 +10,10 @@
910
def learning_curve(estimator, X, y, n_samples_range=np.linspace(0.1, 1.0, 10),
1011
cv=None, scoring=None, exploit_incremental_learning=False,
1112
n_jobs=1, verbose=0):
12-
""" TODO document me
13+
"""Learning curve
14+
15+
Determines cross-validated training and test scores for different training
16+
set sizes.
1317
1418
Parameters
1519
----------
@@ -63,49 +67,26 @@ def learning_curve(estimator, X, y, n_samples_range=np.linspace(0.1, 1.0, 10),
6367
test_scores : array, shape = [n_ticks,]
6468
Scores on test set.
6569
"""
66-
# TODO tests, doc
6770
# TODO use verbose argument
6871

69-
X, y = check_arrays(X, y, sparse_format='csr', allow_lists=True)
70-
# Make a list since we will be iterating multiple times over the folds
71-
cv = list(_check_cv(cv, X, y, classifier=is_classifier(estimator)))
72-
7372
if exploit_incremental_learning and not hasattr(estimator, 'partial_fit'):
7473
raise ValueError('An estimator must support the partial_fit interface '
7574
'to exploit incremental learning')
7675

77-
# Determine range of number of training samples
76+
X, y = check_arrays(X, y, sparse_format='csr', allow_lists=True)
77+
# Make a list since we will be iterating multiple times over the folds
78+
cv = list(_check_cv(cv, X, y, classifier=is_classifier(estimator)))
79+
7880
n_max_training_samples = cv[0][0].shape[0]
79-
n_samples_range = np.asarray(n_samples_range)
80-
n_min_required_samples = np.min(n_samples_range)
81-
n_max_required_samples = np.max(n_samples_range)
82-
if np.issubdtype(n_samples_range.dtype, np.float):
83-
if n_min_required_samples <= 0.0 or n_max_required_samples > 1.0:
84-
raise ValueError("n_samples_range must be within (0, 1], "
85-
"but is within [%f, %f]."
86-
% (n_min_required_samples,
87-
n_max_required_samples))
88-
n_samples_range = np.unique((n_samples_range *
89-
n_max_training_samples).astype(np.int))
90-
# TODO we could
91-
# - print a warning
92-
# - *, inverse = np.unique(*, return_inverse=True); return np.take(., inverse)
93-
# if there are duplicate elements
94-
else:
95-
if (n_min_required_samples <= 0 or
96-
n_max_required_samples > n_max_training_samples):
97-
raise ValueError("n_samples_range must be within (0, %d], "
98-
"but is within [%d, %d]."
99-
% (n_max_training_samples,
100-
n_min_required_samples,
101-
n_max_required_samples))
81+
n_samples_range, n_unique_ticks = _translate_n_samples_range(
82+
n_samples_range, n_max_training_samples)
10283

10384
_check_scorable(estimator, scoring=scoring)
10485
scorer = _deprecate_loss_and_score_funcs(scoring=scoring)
10586

10687
if exploit_incremental_learning:
88+
raise NotImplemented("Incremental learning is not supported yet")
10789
# TODO exploit incremental learning
108-
pass
10990
else:
11091
out = Parallel(
11192
# TODO use pre_dispatch parameter? what is it good for?
@@ -116,13 +97,47 @@ def learning_curve(estimator, X, y, n_samples_range=np.linspace(0.1, 1.0, 10),
11697
for n_train_samples in n_samples_range for train, test in cv)
11798

11899
out = np.array(out)
119-
n_unique_ticks = n_samples_range.shape[0]
120100
n_cv_folds = out.shape[0]/n_unique_ticks
121101
out = out.reshape(n_unique_ticks, n_cv_folds, 2)
122102
avg_over_cv = out.mean(axis=1).reshape(n_unique_ticks, 2)
123103

124104
return n_samples_range, avg_over_cv[:, 0], avg_over_cv[:, 1]
125105

106+
107+
def _translate_n_samples_range(n_samples_range, n_max_training_samples):
108+
"""Determine range of number of training samples"""
109+
n_samples_range = np.asarray(n_samples_range)
110+
n_ticks = n_samples_range.shape[0]
111+
n_min_required_samples = np.min(n_samples_range)
112+
n_max_required_samples = np.max(n_samples_range)
113+
if np.issubdtype(n_samples_range.dtype, np.float):
114+
if n_min_required_samples <= 0.0 or n_max_required_samples > 1.0:
115+
raise ValueError("n_samples_range must be within (0, 1], "
116+
"but is within [%f, %f]."
117+
% (n_min_required_samples,
118+
n_max_required_samples))
119+
n_samples_range = (n_samples_range * n_max_training_samples
120+
).astype(np.int)
121+
n_samples_range = np.clip(n_samples_range, 1, n_max_training_samples)
122+
else:
123+
if (n_min_required_samples <= 0 or
124+
n_max_required_samples > n_max_training_samples):
125+
raise ValueError("n_samples_range must be within (0, %d], "
126+
"but is within [%d, %d]."
127+
% (n_max_training_samples,
128+
n_min_required_samples,
129+
n_max_required_samples))
130+
131+
n_samples_range = np.unique(n_samples_range)
132+
n_unique_ticks = n_samples_range.shape[0]
133+
if n_ticks > n_unique_ticks:
134+
warnings.warn("Number of ticks will be less than than the size of "
135+
"'n_samples_range' (%d instead of %d)."
136+
% (n_unique_ticks, n_ticks), RuntimeWarning)
137+
138+
return n_samples_range, n_unique_ticks
139+
140+
126141
def _fit_estimator(base_estimator, X, y, train, test, n_train_samples,
127142
scorer, verbose):
128143
# HACK as long as boolean indices are allowed in cv generators

sklearn/tests/test_learning_curve.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
import numpy as np
2+
from sklearn.learning_curve import learning_curve
3+
from sklearn.utils.testing import assert_raises
4+
from sklearn.utils.testing import assert_warns
5+
from sklearn.utils.testing import assert_array_equal
6+
from sklearn.utils.testing import assert_array_almost_equal
7+
from sklearn.datasets import make_classification
8+
from sklearn.svm import SVC
9+
10+
class MockImprovingClassifier(object):
11+
"""Dummy classifier to test the learning curve"""
12+
def __init__(self, n_max_train_samples):
13+
self.n_max_train_samples = n_max_train_samples
14+
self.n_train_samples = 0
15+
16+
def fit(self, X_subset, y_subset):
17+
self.X_subset = X_subset
18+
self.y_subset = y_subset
19+
self.n_train_samples = X_subset.shape[0]
20+
return self
21+
22+
def predict(self, X):
23+
raise NotImplemented
24+
25+
def score(self, X=None, Y=None):
26+
# training score becomes worse (2 -> 1), test error better (0 -> 1)
27+
if X is self.X_subset:
28+
return 2. - float(self.n_train_samples) / self.n_max_train_samples
29+
else:
30+
return float(self.n_train_samples) / self.n_max_train_samples
31+
32+
def get_params(self, deep=False):
33+
return {"n_max_train_samples" : self.n_max_train_samples}
34+
35+
def set_params(self, **params):
36+
self.n_max_train_samples = params["n_max_train_samples"]
37+
return self
38+
39+
40+
def test_learning_curve():
41+
X, y = make_classification(n_samples=30, n_features=1, n_informative=1,
42+
n_redundant=0, n_classes=2,
43+
n_clusters_per_class=1, random_state=0)
44+
estimator = MockImprovingClassifier(20)
45+
n_samples_range, train_scores, test_scores = learning_curve(estimator,
46+
X, y, cv=3)
47+
assert_array_equal(n_samples_range, np.linspace(2, 20, 10))
48+
assert_array_almost_equal(train_scores, np.linspace(1.9, 1.0, 10))
49+
assert_array_almost_equal(test_scores, np.linspace(0.1, 1.0, 10))
50+
51+
52+
def test_incremental_learning_not_possible():
53+
X, y = make_classification(n_samples=2, n_features=1, n_informative=1,
54+
n_redundant=0, n_classes=2,
55+
n_clusters_per_class=1, random_state=0)
56+
# The mockup does not have partial_fit()
57+
estimator = MockImprovingClassifier(1)
58+
assert_raises(ValueError, learning_curve, estimator, X, y,
59+
exploit_incremental_learning=True)
60+
61+
62+
def test_n_sample_range_out_of_bounds():
63+
X, y = make_classification(n_samples=30, n_features=1, n_informative=1,
64+
n_redundant=0, n_classes=2,
65+
n_clusters_per_class=1, random_state=0)
66+
estimator = MockImprovingClassifier(20)
67+
assert_raises(ValueError, learning_curve, estimator, X, y, cv=3,
68+
n_samples_range=[0.0, 1.0])
69+
assert_raises(ValueError, learning_curve, estimator, X, y, cv=3,
70+
n_samples_range=[0.1, 1.1])
71+
assert_raises(ValueError, learning_curve, estimator, X, y, cv=3,
72+
n_samples_range=[0, 20])
73+
assert_raises(ValueError, learning_curve, estimator, X, y, cv=3,
74+
n_samples_range=[1, 21])
75+
76+
def test_remove_multiple_sample_sizes():
77+
X, y = make_classification(n_samples=3, n_features=1, n_informative=1,
78+
n_redundant=0, n_classes=2,
79+
n_clusters_per_class=1, random_state=0)
80+
estimator = MockImprovingClassifier(2)
81+
n_samples_range, _, _ = assert_warns(RuntimeWarning,
82+
learning_curve, estimator, X, y, cv=3,
83+
n_samples_range=np.linspace(0.33, 1.0, 3))
84+
assert_array_equal(n_samples_range, [1, 2])

0 commit comments

Comments
 (0)
0