8000 Merge `fit_grid_point` into `_cross_val_score` · scikit-learn/scikit-learn@4b5f468 · GitHub
[go: up one dir, main page]

Skip to content

Commit 4b5f468

Browse files
Merge fit_grid_point into _cross_val_score
1 parent 5e52031 commit 4b5f468

File tree

2 files changed

+38
-38
lines changed

2 files changed

+38
-38
lines changed

sklearn/cross_validation.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from itertools import chain, combinations
1616
from math import ceil, floor, factorial
1717
import numbers
18+
import time
1819
from abc import ABCMeta, abstractmethod
1920

2021
import numpy as np
@@ -24,7 +25,7 @@
2425
from .utils import check_arrays, check_random_state, safe_mask
2526
from .utils.validation import _num_samples
2627
from .utils.fixes import unique
27-
from .externals.joblib import Parallel, delayed
28+
from .externals.joblib import Parallel, delayed, logger
2829
from .externals.six import string_types, with_metaclass
2930
from .metrics.scorer import _deprecate_loss_and_score_funcs
3031

@@ -1095,17 +1096,30 @@ def cross_val_score(estimator, X, y=None, scoring=None, cv=None, n_jobs=1,
10951096
pre_dispatch=pre_dispatch)
10961097
scores = parallel(
10971098
delayed(_cross_val_score)(clone(estimator), X, y, scorer, train, test,
1098-
verbose, fit_params)
1099+
parameters=None, verbose< 10000 /span>=verbose,
1100+
fit_params=fit_params,
1101+
log_label="cross_val_score")
10991102
for train, test in cv)
1100-
return np.array(scores)
1103+
return np.array(scores)[:, 0]
11011104

11021105

1103-
def _cross_val_score(estimator, X, y, scorer, train, test, verbose,
1104-
fit_params):
1106+
def _cross_val_score(estimator, X, y, scorer, train, test, parameters, verbose,
1107+
fit_params, log_label):
11051108
"""Inner loop for cross validation"""
1109+
if parameters is not None:
1110+
estimator.set_params(**parameters)
1111+
if verbose > 1:
1112+
start_time = time.time()
1113+
if parameters is None:
1114+
msg = "Evaluating..."
1115+
else:
1116+
msg = '%s' % (', '.join('%s=%s' % (k, v)
1117+
for k, v in parameters.items()))
1118+
print("[%s] %s %s" % (log_label, msg, (64 - len(msg)) * '.'))
1119+
11061120
n_samples = _num_samples(X)
11071121
fit_params = fit_params if fit_params is not None else {}
1108-
fit_params = dict([(k, np.asarray(v)[train] # TODO why is this necessary?
1122+
fit_params = dict([(k, np.asarray(v)[train]
11091123
if hasattr(v, '__len__') and len(v) == n_samples else v)
11101124
for k, v in fit_params.items()])
11111125

@@ -1114,9 +1128,14 @@ def _cross_val_score(estimator, X, y, scorer, train, test, verbose,
11141128
_fit(estimator.fit, X_train, y_train, **fit_params)
11151129
score = _score(estimator, X_test, y_test, scorer)
11161130

1131+
if verbose > 2:
1132+
msg += ", score=%f" % score
11171133
if verbose > 1:
1118-
print("score: %f" % score)
1119-
return score
1134+
end_msg = "%s -%s" % (msg, logger.short_format_time(time.time() -
1135+
start_time))
1136+
print("[%s] %s %s" % (log_label, (64 - len(end_msg)) * '.', end_msg))
1137+
1138+
return score, _num_samples(X_test)
11201139

11211140

11221141
def _split(estimator, X, y, indices, train_indices=None):

sklearn/grid_search.py

Lines changed: 11 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,15 @@
1616
from itertools import product
1717
import numbers
1818
import operator
19-
import time
2019
import warnings
2120

2221
import numpy as np
2322

2423
from .base import BaseEstimator, is_classifier, clone
2524
from .base import MetaEstimatorMixin
2625
from .cross_validation import _check_cv as check_cv
27-
from .cross_validation import _check_scorable, _split, _fit, _score
28-
from .externals.joblib import Parallel, delayed, logger
26+
from .cross_validation import _check_scorable, _cross_val_score
27+
from .externals.joblib import Parallel, delayed
2928
from .externals import six
3029
from .utils import safe_mask, check_random_state
3130
from .utils.validation import _num_samples, check_arrays
@@ -184,7 +183,7 @@ def __len__(self):
184183
return self.n_iter
185184

186185

187-
def fit_grid_point(X, y, base_estimator, parameters, train, test, scorer,
186+
def fit_grid_point(X, y, estimator, parameters, train, test, scorer,
188187
verbose, loss_func=None, **fit_params):
189188
"""Run fit on one set of parameters.
190189
@@ -196,11 +195,11 @@ def fit_grid_point(X, y, base_estimator, parameters, train, test, scorer,
196195
y : array-like or None
197196
Targets for input data.
198197
199-
base_estimator : estimator object
198+
estimator : estimator object
200199
This estimator will be cloned and then fitted.
201200
202201
parameters : dict
203-
Parameters to be set on base_estimator clone for this grid point.
202+
Parameters to be set on estimator for this grid point.
204203
205204
train : ndarray, dtype int or bool
206205
Boolean mask or indices for training set.
@@ -230,29 +229,11 @@ def fit_grid_point(X, y, base_estimator, parameters, train, test, scorer,
230229
n_samples_test : int
231230
Number of test samples in this split.
232231
"""
233-
if verbose > 1:
234-
start_time = time.time()
235-
msg = '%s' % (', '.join('%s=%s' % (k, v)
236-
for k, v in parameters.items()))
237-
print("[GridSearchCV] %s %s" % (msg, (64 - len(msg)) * '.'))
238-
239-
# update parameters of the classifier after a copy of its base structure
240-
estimator = clone(base_estimator)
241-
estimator.set_params(**parameters)
242-
243-
X_train, y_train = _split(estimator, X, y, train)
244-
X_test, y_test = _split(estimator, X, y, test, train)
245-
_fit(estimator.fit, X_train, y_train, **fit_params)
246-
score = _score(estimator, X_test, y_test, scorer)
247-
248-
if verbose > 2:
249-
msg += ", score=%f" % score
250-
if verbose > 1:
251-
end_msg = "%s -%s" % (msg, logger.short_format_time(time.time() -
252-
start_time))
253-
print("[GridSearchCV] %s %s" % ((64 - len(end_msg)) * '.', end_msg))
254-
255-
return score, parameters, _num_samples(X_test)
232+
score, n_samples_test = _cross_val_score(estimator, X, y, scorer, train,
233+
test, parameters, verbose,
234+
fit_params,
235+
log_label="GridSearchCV")
236+
return score, parameters, n_samples_test
256237

257238

258239
def _check_param_grid(param_grid):
@@ -397,7 +378,7 @@ def _fit(self, X, y, parameter_iterable):
397378
n_jobs=self.n_jobs, verbose=self.verbose,
398379
pre_dispatch=pre_dispatch)(
399380
delayed(fit_grid_point)(
400-
X, y, base_estimator, parameters, train, test,
381+
X, y, clone(base_estimator), parameters, train, test,
401382
self.scorer_, self.verbose, **self.fit_params)
402383
for parameters in parameter_iterable
403384
for train, test in cv)

0 commit comments

Comments
 (0)
0