8000 Resolved issue #6894 and #6895: · raghavrv/scikit-learn@57c4781 · GitHub
[go: up one dir, main page]

Skip to content

Commit 57c4781

Browse files
Eugene Chenraghavrv
Eugene Chen
authored andcommitted
Now *SearchCV.results_ includes both timing and training scores.
1 parent 9a12555 commit 57c4781

File tree

1 file changed

+51
-14
lines changed

1 file changed

+51
-14
lines changed

sklearn/model_selection/_search.py

Lines changed: 51 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,7 @@ class BaseSearchCV(six.with_metaclass(ABCMeta, BaseEstimator,
374374
def __init__(self, estimator, scoring=None,
375375
fit_params=None, n_jobs=1, iid=True,
376376
refit=True, cv=None, verbose=0, pre_dispatch='2*n_jobs',
377-
error_score='raise'):
377+
error_score='raise', return_train_score=False):
378378

379379
self.scoring = scoring
380380
self.estimator = estimator
@@ -386,6 +386,7 @@ def __init__(self, estimator, scoring=None,
386386
self.verbose = verbose
387387
self.pre_dispatch = pre_dispatch
388388
self.error_score = error_score
389+
self.return_train_score = return_train_score
389390

390391
@property
391392
def _estimator_type(self):
@@ -551,16 +552,28 @@ def _fit(self, X, y, groups, parameter_iterable):
551552
pre_dispatch=pre_dispatch
552553
)(delayed(_fit_and_score)(clone(base_estimator), X, y, self.scorer_,
553554
train, test, self.verbose, parameters,
554-
self.fit_params, return_parameters=True,
555+
self.fit_params,
556+
return_train_score=self.return_train_score,
557+
return_parameters=True,
555558
error_score=self.error_score)
556559
for parameters in parameter_iterable
557560
for train, test in cv.split(X, y, groups))
558561

559-
test_scores, test_sample_counts, _, parameters = zip(*out)
562+
# if one choose to see train score, out will have train score info.
563+
if self.return_train_score:
564+
train_scores, test_scores, test_sample_counts, _, parameters =\
565+
zip(*out)
566+
else:
567+
test_scores, test_sample_counts, _, parameters = zip(*out)
560568

561569
candidate_params = parameters[::n_splits]
562570
n_candidates = len(candidate_params)
563571

572+
# if one choose to return train score, reshape the train_scores array
573+
if self.return_train_score:
574+
train_scores = np.array(train_scores,
575+
dtype=np.float64).reshape(n_candidates,
576+
n_splits)
564577
test_scores = np.array(test_scores,
565578
dtype=np.float64).reshape(n_candidates,
566579
n_splits)
@@ -570,9 +583,21 @@ def _fit(self, X, y, groups, parameter_iterable):
570583

571584
# Computed the (weighted) mean and std for all the candidates
572585
weights = test_sample_counts if self.iid else None
573-
means = np.average(test_scores, axis=1, weights=weights)
574-
stds = np.sqrt(np.average((test_scores - means[:, np.newaxis]) ** 2,
575-
axis=1, weights=weights))
586+
587+
time = np.array(_, dtype=np.float64).reshape(n_candidates, n_splits)
588+
time_means = np.average(time, axis=1, weights=weights)
589+
time_stds = np.sqrt(
590+
np.average((time - time_means[:, np.newaxis]) ** 2,
591+
axis=1, weights=weights))
592+
if self.return_train_score:
593+
train_means = np.average(train_scores, axis=1, weights=weights)
594+
train_stds = np.sqrt(
595+
np.average((train_scores - train_means[:, np.newaxis]) ** 2,
596+
axis=1, weights=weights))
597+
test_means = np.average(test_scores, axis=1, weights=weights)
598+
test_stds = np.sqrt(
599+
np.average((test_scores - test_means[:, np.newaxis]) ** 2, axis=1,
600+
weights=weights))
576601

577602
cv_results = dict()
578603
for split_i in range(n_splits):
@@ -581,7 +606,17 @@ def _fit(self, X, y, groups, parameter_iterable):
581606
cv_results["mean_test_score"] = means
582607
cv_results["std_test_score"] = stds
583608

584-
ranks = np.asarray(rankdata(-means, method='min'), dtype=np.int32)
609+
if self.return_train_score:
610+
for split_i in range(n_splits):
611+
results["train_split%d_score" % split_i] = (
612+
train_scores[:, split_i])
613+
results["mean_train_score"] = train_means
614+
results["std_train_scores"] = train_stds
615+
616+
results["mean_test_time"] = time_means
617+
results["std_test_time"] = time_stds
618+
619+
ranks = np.asarray(rankdata(-test_means, method='min'), dtype=np.int32)
585620

586621
best_index = np.flatnonzero(ranks == 1)[0]
587622
best_parameters = candidate_params[best_index]
@@ -868,11 +903,13 @@ class GridSearchCV(BaseSearchCV):
868903

869904
def __init__(self, estimator, param_grid, scoring=None, fit_params=None,
870905
n_jobs=1, iid=True, refit=True, cv=None, verbose=0,
871-
pre_dispatch='2*n_jobs', error_score='raise'):
906+
pre_dispatch='2*n_jobs', error_score='raise',
907+
return_train_score=False):
872908
super(GridSearchCV, self).__init__(
873909
estimator=estimator, scoring=scoring, fit_params=fit_params,
874910
n_jobs=n_jobs, iid=iid, refit=refit, cv=cv, verbose=verbose,
875-
pre_dispatch=pre_dispatch, error_score=error_score)
911+
pre_dispatch=pre_dispatch, error_score=error_score,
912+
return_train_score=return_train_score)
876913
self.param_grid = param_grid
877914
_check_param_grid(param_grid)
878915

@@ -1094,15 +1131,15 @@ class RandomizedSearchCV(BaseSearchCV):
10941131
def __init__(self, estimator, param_distributions, n_iter=10, scoring=None,
10951132
fit_params=None, n_jobs=1, iid=True, refit=True, cv=None,
10961133
verbose=0, pre_dispatch='2*n_jobs', random_state=None,
1097-
error_score='raise'):
1098-
1134+
error_score='raise', return_train_score=False):
10991135
self.param_distributions = param_distributions
11001136
self.n_iter = n_iter
11011137
self.random_state = random_state
11021138
super(RandomizedSearchCV, self).__init__(
1103-
estimator=estimator, scoring=scoring, fit_params=fit_params,
1104-
n_jobs=n_jobs, iid=iid, refit=refit, cv=cv, verbose=verbose,
1105-
pre_dispatch=pre_dispatch, error_score=error_score)
1139+
estimator=estimator, scoring=scoring, fit_params=fit_params,
1140+
n_jobs=n_jobs, iid=iid, refit=refit, cv=cv, verbose=verbose,
1141+
pre_dispatch=pre_dispatch, error_score=error_score,
1142+
return_train_score=return_train_score)
11061143

11071144
def fit(self, X, y=None, groups=None):
11081145
"""Run fit on the estimator with randomly drawn parameters.

0 commit comments

Comments
 (0)
0