8000 Timing and training score in *SearchCV.results_ by eyc88 · Pull Request #7026 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

Timing and training score in *SearchCV.results_ #7026

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 84 additions & 23 deletions sklearn/model_selection/_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ class BaseSearchCV(six.with_metaclass(ABCMeta, BaseEstimator,
def __init__(self, estimator, scoring=None,
fit_params=None, n_jobs=1, iid=True,
refit=True, cv=None, verbose=0, pre_dispatch='2*n_jobs',
error_score='raise'):
error_score='raise', return_train_score=True):

self.scoring = scoring
self.estimator = estimator
Expand All @@ -383,6 +383,7 @@ def __init__(self, estimator, scoring=None,
self.verbose = verbose
self.pre_dispatch = pre_dispatch
self.error_score = error_score
self.return_train_score = return_train_score

@property
def _estimator_type(self):
Expand Down Expand Up @@ -533,36 +534,70 @@ def _fit(self, X, y, labels, parameter_iterable):
pre_dispatch=pre_dispatch
)(delayed(_fit_and_score)(clone(base_estimator), X, y, self.scorer_,
train, test, self.verbose, parameters,
self.fit_params, return_parameters=True,
self.fit_params,
return_train_score=self.return_train_score,
return_parameters=True,
error_score=self.error_score)
for parameters in parameter_iterable
for train, test in cv.split(X, y, labels))

test_scores, test_sample_counts, _, parameters = zip(*out)
# if one choose to see train score, "out" will contain train score info
if self.return_train_score:
train_scores, test_scores, test_sample_counts, time, parameters =\
zip(*out)
else:
test_scores, test_sample_counts, time, parameters = zip(*out)

candidate_params = parameters[::n_splits]
n_candidates = len(candidate_params)

# if one choose to return train score, reshape the train_scores array
if self.return_train_score:
train_scores = np.array(train_scores,
dtype=np.float64).reshape(n_candidates,
n_splits)
test_scores = np.array(test_scores,
dtype=np.float64).reshape(n_candidates,
n_splits)
# NOTE test_sample counts (weights) remain the same for all candidates
test_sample_counts = np.array(test_sample_counts[:n_splits],
dtype=np.int)

# Computed the (weighted) mean and std for all the candidates
# Computed the (weighted) mean and std for test scores
weights = test_sample_counts if self.iid else None
means = np.average(test_scores, axis=1, weights=weights)
stds = np.sqrt(np.average((test_scores - means[:, np.newaxis]) ** 2,
axis=1, weights=weights))
test_means = np.average(test_scores, axis=1, weights=weights)
test_stds = np.sqrt(
np.average((test_scores - test_means[:, np.newaxis]) ** 2, axis=1,
weights=weights))

time = np.array(time, dtype=np.float64).reshape(n_candidates, n_splits)
< 8000 /td> time_means = np.average(time, axis=1)
time_stds = np.sqrt(
np.average((time - time_means[:, np.newaxis]) ** 2,
axis=1))
if self.return_train_score:
train_means = np.average(train_scores, axis=1)
train_stds = np.sqrt(
np.average((train_scores - train_means[:, np.newaxis]) ** 2,
axis=1))

results = dict()
for split_i in range(n_splits):
results["test_split%d_score" % split_i] = test_scores[:, split_i]
results["test_mean_score"] = means
results["test_std_score"] = stds
results["test_mean_score"] = test_means
results["test_std_score"] = test_stds

if self.return_train_score:
for split_i in range(n_splits):
results["train_split%d_score" % split_i] =\
train_scores[:, split_i]
results["train_mean_score"] = train_means
results["train_std_score"] = train_stds

ranks = np.asarray(rankdata(-means, method='min'), dtype=np.int32)
results["test_mean_time"] = time_means
results["test_std_time"] = time_stds

ranks = np.asarray(rankdata(-test_means, method='min'), dtype=np.int32)

best_index = np.flatnonzero(ranks == 1)[0]
best_parameters = candidate_params[best_index]
Expand Down Expand Up @@ -726,6 +761,10 @@ class GridSearchCV(BaseSearchCV):
FitFailedWarning is raised. This parameter does not affect the refit
step, which will always raise the error.

return_train_score: boolean, default=True
If ``'False'``, the results_ attribute will not include training
scores.


Examples
--------
Expand All @@ -744,13 +783,13 @@ class GridSearchCV(BaseSearchCV):
random_state=None, shrinking=True, tol=...,
verbose=False),
fit_params={}, iid=..., n_jobs=1,
param_grid=..., pre_dispatch=..., refit=...,
param_grid=..., pre_dispatch=..., refit=..., return_train_score=...,
scoring=..., verbose=...)
>>> sorted(clf.results_.keys())
... # doctest: +NORMALIZE_WHITESPACE +ELLIPSIS
['param_C', 'param_kernel', 'params', 'test_mean_score',...
'test_rank_score', 'test_split0_score', 'test_split1_score',...
'test_split2_score', 'test_std_score']
['param_C', 'param_kernel', 'params', 'test_mean_score', 'test_mean_time',
'test_rank_score', 'test_split0_score', 'test_split1_score',
'test_split2_score', 'test_std_score', ...]

Attributes
----------
Expand Down Expand Up @@ -784,13 +823,21 @@ class GridSearchCV(BaseSearchCV):
'test_split0_score' : [0.8, 0.7, 0.8, 0.9],
'test_split1_score' : [0.82, 0.5, 0.7, 0.78],
'test_mean_score' : [0.81, 0.60, 0.75, 0.82],
'train_split0_score': [0.9, 0.8, 0.85, 1.]
'train_split1_score': [0.95, 0.7, 0.8, 0.8]
'train_mean_score' : [0.93, 0.75, 0.83, 0.9]
'test_mean_time' : [0.00073, 0.00063, 0.00043, 0.00049]
'test_std_time' : [1.62e-4, 3.37e-5, 1.42e-5, 1.1e-5]
'test_std_score' : [0.02, 0.01, 0.03, 0.03],
'test_rank_score' : [2, 4, 3, 1],
...
'params' : [{'kernel': 'poly', 'degree': 2}, ...],
}

NOTE that the key ``'params'`` is used to store a list of parameter
settings dict for all the parameter candidates.
settings dict for all the parameter candidates. Besides,
``'train_mean_score'``, ``'train_split*_score'``, ... will be present
when return_train_score=True.

best_estimator_ : estimator
Estimator that was chosen by the search, i.e. estimator
Expand Down Expand Up @@ -848,11 +895,13 @@ class GridSearchCV(BaseSearchCV):

def __init__(self, estimator, param_grid, scoring=None, fit_params=None,
n_jobs=1, iid=True, refit=True, cv=None, verbose=0,
pre_dispatch='2*n_jobs', error_score='raise'):
pre_dispatch='2*n_jobs', error_score='raise',
return_train_score=False):
super(GridSearchCV, self).__init__(
estimator=estimator, scoring=scoring, fit_params=fit_params,
n_jobs=n_jobs, iid=iid, refit=refit, cv=cv, verbose=verbose,
pre_dispatch=pre_dispatch, error_score=error_score)
pre_dispatch=pre_dispatch, error_score=error_score,
return_train_score=return_train_score)
self.param_grid = param_grid
_check_param_grid(param_grid)

Expand Down Expand Up @@ -986,6 +1035,10 @@ class RandomizedSearchCV(BaseSearchCV):
FitFailedWarning is raised. This parameter does not affect the refit
step, which will always raise the error.

return_train_score: boolean, default=True
If ``'False'``, the results_ attribute will not include training
scores.

Attributes
----------
results_ : dict of numpy (masked) ndarrays
Expand Down Expand Up @@ -1013,13 +1066,21 @@ class RandomizedSearchCV(BaseSearchCV):
'test_split0_score' : [0.8, 0.9, 0.7],
'test_split1_score' : [0.82, 0.5, 0.7],
'test_mean_score' : [0.81, 0.7, 0.7],
'train_split0_score': [0.9, 0.8, 0.85]
'train_split1_score': [0.95, 0.7, 0.8]
'train_mean_score' : [0.93, 0.75, 0.83]
'test_mean_time' : [0.00073, 0.00063, 0.00043]
'test_std_time' : [1.62e-4, 3.37e-5, 1.1e-5]
'test_std_score' : [0.02, 0.2, 0.],
'test_rank_score' : [3, 1, 1],
...
'params' : [{'kernel' : 'rbf', 'gamma' : 0.1}, ...],
}

NOTE that the key ``'params'`` is used to store a list of parameter
settings dict for all the parameter candidates.
settings dict for all the para 8000 meter candidates. Besides,
'train_mean_score', 'train_split*_score', ... will be present when
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Place this under the return_train_score description.

return_train_score is set to True.

best_estimator_ : estimator
Estimator that was chosen by the search, i.e. estimator
Expand Down Expand Up @@ -1074,15 +1135,15 @@ class RandomizedSearchCV(BaseSearchCV):
def __init__(self, estimator, param_distributions, n_iter=10, scoring=None,
fit_params=None, n_jobs=1, iid=True, refit=True, cv=None,
verbose=0, pre_dispatch='2*n_jobs', random_state=None,
error_score='raise'):

error_score='raise', return_train_score=False):
self.param_distributions = param_distributions
self.n_iter = n_iter
self.random_state = random_state
super(RandomizedSearchCV, self).__init__(
estimator=estimator, scoring=scoring, fit_params=fit_params,
n_jobs=n_jobs, iid=iid, refit=refit, cv=cv, verbose=verbose,
pre_dispatch=pre_dispatch, error_score=error_score)
estimator=estimator, scoring=scoring, fit_params=fit_params,
n_jobs=n_jobs, iid=iid, refit=refit, cv=cv, verbose=verbose,
pre_dispatch=pre_dispatch, error_score=error_score,
return_train_score=return_train_score)

def fit(self, X, y=None, labels=None):
"""Run fit on the estimator with randomly drawn parameters.
Expand Down
33 changes: 23 additions & 10 deletions sklearn/model_selection/tests/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,21 +602,30 @@ def test_grid_search_results():
params = [dict(kernel=['rbf', ], C=[1, 10], gamma=[0.1, 1]),
dict(kernel=['poly', ], degree=[1, 2])]
grid_search = GridSearchCV(SVC(), cv=n_folds, iid=False,
param_grid=params)
param_grid=params, return_train_score=True)
grid_search.fit(X, y)
grid_search_iid = GridSearchCV(SVC(), cv=n_folds, iid=True,
param_grid=params)
param_grid=params, return_train_score=True)
grid_search_iid.fit(X, y)

param_keys = ('param_C', 'param_degree', 'param_gamma', 'param_kernel')
score_keys = ('test_mean_score', 'test_rank_score',
'test_split0_score', 'test_split1_score',
'test_split2_score', 'test_std_score')
score_keys = ('test_mean_score', 'train_mean_score', 'test_mean_time',
'test_rank_score', 'test_split0_score', 'test_split1_score',
'test_split2_score', 'train_split0_score',
'train_split1_score', 'train_split2_score',
'test_std_score', 'train_std_score', 'test_std_time')
n_candidates = n_grid_points

for search, iid in zip((grid_search, grid_search_iid), (False, True)):
assert_equal(iid, search.iid)
results = search.results_
# Check if score and timing are reasonable
assert_true(all(results['test_rank_score'] >= 1))
assert_true(all(results[k] >= 0) for k in score_keys
if k is not 'test_rank_score')
assert_true(all(results[k] <= 1) for k in score_keys
if not k.endswith('time') and
k is not 'test_rank_score')
# Check results structure
check_results_array_types(results, param_keys, score_keys)
check_results_keys(results, param_keys, score_keys, n_candidates)
Expand Down Expand Up @@ -649,17 +658,21 @@ def test_random_search_results():
n_search_iter = 30
params = dict(C=expon(scale=10), gamma=expon(scale=0.1))
random_search = RandomizedSearchCV(SVC(), n_iter=n_search_iter, cv=n_folds,
iid=False, param_distributions=params)
iid=False, param_distributions=params,
return_train_score=True)
random_search.fit(X, y)
random_search_iid = RandomizedSearchCV(SVC(), n_iter=n_search_iter,
cv=n_folds, iid=True,
param_distributions=params)
param_distributions=params,
return_train_score=True)
random_search_iid.fit(X, y)

param_keys = ('param_C', 'param_gamma')
score_keys = ('test_mean_score', 'test_rank_score',
'test_split0_score', 'test_split1_score',
'test_split2_score', 'test_std_score')
score_keys = ('test_mean_score', 'train_mean_score', 'test_mean_time',
'test_rank_score', 'test_split0_score', 'test_split1_score',
'test_split2_score', 'train_split0_score',
'train_split1_score', 'train_split2_score',
'test_std_score', 'train_std_score', 'test_std_time')
n_cand = n_search_iter

for search, iid in zip((random_search, random_search_iid), (False, True)):
Expand Down
0