6
6
from .metrics .scorer import _deprecate_loss_and_score_funcs
7
7
8
8
def learning_curve (estimator , X , y ,
9
- n_samples_range = np .arange (0.1 , 1.1 , 0.1 ), cv = None , scoring = None ,
9
+ n_samples_range = np .linspace (0.1 , 1.0 , 10 ), cv = None , scoring = None ,
10
10
n_jobs = 1 , verbose = False , random_state = None ):
11
11
""" TODO document me
12
12
Parameters
@@ -28,7 +28,7 @@ def learning_curve(estimator, X, y,
28
28
n_samples_range = np .asarray (n_samples_range )
29
29
n_min_required_samples = np .min (n_samples_range )
30
30
n_max_required_samples = np .max (n_samples_range )
31
- if np .issubdtype (n_samples_range .dtype , float ):
31
+ if np .issubdtype (n_samples_range .dtype , np . float ):
32
32
if n_min_required_samples <= 0.0 or n_max_required_samples > 1.0 :
33
33
raise ValueError ("n_samples_range must be within ]0, 1], "
34
34
"but is within [%f, %f]."
@@ -61,19 +61,22 @@ def learning_curve(estimator, X, y,
61
61
"does not." % estimator )
62
62
scorer = _deprecate_loss_and_score_funcs (scoring = scoring )
63
63
64
- scores = []
65
- for n_train_samples in n_samples_range :
66
- out = Parallel (
67
- # TODO set pre_dispatch parameter? what is it good for?
68
- n_jobs = n_jobs , verbose = verbose )(
69
- delayed (_fit_estimator )(
70
- estimator , X , y , train [:n_train_samples ], test , scorer ,
71
- verbose )
72
- for train , test in cv )
73
- scores .append (np .mean (out , axis = 0 ))
74
- scores = np .array (scores )
64
+ out = Parallel (
65
+ # TODO use pre_dispatch parameter? what is it good for?
66
+ n_jobs = n_jobs , verbose = verbose )(
67
+ delayed (_fit_estimator )(
68
+ estimator , X , y , train [:n_train_samples ], test , scorer ,
69
+ verbose )
70
+ for train , test in cv for n_train_samples in n_samples_range )
75
71
76
- return n_samples_range , scores [:, 0 ], scores [:, 1 ]
72
+ out = np .asarray (out )
73
+ train_scores = np .zeros (n_samples_range .shape , dtype = np .float )
74
+ test_scores = np .zeros (n_samples_range .shape , dtype = np .float )
75
+ for i , n_train_samples in enumerate (n_samples_range ):
76
+ res_indices = np .where (out [:, 0 ] == n_train_samples )
77
+ train_scores [i ], test_scores [i ] = out [res_indices [0 ], 1 :].mean (axis = 0 )
78
+
79
+ return n_samples_range , train_scores , test_scores
77
80
78
81
def _fit_estimator (base_estimator , X , y , train , test , scorer , verbose ):
79
82
# TODO similar to fit_grid_point from grid search, refactor
@@ -85,4 +88,4 @@ def _fit_estimator(base_estimator, X, y, train, test, scorer, verbose):
85
88
else :
86
89
train_score = scorer (estimator , X [train ], y [train ])
87
90
test_score = scorer (estimator , X [test ], y [test ])
88
- return train_score , test_score
91
+ return train . shape [ 0 ], train_score , test_score
0 commit comments