@@ -340,7 +340,9 @@ def __init__(self, estimator, scoring=None, loss_func=None,
340
340
self ._check_estimator ()
341
341
342
342
def score (self , X , y = None ):
343
- """Returns the mean accuracy on the given test data and labels.
343
+ """Returns the score on the given test data and labels, if the search
344
+ estimator has been refit. The ``score`` function of the best estimator
345
+ is used, or the ``scoring`` parameter where unavailable.
344
346
345
347
Parameters
346
348
----------
@@ -364,6 +366,22 @@ def score(self, X, y=None):
364
366
y_predicted = self .predict (X )
365
367
return self .scorer (y , y_predicted )
366
368
369
+ @property
370
+ def predict (self ):
371
+ return self .best_estimator_ .predict
372
+
373
+ @property
374
+ def predict_proba (self ):
375
+ return self .best_estimator_ .predict_proba
376
+
377
+ @property
378
+ def decision_function (self ):
379
+ return self .best_estimator_ .decision_function
380
+
381
+ @property
382
+ def transform (self ):
383
+ return self .best_estimator_ .transform
384
+
367
385
def _check_estimator (self ):
368
386
"""Check that estimator can be fitted and score can be computed."""
369
387
if (not hasattr (self .estimator , 'fit' ) or
@@ -381,13 +399,6 @@ def _check_estimator(self):
381
399
"should have a 'score' method. The estimator %s "
382
400
"does not." % self .estimator )
383
401
384
- def _set_methods (self ):
385
- """Create predict and predict_proba if present in best estimator."""
386
- if hasattr (self .best_estimator_ , 'predict' ):
387
- self .predict = self .best_estimator_ .predict
388
- if hasattr (self .best_estimator_ , 'predict_proba' ):
389
- self .predict_proba = self .best_estimator_ .predict_proba
390
-
391
402
def _fit (self , X , y , parameter_iterator , ** params ):
392
403
"&q
8000
uot;"Actual fitting, performing the search over parameters."""
393
404
estimator = self .estimator
@@ -492,7 +503,6 @@ def _fit(self, X, y, parameter_iterator, **params):
492
503
else :
493
504
best_estimator .fit (X , ** self .fit_params )
494
505
self .best_estimator_ = best_estimator
495
- self ._set_methods ()
496
506
497
507
# Store the computed scores
498
508
CVScoreTuple = namedtuple ('CVScoreTuple' , ('parameters' ,
0 commit comments