8000 Work on cross val and pipelines. by GaelVaroquaux · Pull Request #6 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

Work on cross val and pipelines. #6

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
3 commits merged into from
Sep 16, 2010
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 48 additions & 4 deletions scikits/learn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,45 @@

# License: BSD Style
import inspect
import copy

import numpy as np

from .metrics import explained_variance

################################################################################
def clone(estimator):
def clone(estimator, safe=True):
""" Constructs a new estimator with the same parameters.

Clone does a deep copy of the model in an estimator
without actually copying attached data. It yields a new estimator
with the same parameters that has not been fit on any data.

Parameters
============
estimator: estimator object, or list, tuple or set of objects
The estimator or group of estimators to be cloned
safe: boolean, optional
If safe is false, clone will fall back to a deepcopy on objects
that are not estimators.

"""
estimator_type = type(estimator)
# XXX: not handling dictionnaries
if estimator_type in (list, tuple, set, frozenset):
return estimator_type([clone(e, safe=safe) for e in estimator])
elif not hasattr(estimator, '_get_params'):
if not safe:
return copy.deepcopy(estimator)
else:
raise ValueError("Cannot clone object '%s' (type %s): "
"it does not seem to be a scikit-learn estimator as "
"it does not implement a '_get_params' methods."
% (repr(estimator), type(estimator)))
klass = estimator.__class__
new_object_params = estimator._get_params(deep=False)
for name, param in new_object_params.iteritems():
if hasattr(param, '_get_params'):
new_object_params[name] = clone(param)
new_object_params[name] = clone(param, safe=False)
new_object = klass(**new_object_params)

return new_object
Expand Down Expand Up @@ -108,7 +129,7 @@ def _get_param_names(cls):
args = []
return args

def _get_params(self, deep=False):
def _get_params(self, deep=True):
""" Get parameters for the estimator

Parameters
Expand Down Expand Up @@ -220,3 +241,26 @@ def score(self, X, y):
z : float
"""
return explained_variance(self.predict(X), y)


################################################################################
# XXX: Temporary solution to figure out if an estimator is a classifier

def _get_sub_estimator(estimator):
""" Returns the final estimator if there is any.
"""
if hasattr(estimator, 'estimator'):
# GridSearchCV and other CV-tuned estimators
return _get_sub_estimator(estimator.estimator)
if hasattr(estimator, 'steps'):
# Pipeline
return _get_sub_estimator(estimator.steps[-1][1])
return estimator


def is_classifier(estimator):
""" Returns True if the given estimator is (probably) a classifier.
"""
estimator = _get_sub_estimator(estimator)
return isinstance(estimator, ClassifierMixin)

10 changes: 5 additions & 5 deletions scikits/learn/cross_val.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from math import ceil
import numpy as np

from .base import ClassifierMixin
from .base import is_classifier, clone
from .utils.extmath import factorial, combinations
from .externals.joblib import Parallel, delayed

Expand Down Expand Up @@ -485,9 +485,7 @@ def cross_val_score(estimator, X, y=None, score_func=None, cv=None,
"""
n_samples = len(X)
if cv is None:
if y is not None and (isinstance(estimator, ClassifierMixin)
or (hasattr(estimator, 'estimator')
and isinstance(estimator.estimator, ClassifierMixin))):
if y is not None and is_classifier(estimator):
cv = StratifiedKFold(y, k=3)
else:
cv = KFold(n_samples, k=3)
Expand All @@ -497,8 +495,10 @@ def cross_val_score(estimator, X, y=None, score_func=None, cv=None,
"should have a 'score' method. The estimator %s "
"does not." % estimator
)
# We clone the estimator to make sure that all the folds are
# independent, and that it is pickable.
scores = Parallel(n_jobs=n_jobs, verbose=verbose)(
delayed(_cross_val_score)(estimator, X, y, score_func,
delayed(_cross_val_score)(clone(estimator), X, y, score_func,
train, test)
for train, test in cv)
return np.array(scores)
Expand Down
35 changes: 22 additions & 13 deletions scikits/learn/grid_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from .externals.joblib import Parallel, delayed
from .cross_val import KFold, StratifiedKFold
from .base import ClassifierMixin, clone
from .base import BaseEstimator, is_classifier, clone

try:
from itertools import product
Expand Down Expand Up @@ -50,9 +50,11 @@ def iter_grid(param_grid):
if hasattr(param_grid, 'has_key'):
param_grid = [param_grid]
for p in param_grid:
keys = p.keys()
for v in product(*p.values()):
params = dict(zip(keys,v))
# Always sort the keys of a dictionary, for reproducibility
items = sorted(p.items())
keys, values = zip(*items)
for v in product(*values):
params = dict(zip(keys, v))
yield params


Expand All @@ -65,7 +67,8 @@ def fit_grid_point(X, y, base_clf, clf_params, cv, loss_func, iid,
clf = copy.deepcopy(base_clf)
clf._set_params(**clf_params)

score = 0
score = 0.
n_test_samples = 0.
for train, test in cv:
clf.fit(X[train], y[train], **fit_params)
y_test = y[test]
Expand All @@ -76,13 +79,16 @@ def fit_grid_point(X, y, base_clf, clf_params, cv, loss_func, iid,
this_score = clf.score(X[test], y_test)
if iid:
this_score *= len(y_test)
n_test_samples += len(y_test)
score += this_score
if iid:
score /= n_test_samples

return clf, score
return score, clf


################################################################################
class GridSearchCV(object):
class GridSearchCV(BaseEstimator):
"""
Grid search on the parameters of a classifier.

Expand Down Expand Up @@ -181,9 +187,7 @@ def fit(self, X, y, cv=None, **kw):
estimator = self.estimator
if cv is None:
n_samples = len(X)
if y is not None and (isinstance(estimator, ClassifierMixin)
or (hasattr(estimator, 'estimator')
and isinstance(estimator.estimator, ClassifierMixin))):
if y is not None and is_classifier(estimator):
cv = StratifiedKFold(y, k=3)
else:
cv = KFold(n_samples, k=3)
Expand All @@ -195,12 +199,17 @@ def fit(self, X, y, cv=None, **kw):
cv, self.loss_func, self.iid, **self.fit_params)
for clf_params in grid)

# Out is a list of pairs: estimator, score
key = lambda pair: pair[1]
best_estimator = max(out, key=key)[0] # get maximum score
# Out is a list of pairs: score, estimator
best_estimator = max(out)[1] # get maximum score

self.best_estimator = best_estimator
self.predict = best_estimator.predict
self.score = best_estimator.score

# Store the computed scores
grid = iter_grid(self.param_grid)
self.grid_points_scores_ = dict((tuple(clf_params.items()), score)
for clf_params, (score, _) in zip(grid, out))

return self

Expand Down
2 changes: 1 addition & 1 deletion scikits/learn/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def __init__(self, steps):
"'%s' (type %s) doesn't)" % (estimator, type(estimator))
)

def _get_params(self, deep=False):
def _get_params(self, deep=True):
if not deep:
return super(Pipeline, self)._get_params(deep=False)
else:
Expand Down
19 changes: 17 additions & 2 deletions scikits/learn/tests/test_base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@

# Author: Gael Varoquaux
# License: BSD

from nose.tools import assert_true, assert_false, assert_equal, \
assert_raises
from ..base import BaseEstimator, clone
from ..base import BaseEstimator, clone, is_classifier

################################################################################
# A few test classes
Expand Down Expand Up @@ -74,7 +78,6 @@ def test_str():


def test_get_params():

test = T(K(), K())

assert_true('a__d' in test._get_params(deep=True))
Expand All @@ -84,3 +87,15 @@ def test_get_params():
assert test.a.d == 2
assert_raises(AssertionError, test._set_params, a__a=2)


def test_is_classifier():
from ..svm import SVC
from ..pipeline import Pipeline
from ..grid_search import GridSearchCV
svc = SVC()
assert_true(is_classifier(svc))
assert_true(is_classifier(GridSearchCV(svc, {'C': [0.1, 1]})))
assert_true(is_classifier(Pipeline([('svc', svc)])))
assert_true(is_classifier(Pipeline([('svc_cv',
GridSearchCV(svc, {'C': [0.1, 1]}))])))

14 changes: 12 additions & 2 deletions scikits/learn/tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Test the pipeline module.
"""

from nose.tools import assert_raises, assert_equal
from nose.tools import assert_raises, assert_equal, assert_false

from ..base import BaseEstimator, clone
from ..pipeline import Pipeline
Expand Down Expand Up @@ -56,4 +56,14 @@ def test_pipeline_init():

# Test clone
pipe2 = clone(pipe)
assert_equal(pipe._get_params(), pipe2._get_params())
assert_false(pipe._named_steps['svc'] is pipe2._named_steps['svc'])

# Check that appart from estimators, the parameters are the same
params = pipe._get_params()
params2 = pipe2._get_params()
# Remove estimators that where copied
params.pop('svc')
params.pop('anova')
params2.pop('svc')
params2.pop('anova')
assert_equal(params, params2)
0