8000 Make it possible to pickle a fit `GridSearchCV` and `RandomizedSearchCV` by jnothman · Pull Request #1801 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

Make it possible to pickle a fit GridSearchCV and RandomizedSearchCV #1801

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 25 additions & 13 deletions sklearn/grid_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,11 @@ def _check_param_grid(param_grid):
"list.")


_CVScoreTuple = namedtuple('_CVScoreTuple', ('parameters',
'mean_validation_score',
'cv_validation_scores'))


class BaseSearchCV(BaseEstimator, MetaEstimatorMixin):
"""Base class for hyper parameter search with cross-validation.
"""
Expand All @@ -340,7 +345,9 @@ def __init__(self, estimator, scoring=None, loss_func=None,
self._check_estimator()

def score(self, X, y=None):
"""Returns the mean accuracy on the given test data and labels.
"""Returns the score on the given test data and labels, if the search
estimator has been refit. The ``score`` function of the best estimator
is used, or the ``scoring`` parameter where unavailable.

Parameters
----------
Expand All @@ -364,6 +371,22 @@ def score(self, X, y=None):
y_predicted = self.predict(X)
return self.scorer(y, y_predicted)

@property
def predict(self):
return self.best_estimator_.predict

@property
def predict_proba(self):
return self.best_estimator_.predict_proba

@property
def decision_function(self):
return self.best_estimator_.decision_function

@property
def transform(self):
return self.best_estimator_.transform

def _check_estimator(self):
"""Check that estimator can be fitted and score can be computed."""
if (not hasattr(self.estimator, 'fit') or
Expand All @@ -381,13 +404,6 @@ def _check_estimator(self):
"should have a 'score' method. The estimator %s "
"does not." % self.estimator)

def _set_methods(self):
"""Create predict and predict_proba if present in best estimator."""
if hasattr(self.best_estimator_, 'predict'):
self.predict = self.best_estimator_.predict
if hasattr(self.best_estimator_, 'predict_proba'):
self.predict_proba = self.best_estimator_.predict_proba

def _fit(self, X, y, parameter_iterator, **params):
"""Actual fitting, performing the search over parameters."""
estimator = self.estimator
Expand Down Expand Up @@ -492,14 +508,10 @@ def _fit(self, X, y, parameter_iterator, **params):
else:
best_estimator.fit(X, **self.fit_params)
self.best_estimator_ = best_estimator
self._set_methods()

# Store the computed scores
CVScoreTuple = namedtuple('CVScoreTuple', ('parameters',
'mean_validation_score',
'cv_validation_scores'))
self.cv_scores_ = [
CVScoreTuple(clf_params, score, all_scores)
_CVScoreTuple(clf_params, score, all_scores)
for clf_params, (score, _), all_scores
in zip(parameter_iterator, scores, cv_scores)]
return self
Expand Down
21 changes: 20 additions & 1 deletion sklearn/tests/test_grid_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from collections import Iterable, Sized
from cStringIO import StringIO
from itertools import chain, product
import pickle
import sys
import warnings

Expand Down Expand Up @@ -46,6 +47,10 @@ def fit(self, X, Y):
def predict(self, T):
return T.shape[0]

predict_proba = predict
decision_function = predict
transform = predict

def score(self, X=None, Y=None):
if self.foo_param > 1:
score = 1.
Expand Down Expand Up @@ -132,8 +137,11 @@ def test_grid_search():
for i, foo_i in enumerate([1, 2, 3]):
assert_true(grid_search.cv_scores_[i][0]
== {'foo_param': foo_i})
# Smoke test the score:
# Smoke test the score etc:
grid_search.score(X, y)
grid_search.predict_proba(X)
grid_search.decision_function(X)
grid_search.transform(X)


def test_trivial_cv_scores():
Expand Down Expand Up @@ -483,3 +491,14 @@ def test_grid_search_score_consistency():
clf.decision_function(X[test]))
assert_almost_equal(correct_score, scores[i])
i += 1

def test_pickle():
"""Test that a fit search can be pickled"""
clf = MockClassifier()
grid_search = GridSearchCV(clf, {'foo_param': [1, 2, 3]}, refit=True)
grid_search.fit(X, y)
pickle.dumps(grid_search) # smoke test

random_search = RandomizedSearchCV(clf, {'foo_param': [1, 2, 3]}, refit=True)
random_search.fit(X, y)
pickle.dumps(random_search) # smoke test
0