8000 ENH/FIX make best_estimator_'s predict functions available in paramet… · scikit-learn/scikit-learn@83f67ee · GitHub
[go: up one dir, main page]

Skip to content

Commit 83f67ee

Browse files
jnothmanamueller
authored andcommitted
ENH/FIX make best_estimator_'s predict functions available in parameter search
Avoids copying an unpicklable method to parameter search's __dict__. Also adds decision_function and transform where only predict and predict_proba were available before.
1 parent 46292a1 commit 83f67ee

File tree

2 files changed

+27
-10
lines changed

2 files changed

+27
-10
lines changed

sklearn/grid_search.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,9 @@ def __init__(self, estimator, scoring=None, loss_func=None,
340340
self._check_estimator()
341341

342342
def score(self, X, y=None):
343-
"""Returns the mean accuracy on the given test data and labels.
343+
"""Returns the score on the given test data and labels, if the search
344+
estimator has been refit. The ``score`` function of the best estimator
345+
is used, or the ``scoring`` parameter where unavailable.
344346
345347
Parameters
346348
----------
@@ -364,6 +366,22 @@ def score(self, X, y=None):
364366
y_predicted = self.predict(X)
365367
return self.scorer(y, y_predicted)
366368

369+
@property
370+
def predict(self):
371+
return self.best_estimator_.predict
372+
373+
@property
374+
def predict_proba(self):
375+
return self.best_estimator_.predict_proba
376+
377+
@property
378+
def decision_function(self):
379+
return self.best_estimator_.decision_function
380+
381+
@property
382+
def transform(self):
383+
return self.best_estimator_.transform
384+
367385
def _check_estimator(self):
368386
"""Check that estimator can be fitted and score can be computed."""
369387
if (not hasattr(self.estimator, 'fit') or
@@ -381,13 +399,6 @@ def _check_estimator(self):
381399
"should have a 'score' method. The estimator %s "
382400
"does not." % self.estimator)
383401

384-
def _set_methods(self):
385-
"""Create predict and predict_proba if present in best estimator."""
386-
if hasattr(self.best_estimator_, 'predict'):
387-
self.predict = self.best_estimator_.predict
388-
if hasattr(self.best_estimator_, 'predict_proba'):
389-
self.predict_proba = self.best_estimator_.predict_proba
390-
391402
def _fit(self, X, y, parameter_iterator, **params):
392403
"&q 8000 uot;"Actual fitting, performing the search over parameters."""
393404
estimator = self.estimator
@@ -492,7 +503,6 @@ def _fit(self, X, y, parameter_iterator, **params):
492503
else:
493504
best_estimator.fit(X, **self.fit_params)
494505
self.best_estimator_ = best_estimator
495-
self._set_methods()
496506

497507
# Store the computed scores
498508
CVScoreTuple = namedtuple('CVScoreTuple', ('parameters',

sklearn/tests/test_grid_search.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,10 @@ def fit(self, X, Y):
4747
def predict(self, T):
4848
return T.shape[0]
4949

50+
predict_proba = predict
51+
decision_function = predict
52+
transform = predict
53+
5054
def score(self, X=None, Y=None):
5155
if self.foo_param > 1:
5256
score = 1.
@@ -133,8 +137,11 @@ def test_grid_search():
133137
for i, foo_i in enumerate([1, 2, 3]):
134138
assert_true(grid_search.cv_scores_[i][0]
135139
== {'foo_param': foo_i})
136-
# Smoke test the score:
140+
# Smoke test the score etc:
137141
grid_search.score(X, y)
142+
grid_search.predict_proba(X)
143+
grid_search.decision_function(X)
144+
grid_search.transform(X)
138145

139146

140147
def test_trivial_cv_scores():

0 commit comments

Comments
 (0)
0