16
16
from itertools import product
17
17
import numbers
18
18
import operator
19
- import time
20
19
import warnings
21
20
22
21
import numpy as np
23
22
24
23
from .base import BaseEstimator , is_classifier , clone
25
24
from .base import MetaEstimatorMixin
26
25
from .cross_validation import _check_cv as check_cv
27
- from .cross_validation import _check_scorable , _split , _fit , _score
28
- from .externals .joblib import Parallel , delayed , logger
26
+ from .cross_validation import _check_scorable , _cross_val_score
27
+ from .externals .joblib import Parallel , delayed
29
28
from .externals import six
30
29
from .utils import safe_mask , check_random_state
31
30
from .utils .validation import _num_samples , check_arrays
@@ -184,7 +183,7 @@ def __len__(self):
184
183
return self .n_iter
185
184
186
185
187
- def fit_grid_point (X , y , base_estimator , parameters , train , test , scorer ,
186
+ def fit_grid_point (X , y , estimator , parameters , train , test , scorer ,
188
187
verbose , loss_func = None , ** fit_params ):
189
188
"""Run fit on one set of parameters.
190
189
@@ -196,11 +195,11 @@ def fit_grid_point(X, y, base_estimator, parameters, train, test, scorer,
196
195
y : array-like or None
197
196
Targets for input data.
198
197
199
- base_estimator : estimator object
198
+ estimator : estimator object
200
199
This estimator will be cloned and then fitted.
201
200
202
201
parameters : dict
203
- Parameters to be set on base_estimator clone for this grid point.
202
+ Parameters to be set on estimator for this grid point.
204
203
205
204
train : ndarray, dtype int or bool
206
205
Boolean mask or indices for training set.
@@ -230,29 +229,11 @@ def fit_grid_point(X, y, base_estimator, parameters, train, test, scorer,
230
229
n_samples_test : int
231
230
Number of test samples in this split.
232
231
"""
233
- if verbose > 1 :
234
- start_time = time .time ()
235
- msg = '%s' % (', ' .join ('%s=%s' % (k , v )
236
- for k , v in parameters .items ()))
237
- print ("[GridSearchCV] %s %s" % (msg , (64 - len (msg )) * '.' ))
238
-
239
- # update parameters of the classifier after a copy of its base structure
240
- estimator = clone (base_estimator )
241
- estimator .set_params (** parameters )
242
-
243
- X_train , y_train = _split (estimator , X , y , train )
244
- X_test , y_test = _split (estimator , X , y , test , train )
245
- _fit (estimator .fit , X_train , y_train , ** fit_params )
246
- score = _score (estimator , X_test , y_test , scorer )
247
-
248
- if verbose > 2 :
249
- msg += ", score=%f" % score
250
- if verbose > 1 :
251
- end_msg = "%s -%s" % (msg , logger .short_format_time (time .time () -
252
- start_time ))
253
- print ("[GridSearchCV] %s %s" % ((64 - len (end_msg )) * '.' , end_msg ))
254
-
255
- return score , parameters , _num_samples (X_test )
232
+ score , n_samples_test = _cross_val_score (estimator , X , y , scorer , train ,
233
+ test , parameters , verbose ,
234
+ fit_params ,
235
+ log_label = "GridSearchCV" )
236
+ return score , parameters , n_samples_test
256
237
257
238
258
239
def _check_param_grid (param_grid ):
@@ -397,7 +378,7 @@ def _fit(self, X, y, parameter_iterable):
397
378
n_jobs = self .n_jobs , verbose = self .verbose ,
398
379
pre_dispatch = pre_dispatch )(
399
380
delayed (fit_grid_point )(
400
- X , y , base_estimator , parameters , train , test ,
381
+ X , y , clone ( base_estimator ) , parameters , train , test ,
401
382
self .scorer_ , self .verbose , ** self .fit_params )
402
383
for parameters in parameter_iterable
403
384
for train , test in cv )
0 commit comments