8000 Fit different training set sizes in parallel · scikit-learn/scikit-learn@8b8c7bb · GitHub
[go: up one dir, main page]

Skip to content

Commit 8b8c7bb

Browse files
Fit different training set sizes in parallel
1 parent 37fa56a commit 8b8c7bb

File tree

1 file changed

+18
-15
lines changed

1 file changed

+18
-15
lines changed

sklearn/learning_curve.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from .metrics.scorer import _deprecate_loss_and_score_funcs
77

88
def learning_curve(estimator, X, y,
9-
n_samples_range=np.arange(0.1, 1.1, 0.1), cv=None, scoring=None,
9+
n_samples_range=np.linspace(0.1, 1.0, 10), cv=None, scoring=None,
1010
n_jobs=1, verbose=False, random_state=None):
1111
""" TODO document me
1212
Parameters
@@ -28,7 +28,7 @@ def learning_curve(estimator, X, y,
2828
n_samples_range = np.asarray(n_samples_range)
2929
n_min_required_samples = np.min(n_samples_range)
3030
n_max_required_samples = np.max(n_samples_range)
31-
if np.issubdtype(n_samples_range.dtype, float):
31+
if np.issubdtype(n_samples_range.dtype, np.float):
3232
if n_min_required_samples <= 0.0 or n_max_required_samples > 1.0:
3333
raise ValueError("n_samples_range must be within ]0, 1], "
3434
"but is within [%f, %f]."
@@ -61,19 +61,22 @@ def learning_curve(estimator, X, y,
6161
"does not." % estimator)
6262
scorer = _deprecate_loss_and_score_funcs(scoring=scoring)
6363

64-
scores = []
65-
for n_train_samples in n_samples_range:
66-
out = Parallel(
67-
# TODO set pre_dispatch parameter? what is it good for?
68-
n_jobs=n_jobs, verbose=verbose)(
69-
delayed(_fit_estimator)(
70-
estimator, X, y, train[:n_train_samples], test, scorer,
71-
verbose)
72-
for train, test in cv)
73-
scores.append(np.mean(out, axis=0))
74-
scores = np.array(scores)
64+
out = Parallel(
65+
# TODO use pre_dispatch parameter? what is it good for?
66+
n_jobs=n_jobs, verbose=verbose)(
67+
delayed(_fit_estimator)(
68+
estimator, X, y, train[:n_train_samples], test, scorer,
69+
verbose)
70+
for train, test in cv for n_train_samples in n_samples_range)
7571

76-
return n_samples_range, scores[:, 0], scores[:, 1]
72+
out = np.asarray(out)
73+
train_scores = np.zeros(n_samples_range.shape, dtype=np.float)
74+
test_scores = np.zeros(n_samples_range.shape, dtype=np.float)
75+
for i, n_train_samples in enumerate(n_samples_range):
76+
res_indices = np.where(out[:, 0] == n_train_samples)
77+
train_scores[i], test_scores[i] = out[res_indices[0], 1:].mean(axis=0)
78+
79+
return n_samples_range, train_scores, test_scores
7780

7881
def _fit_estimator(base_estimator, X, y, train, test, scorer, verbose):
7982
# TODO similar to fit_grid_point from grid search, refactor
@@ -85,4 +88,4 @@ def _fit_estimator(base_estimator, X, y, train, test, scorer, verbose):
8588
else:
8689
train_score = scorer(estimator, X[train], y[train])
8790
test_score = scorer(estimator, X[test], y[test])
88-
return train_score, test_score
91+
return train.shape[0], train_score, test_score

0 commit comments

Comments
 (0)
0