8000 [MRG+2] Timing and training score in GridSearchCV by raghavrv · Pull Request #7325 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

[MRG+2] Timing and training score in GridSearchCV #7325

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Sep 27, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,20 @@ Model Selection Enhancements and API Changes
The parameter ``n_labels`` in the newly renamed
:class:`model_selection.LeavePGroupsOut` is changed to ``n_groups``.

- Training scores and Timing information

``cv_results_`` also includes the training scores for each
cross-validation split (with keys such as ``'split0_train_score'``), as
well as their mean (``'mean_train_score'``) and standard deviation
(``'std_train_score'``). To avoid the cost of evaluating training score,
set ``return_train_score=False``.

Additionally the mean and standard deviation of the times taken to split,
train and score the model across all the cross-validation splits is
available at the key ``'mean_time'`` and ``'std_time'`` respectively.

Changelog
---------

New features
............
Expand Down Expand Up @@ -349,6 +363,12 @@ Enhancements
now accept arbitrary kernel functions in addition to strings ``knn`` and ``rbf``.
(`#5762 <https://github.com/scikit-learn/scikit-learn/pull/5762>`_) By `Utkarsh Upadhyay`_.

- The training scores and time taken for training followed by scoring for
each search candidate are now available at the ``cv_results_`` dict.
See :ref:`model_selection_changes` for more information.
(`#7324 <https://github.com/scikit-learn/scikit-learn/pull/7325>`)
By `Eugene Chen`_ and `Raghav RV`_.


Bug fixes
.........
Expand Down Expand Up @@ -4651,3 +4671,5 @@ David Huard, Dave Morrill, Ed Schofield, Travis Oliphant, Pearu Peterson.
.. _Russell Smith: https://github.com/rsmith54

.. _Utkarsh Upadhyay: https://github.com/musically-ut

.. _Eugene Chen: https://github.com/eyc88
154 changes: 106 additions & 48 deletions sklearn/model_selection/_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,9 @@ def fit_grid_point(X, y, estimator, parameters, train, test, scorer,
"""
score, n_samples_test, _ = _fit_and_score(estimator, X, y, scorer, train,
test, verbose, parameters,
fit_params, error_score)
fit_params=fit_params,
return_n_test_samples=True,
error_score=error_score)
return score, parameters, n_samples_test


Expand Down Expand Up @@ -374,7 +376,7 @@ class BaseSearchCV(six.with_metaclass(ABCMeta, BaseEstimator,
def __init__(self, estimator, scoring=None,
fit_params=None, n_jobs=1, iid=True,
refit=True, cv=None, verbose=0, pre_dispatch='2*n_jobs',
error_score='raise'):
error_score='raise', return_train_score=True):

self.scoring = scoring
self.estimator = estimator
Expand All @@ -386,6 +388,7 @@ def __init__(self, estimator, scoring=None,
self.verbose = verbose
self.pre_dispatch = pre_dispatch
self.error_score = error_score
self.return_train_score = return_train_score

@property
def _estimator_type(self):
Expand Down Expand Up @@ -551,41 +554,61 @@ def _fit(self, X, y, groups, parameter_iterable):
pre_dispatch=pre_dispatch
)(delayed(_fit_and_score)(clone(base_estimator), X, y, self.scorer_,
train, test, self.verbose, parameters,
self.fit_params, return_parameters=True,
fit_params=self.fit_params,
return_train_score=self.return_train_score,
return_n_test_samples=True,
return_times=True, return_parameters=True,
error_score=self.error_score)
for parameters in parameter_iterable
for train, test in cv.split(X, y, groups))

test_scores, test_sample_counts, _, parameters = zip(*out)
# if one choose to see train score, "out" will contain train score info
if self.return_train_score:
(train_scores, test_scores, test_sample_counts,
fit_time, score_time, parameters) = zip(*out)
else:
(test_scores, test_sample_counts,
fit_time, score_time, parameters) = zip(*out)

candidate_params = parameters[::n_splits]
n_candidates = len(candidate_params)

test_scores = np.array(test_scores,
dtype=np.float64).reshape(n_candidates,
n_splits)
results = dict()

def _store(key_name, array, weights=None, splits=False, rank=False):
"""A small helper to store the scores/times to the cv_results_"""
array = np.array(array, dtype=np.float64).reshape(n_candidates,
n_splits)
if splits:
for split_i in range(n_splits):
results["split%d_%s"
% (split_i, key_name)] = array[:, split_i]

array_means = np.average(array, axis=1, weights=weights)
results['mean_%s' % key_name] = array_means
# Weighted std is not directly available in numpy
array_stds = np.sqrt(np.average((array -
array_means[:, np.newaxis]) ** 2,
axis=1, weights=weights))
results['std_%s' % key_name] = array_stds

if rank:
results["rank_%s" % key_name] = np.asarray(
rankdata(-array_means, method='min'), dtype=np.int32)

# Computed the (weighted) mean and std for test scores alone
# NOTE test_sample counts (weights) remain the same for all candidates
test_sample_counts = np.array(test_sample_counts[:n_splits],
dtype=np.int)

# Computed the (weighted) mean and std for all the candidates
weights = test_sample_counts if self.iid else None
means = np.average(test_scores, axis=1, weights=weights)
stds = np.sqrt(np.average((test_scores - means[:, np.newaxis]) ** 2,
axis=1, weights=weights))

cv_results = dict()
for split_i in range(n_splits):
cv_results["split%d_test_score" % split_i] = test_scores[:,
split_i]
cv_results["mean_test_score"] = means
cv_results["std_test_score"] = stds

ranks = np.asarray(rankdata(-means, method='min'), dtype=np.int32)
_store('test_score', test_scores, splits=True, rank=True,
weights=test_sample_counts if self.iid else None)
_store('train_score', train_scores, splits=True)
_store('fit_time', fit_time)
_store('score_time', score_time)

best_index = np.flatnonzero(ranks == 1)[0]
best_index = np.flatnonzero(results["rank_test_score"] == 1)[0]
best_parameters = candidate_params[best_index]
cv_results["rank_test_score"] = ranks

# Use one np.MaskedArray and mask all the places where the param is not
# applicable for that candidate. Use defaultdict as each candidate may
Expand All @@ -599,12 +622,12 @@ def _fit(self, X, y, groups, parameter_iterable):
# Setting the value at an index also unmasks that index
param_results["param_%s" % name][cand_i] = value

cv_results.update(param_results)
results.update(param_results)

# Store a list of param dicts at the key 'params'
cv_results['params'] = candidate_params
results['params'] = candidate_params

self.cv_results_ = cv_results
self.cv_results_ = results
self.best_index_ = best_index
self.n_splits_ = n_splits

Expand Down Expand Up @@ -746,6 +769,10 @@ class GridSearchCV(BaseSearchCV):
FitFailedWarning is raised. This parameter does not affect the refit
step, which will always raise the error.

return_train_score : boolean, default=True
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

versionadded?

If ``'False'``, the ``cv_results_`` attribute will not include training
scores.


Examples
--------
Expand All @@ -764,13 +791,16 @@ class GridSearchCV(BaseSearchCV):
random_state=None, shrinking=True, tol=...,
verbose=False),
fit_params={}, iid=..., n_jobs=1,
param_grid=..., pre_dispatch=..., refit=...,
param_grid=..., pre_dispatch=..., refit=..., return_train_score=...,
scoring=..., verbose=...)
>>> sorted(clf.cv_results_.keys())
... # doctest: +NORMALIZE_WHITESPACE +ELLIPSIS
['mean_test_score', 'param_C', 'param_kernel', 'params',...
'rank_test_score', 'split0_test_score', 'split1_test_score',...
'split2_test_score', 'std_test_score']
['mean_fit_time', 'mean_score_time', 'mean_test_score',...
'mean_train_score', 'param_C', 'param_kernel', 'params',...
'rank_test_score', 'split0_test_score',...
'split0_train_score', 'split1_test_score', 'split1_train_score',...
'split2_test_score', 'split2_train_score',...
'std_fit_time', 'std_score_time', 'std_test_score', 'std_train_score'...]

Attributes
----------
Expand Down Expand Up @@ -801,17 +831,28 @@ class GridSearchCV(BaseSearchCV):
mask = [ True True False False]...),
'param_degree': masked_array(data = [2.0 3.0 -- --],
mask = [False False True True]...),
'split0_test_score' : [0.8, 0.7, 0.8, 0.9],
'split1_test_score' : [0.82, 0.5, 0.7, 0.78],
'mean_test_score' : [0.81, 0.60, 0.75, 0.82],
'std_test_score' : [0.02, 0.01, 0.03, 0.03],
'rank_test_score' : [2, 4, 3, 1],
'params' : [{'kernel': 'poly', 'degree': 2}, ...],
'split0_test_score' : [0.8, 0.7, 0.8, 0.9],
'split1_test_score' : [0.82, 0.5, 0.7, 0.78],
'mean_test_score' : [0.81, 0.60, 0.75, 0.82],
'std_test_score' : [0.02, 0.01, 0.03, 0.03],
'rank_test_score' : [2, 4, 3, 1],
'split0_train_score' : [0.8, 0.9, 0.7],
'split1_train_score' : [0.82, 0.5, 0.7],
'mean_train_score' : [0.81, 0.7, 0.7],
'std_train_score' : [0.03, 0.03, 0.04],
'mean_fit_time' : [0.73, 0.63, 0.43, 0.49],
'std_fit_time' : [0.01, 0.02, 0.01, 0.01],
'mean_score_time' : [0.007, 0.06, 0.04, 0.04],
'std_score_time' : [0.001, 0.002, 0.003, 0.005],
'params' : [{'kernel': 'poly', 'degree': 2}, ...],
}

NOTE that the key ``'params'`` is used to store a list of parameter
settings dict for all the parameter candidates.

The ``mean_fit_time``, ``std_fit_time``, ``mean_score_time`` and
``std_score_time`` are all in seconds.

best_estimator_ : estimator
Estimator that was chosen by the search, i.e. estimator
which gave highest score (or smallest loss if specified)
Expand Down Expand Up @@ -868,11 +909,13 @@ class GridSearchCV(BaseSearchCV):

def __init__(self, estimator, param_grid, scoring=None, fit_params=None,
n_jobs=1, iid=True, refit=True, cv=None, verbose=0,
pre_dispatch='2*n_jobs', error_score='raise'):
pre_dispatch='2*n_jobs', error_score='raise',
return_train_score=True):
super(GridSearchCV, self).__init__(
estimator=estimator, scoring=scoring, fit_params=fit_params,
n_jobs=n_jobs, iid=iid, refit=refit, cv=cv, verbose=verbose,
pre_dispatch=pre_dispatch, error_score=error_score)
pre_dispatch=pre_dispatch, error_score=error_score,
return_train_score=return_train_score)
self.param_grid = param_grid
_check_param_grid(param_grid)

Expand Down Expand Up @@ -1006,6 +1049,10 @@ class RandomizedSearchCV(BaseSearchCV):
FitFailedWarning is raised. This parameter does not affect the refit
step, which will always raise the error.

return_train_score : boolean, default=True
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

versionadded?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since model_selection is not release yet, do we need a versionadded?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch, whoops

If ``'False'``, the ``cv_results_`` attribute will not include training
scores.

Attributes
----------
cv_results_ : dict of numpy (masked) ndarrays
Expand All @@ -1030,17 +1077,28 @@ class RandomizedSearchCV(BaseSearchCV):
'param_kernel' : masked_array(data = ['rbf', rbf', 'rbf'],
mask = False),
'param_gamma' : masked_array(data = [0.1 0.2 0.3], mask = False),
'split0_test_score' : [0.8, 0.9, 0.7],
'split1_test_score' : [0.82, 0.5, 0.7],
'mean_test_score' : [0.81, 0.7, 0.7],
'std_test_score' : [0.02, 0.2, 0.],
'rank_test_score' : [3, 1, 1],
'split0_test_score' : [0.8, 0.9, 0.7],
'split1_test_score' : [0.82, 0.5, 0.7],
'mean_test_score' : [0.81, 0.7, 0.7],
'std_test_score' : [0.02, 0.2, 0.],
'rank_test_score' : [3, 1, 1],
'split0_train_score' : [0.8, 0.9, 0.7],
'split1_train_score' : [0.82, 0.5, 0.7],
'mean_train_score' : [0.81, 0.7, 0.7],
'std_train_score' : [0.03, 0.03, 0.04],
'mean_fit_time' : [0.73, 0.63, 0.43, 0.49],
'std_fit_time' : [0.01, 0.02, 0.01, 0.01],
'mean_score_time' : [0.007, 0.06, 0.04, 0.04],
'std_score_time' : [0.001, 0.002, 0.003, 0.005],
'params' : [{'kernel' : 'rbf', 'gamma' : 0.1}, ...],
}

NOTE that the key ``'params'`` is used to store a list of parameter
settings dict for all the parameter candidates.

The ``mean_fit_time``, ``std_fit_time``, ``mean_score_time`` and
``std_score_time`` are all in seconds.

best_estimator_ : estimator
Estimator that was chosen by the search, i.e. estimator
which gave highest score (or smallest loss if specified)
Expand Down Expand Up @@ -1094,15 +1152,15 @@ class RandomizedSearchCV(BaseSearchCV):
def __init__(self, estimator, param_distributions, n_iter=10, scoring=None,
fit_params=None, n_jobs=1, iid=True, refit=True, cv=None,
verbose=0, pre_dispatch='2*n_jobs', random_state=None,
error_score='raise'):

error_score='raise', return_train_score=True):
self.param_distributions = param_distributions
self.n_iter = n_iter
self.random_state = random_state
super(RandomizedSearchCV, self).__init__(
estimator=estimator, scoring=scoring, fit_params=fit_params,
n_jobs=n_jobs, iid=iid, refit=refit, cv=cv, verbose=verbose,
pre_dispatch=pre_dispatch, error_score=error_score)
estimator=estimator, scoring=scoring, fit_params=fit_params,
n_jobs=n_jobs, iid=iid, refit=refit, cv=cv, verbose=verbose,
pre_dispatch=pre_dispatch, error_score=error_score,
return_train_score=return_train_score)

def fit(self, X, y=None, groups=None):
"""Run fit on the estimator with randomly drawn parameters.
Expand Down
Loading
0