@@ -236,59 +236,79 @@ def fit_grid_point(X, y, base_estimator, parameters, train, test, scorer,
236
236
print ("[GridSearchCV] %s %s" % (msg , (64 - len (msg )) * '.' ))
237
237
238
238
# update parameters of the classifier after a copy of its base structure
239
- clf = clone (base_estimator )
240
- clf .set_params (** parameters )
239
+ estimator = clone (base_estimator )
240
+ estimator .set_params (** parameters )
241
241
242
- if hasattr (base_estimator , 'kernel' ) and callable (base_estimator .kernel ):
242
+ X_train , y_train = _split (estimator , X , y , train )
243
+ X_test , y_test = _split (estimator , X , y , test , train )
244
+ _fit (estimator .fit , X_train , y_train , ** fit_params )
245
+ this_score = _score (estimator , X_test , y_test , scorer )
246
+
247
+ if verbose > 2 :
248
+ msg += ", score=%f" % this_score
249
+ if verbose > 1 :
250
+ end_msg = "%s -%s" % (msg ,
251
+ logger .short_format_time (time .time () -
252
+ start_time ))
253
+ print ("[GridSearchCV] %s %s" % ((64 - len (end_msg )) * '.' , end_msg ))
254
+
255
+ return this_score , parameters , _num_samples (X_test )
256
+
257
+
258
+ def _split (estimator , X , y , indices , train_indices = None ):
259
+ """Create subset of dataset."""
260
+ if hasattr (estimator , 'kernel' ) and callable (estimator .kernel ):
243
261
# cannot compute the kernel values with custom function
244
262
raise ValueError ("Cannot use a custom kernel function. "
245
263
"Precompute the kernel matrix instead." )
246
264
247
265
if not hasattr (X , "shape" ):
248
- if getattr (base_estimator , "_pairwise" , False ):
266
+ if getattr (estimator , "_pairwise" , False ):
249
267
raise ValueError ("Precomputed kernels or affinity matrices have "
250
268
"to be passed as arrays or sparse matrices." )
251
- X_train = [X [idx ] for idx in train ]
252
- X_test = [X [idx ] for idx in test ]
269
+ X_subset = [X [idx ] for idx in indices ]
253
270
else :
254
- if getattr (base_estimator , "_pairwise" , False ):
271
+ if getattr (estimator , "_pairwise" , False ):
255
272
# X is a precomputed square kernel matrix
256
273
if X .shape [0 ] != X .shape [1 ]:
257
274
raise ValueError ("X should be a square kernel matrix" )
258
- X_train = X [np .ix_ (train , train )]
259
- X_test = X [np .ix_ (test , train )]
275
+ if train_indices is None :
276
+ X_subset = X [np .ix_ (indices , indices )]
277
+ else :
278
+ X_subset = X [np .ix_ (indices , train_indices )]
260
279
else :
261
- X_train = X [safe_mask (X , train )]
262
- X_test = X [safe_mask (X , test )]
280
+ X_subset = X [safe_mask (X , indices )]
263
281
264
282
if y is not None :
265
- y_test = y [safe_mask (y , test )]
266
- y_train = y [safe_mask (y , train )]
267
- clf .fit (X_train , y_train , ** fit_params )
283
+ y_subset = y [safe_mask (y , indices )]
284
+ else :
285
+ y_subset = None
286
+
287
+ return X_subset , y_subset
288
+
268
289
269
- if scorer is not None :
270
- this_score = scorer (clf , X_test , y_test )
290
+ def _fit (fit_function , X_train , y_train , ** fit_params ):
291
+ """Fit and estimator on a given training set."""
292
+ if y_train is None :
293
+ fit_function (X_train , ** fit_params )
294
+ else :
295
+ fit_function (X_train , y_train , ** fit_params )
296
+
297
+
298
+ def _score (estimator , X_test , y_test , scorer ):
299
+ """Compute the score of an estimator on a given test set."""
300
+ if y_test is None :
301
+ if scorer is None :
302
+ this_score = estimator .score (X_test )
271
303
else :
272
- this_score = clf . score ( X_test , y_test )
304
+ this_score = scorer ( estimator , X_test )
273
305
else :
274
- clf .fit (X_train , ** fit_params )
275
- if scorer is not None :
276
- this_score = scorer (clf , X_test )
306
+ if scorer is None :
307
+ this_score = estimator .score (X_test , y_test )
277
308
else :
278
- this_score = clf .score (X_test )
279
-
280
- if not isinstance (this_score , numbers .Number ):
281
- raise ValueError ("scoring must return a number, got %s (%s)"
282
- " instead." % (str (this_score ), type (this_score )))
309
+ this_score = scorer (estimator , X_test , y_test )
283
310
284
- if verbose > 2 :
285
- msg += ", score=%f" % this_score
286
- if verbose > 1 :
287
- end_msg = "%s -%s" % (msg ,
288
- logger .short_format_time (time .time () -
289
- start_time ))
290
- print ("[GridSearchCV] %s %s" % ((64 - len (end_msg )) * '.' , end_msg ))
291
- return this_score , parameters , _num_samples (X_test )
311
+ return this_score
292
312
293
313
294
314
def _check_param_grid (param_grid ):
@@ -331,6 +351,24 @@ def __repr__(self):
331
351
self .parameters )
332
352
333
353
354
+ def _check_scorable (estimator , scoring = None , loss_func = None , score_func = None ):
355
+ """Check that estimator can be fitted and score can be computed."""
356
+ if (not hasattr (estimator , 'fit' ) or
357
+ not (hasattr (estimator , 'predict' )
358
+ or hasattr (estimator , 'score' ))):
359
+ raise TypeError ("estimator should a be an estimator implementing"
360
+ " 'fit' and 'predict' or 'score' methods,"
361
+ " %s (type %s) was passed" %
362
+ (estimator , type (estimator )))
363
+ if (scoring is None and loss_func is None and score_func
364
+ is None ):
365
+ if not hasattr (estimator , 'score' ):
366
+ raise TypeError (
367
+ "If no scoring is specified, the estimator passed "
368
+ "should have a 'score' method. The estimator %s "
369
+ "does not." % estimator )
370
+
371
+
334
372
class BaseSearchCV (six .with_metaclass (ABCMeta , BaseEstimator ,
335
373
MetaEstimatorMixin )):
336
374
"""Base class for hyper parameter search with cross-validation."""
@@ -351,7 +389,8 @@ def __init__(self, estimator, scoring=None, loss_func=None,
351
389
self .cv = cv
352
390
self .verbose = verbose
353
391
self .pre_dispatch = pre_dispatch
354
- self ._check_estimator ()
392
+ _check_scorable (self .estimator , scoring = self .scoring ,
393
+ loss_func = self .loss_func , score_func = self .score_func )
355
394
356
395
def score (self , X , y = None ):
357
396
"""Returns the score on the given test data and labels, if the search
@@ -396,24 +435,7 @@ def decision_function(self):
396
435
@property
397
436
def transform (self ):
398
437
return self .best_estimator_ .transform
399
-
400
- def _check_estimator (self ):
401
- """Check that estimator can be fitted and score can be computed."""
402
- if (not hasattr (self .estimator , 'fit' ) or
403
- not (hasattr (self .estimator , 'predict' )
404
- or hasattr (self .estimator , 'score' ))):
405
- raise TypeError ("estimator should a be an estimator implementing"
406
- " 'fit' and 'predict' or 'score' methods,"
407
- " %s (type %s) was passed" %
408
- (self .estimator , type (self .estimator )))
409
- if (self .scoring is None and self .loss_func is None and self .score_func
410
- is None ):
411
- if not hasattr (self .estimator , 'score' ):
412
- raise TypeError (
413
- "If no scoring is specified, the estimator passed "
414
- "should have a 'score' method. The estimator %s "
415
- "does not." % self .estimator )
416
-
438
+
417
439
def _fit (self , X , y , parameter_iterable ):
418
440
"""Actual fitting, performing the search over parameters."""
419
441
0 commit comments