From efa9a860a15646cfb9d1014dc817599724c26535 Mon Sep 17 00:00:00 2001 From: Joel Nothman Date: Fri, 17 Feb 2017 00:32:20 +1100 Subject: [PATCH 1/6] ENH add freeze method which stops an estimator being cloned/refit --- sklearn/base.py | 35 +++++++++++++++++++++++++++++++++++ sklearn/tests/test_base.py | 22 +++++++++++++++++++++- 2 files changed, 56 insertions(+), 1 deletion(-) diff --git a/sklearn/base.py b/sklearn/base.py index 1b79841746677..6e94ad45a2a4f 100644 --- a/sklearn/base.py +++ b/sklearn/base.py @@ -5,6 +5,7 @@ import copy import warnings +import functools import numpy as np from scipy import sparse @@ -45,6 +46,8 @@ def clone(estimator, safe=True): """ estimator_type = type(estimator) # XXX: not handling dictionaries + if isinstance(getattr(estimator, 'fit'), _FrozenFit): + return estimator if estimator_type in (list, tuple, set, frozenset): return estimator_type([clone(e, safe=safe) for e in estimator]) elif not hasattr(estimator, 'get_params'): @@ -523,3 +526,35 @@ def is_classifier(estimator): def is_regressor(estimator): """Returns True if the given estimator is (probably) a regressor.""" return getattr(estimator, "_estimator_type", None) == "regressor" + + +class _FrozenFit(object): + # We use a class as this allows isinstance check in clone + def __init__(self, obj): + self.obj = obj + + def __call__(self, *args, **kwargs): + return self.obj + + +def _frozen_fit_method(obj, method, X, *args, **kwargs): + return getattr(obj, method)(args[0]) + + +def freeze(estimator): + """Copies estimator and freezes it + + Frozen estimators: + * have ``fit(self, *args, **kwargs)`` merely return ``self`` + * have ``fit_transform`` merely perform ``transform`` + + Parameters + ---------- + estimator : estimator + """ + estimator = copy.deepcopy(estimator) + estimator.fit = _FrozenFit(estimator) + if hasattr(estimator, 'fit_transform'): + estimator.fit_transform = functools.partial(_frozen_fit_method, + 'transform') + return estimator diff --git a/sklearn/tests/test_base.py b/sklearn/tests/test_base.py index 9983dbdc486bd..a40c9737451f6 100644 --- a/sklearn/tests/test_base.py +++ b/sklearn/tests/test_base.py @@ -2,6 +2,7 @@ # License: BSD 3 clause import sys +import pickle import numpy as np import scipy.sparse as sp @@ -23,12 +24,13 @@ from sklearn.tree import DecisionTreeClassifier from sklearn.tree import DecisionTreeRegressor +from sklearn.linear_model import LogisticRegression from sklearn import datasets from sklearn.utils import deprecated from sklearn.base import TransformerMixin +from sklearn.base import freeze from sklearn.utils.mocking import MockDataFrame -import pickle ############################################################################# @@ -359,3 +361,21 @@ def test_pickle_version_warning(): # check that no warning is raised for external estimators TreeNoVersion.__module__ = "notsklearn" assert_no_warnings(pickle.loads, tree_pickle_noversion) + + +def test_freeze(): + X, y = datasets.load_iris(return_X_y=True) + est = LogisticRegression().fit(X, y) + frozen_est = freeze(est) + + assert_false(est is frozen_est) + assert_array_equal(est.coef_, frozen_est.coef_) + assert_true(isinstance(frozen_est, LogisticRegression)) + + dumped = pickle.dumps(frozen_est) + frozen_est2 = pickle.loads(dumped) + assert_false(frozen_est is frozen_est2) + assert_array_equal(est.coef_, frozen_est2.coef_) + + assert_true(frozen_est2.fit() is frozen_est2) + assert_array_equal(est.coef_, frozen_est2.coef_) From 4d06fb0ced424842df9da03d236af2b546534486 Mon Sep 17 00:00:00 2001 From: Joel Nothman Date: Fri, 17 Feb 2017 01:55:31 +1100 Subject: [PATCH 2/6] ENH/DOC copy param and caveats --- sklearn/base.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/sklearn/base.py b/sklearn/base.py index 6e94ad45a2a4f..3cb01d7d41513 100644 --- a/sklearn/base.py +++ b/sklearn/base.py @@ -541,7 +541,7 @@ def _frozen_fit_method(obj, method, X, *args, **kwargs): return getattr(obj, method)(args[0]) -def freeze(estimator): +def freeze(estimator, copy=False): """Copies estimator and freezes it Frozen estimators: @@ -551,8 +551,20 @@ def freeze(estimator): Parameters ---------- estimator : estimator + copy : bool + + Returns + ------- + frozen_estimator : estimator + A frozen copy of the input estimator, if ``copy``; otherwise, the + estimator is mutated to a frozen version of itself. + + Notes + ----- + Only works on estimators with ``__dict__``. """ - estimator = copy.deepcopy(estimator) + if copy: + estimator = copy.deepcopy(estimator) estimator.fit = _FrozenFit(estimator) if hasattr(estimator, 'fit_transform'): estimator.fit_transform = functools.partial(_frozen_fit_method, From ecefd05345d6ea735ae40f7dcbc665b97fc9b935 Mon Sep 17 00:00:00 2001 From: Joel Nothman Date: Fri, 17 Feb 2017 09:01:59 +1100 Subject: [PATCH 3/6] FIX case where non-estimator in clone --- sklearn/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/base.py b/sklearn/base.py index 3cb01d7d41513..fbde35dde555b 100644 --- a/sklearn/base.py +++ b/sklearn/base.py @@ -46,7 +46,7 @@ def clone(estimator, safe=True): """ estimator_type = type(estimator) # XXX: not handling dictionaries - if isinstance(getattr(estimator, 'fit'), _FrozenFit): + if isinstance(getattr(estimator, 'fit', None), _FrozenFit): return estimator if estimator_type in (list, tuple, set, frozenset): return estimator_type([clone(e, safe=safe) for e in estimator]) From bcd4eae91b7509272f98e1d9a20f351111269677 Mon Sep 17 00:00:00 2001 From: Joel Nothman Date: Fri, 17 Feb 2017 10:36:01 +1100 Subject: [PATCH 4/6] TST/FIX copy param --- sklearn/tests/test_base.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/sklearn/tests/test_base.py b/sklearn/tests/test_base.py index a40c9737451f6..3c7af14231127 100644 --- a/sklearn/tests/test_base.py +++ b/sklearn/tests/test_base.py @@ -366,12 +366,17 @@ def test_pickle_version_warning(): def test_freeze(): X, y = datasets.load_iris(return_X_y=True) est = LogisticRegression().fit(X, y) - frozen_est = freeze(est) + frozen_est = freeze(est, copy=True) assert_false(est is frozen_est) assert_array_equal(est.coef_, frozen_est.coef_) assert_true(isinstance(frozen_est, LogisticRegression)) + frozen_est = freeze(est) + assert_true(est is frozen_est) + assert_array_equal(est.coef_, frozen_est.coef_) + assert_true(isinstance(frozen_est, LogisticRegression)) + dumped = pickle.dumps(frozen_est) frozen_est2 = pickle.loads(dumped) assert_false(frozen_est is frozen_est2) From 3398026844b38e8b9b845b45dd5938de5a5d1ada Mon Sep 17 00:00:00 2001 From: Joel Nothman Date: Fri, 17 Feb 2017 10:53:33 +1100 Subject: [PATCH 5/6] FIX Avoid naming conflicts --- sklearn/base.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sklearn/base.py b/sklearn/base.py index fbde35dde555b..9d27c7505b0e4 100644 --- a/sklearn/base.py +++ b/sklearn/base.py @@ -3,7 +3,7 @@ # Author: Gael Varoquaux # License: BSD 3 clause -import copy +from copy import deepcopy import warnings import functools @@ -52,7 +52,7 @@ def clone(estimator, safe=True): return estimator_type([clone(e, safe=safe) for e in estimator]) elif not hasattr(estimator, 'get_params'): if not safe: - return copy.deepcopy(estimator) + return deepcopy(estimator) else: raise TypeError("Cannot clone object '%s' (type %s): " "it does not seem to be a scikit-learn estimator " @@ -564,7 +564,7 @@ def freeze(estimator, copy=False): Only works on estimators with ``__dict__``. """ if copy: - estimator = copy.deepcopy(estimator) + estimator = deepcopy(estimator) estimator.fit = _FrozenFit(estimator) if hasattr(estimator, 'fit_transform'): estimator.fit_transform = functools.partial(_frozen_fit_method, From 705416b006f69d76460f4bbda4c13ea125d89856 Mon Sep 17 00:00:00 2001 From: Joel Nothman Date: Fri, 17 Feb 2017 11:40:35 +1100 Subject: [PATCH 6/6] TST/FIX fit_transform in freeze --- sklearn/base.py | 8 ++++-- sklearn/tests/test_base.py | 51 ++++++++++++++++++++++++-------------- 2 files changed, 39 insertions(+), 20 deletions(-) diff --git a/sklearn/base.py b/sklearn/base.py index 9d27c7505b0e4..a0c6db1e7b0d3 100644 --- a/sklearn/base.py +++ b/sklearn/base.py @@ -538,7 +538,7 @@ def __call__(self, *args, **kwargs): def _frozen_fit_method(obj, method, X, *args, **kwargs): - return getattr(obj, method)(args[0]) + return getattr(obj, method)(X) def freeze(estimator, copy=False): @@ -547,6 +547,7 @@ def freeze(estimator, copy=False): Frozen estimators: * have ``fit(self, *args, **kwargs)`` merely return ``self`` * have ``fit_transform`` merely perform ``transform`` + * have ``fit_predict`` merely perform ``predict`` Parameters ---------- @@ -568,5 +569,8 @@ def freeze(estimator, copy=False): estimator.fit = _FrozenFit(estimator) if hasattr(estimator, 'fit_transform'): estimator.fit_transform = functools.partial(_frozen_fit_method, - 'transform') + estimator, 'transform') + if hasattr(estimator, 'fit_predict'): + estimator.fit_predict = functools.partial(_frozen_fit_method, + estimator, 'predict') return estimator diff --git a/sklearn/tests/test_base.py b/sklearn/tests/test_base.py index 3c7af14231127..4dea54f385740 100644 --- a/sklearn/tests/test_base.py +++ b/sklearn/tests/test_base.py @@ -24,7 +24,7 @@ from sklearn.tree import DecisionTreeClassifier from sklearn.tree import DecisionTreeRegressor -from sklearn.linear_model import LogisticRegression +from sklearn.feature_selection import SelectKBest from sklearn import datasets from sklearn.utils import deprecated @@ -365,22 +365,37 @@ def test_pickle_version_warning(): def test_freeze(): X, y = datasets.load_iris(return_X_y=True) - est = LogisticRegression().fit(X, y) - - frozen_est = freeze(est, copy=True) - assert_false(est is frozen_est) - assert_array_equal(est.coef_, frozen_est.coef_) - assert_true(isinstance(frozen_est, LogisticRegression)) + for copy in [False, True]: + est = SelectKBest(k=1).fit(X, y) + + frozen_est = freeze(est, copy=copy) + if copy: + assert_false(est is frozen_est) + else: + assert_true(est is frozen_est) + assert_array_equal(est.scores_, frozen_est.scores_) + assert_true(isinstance(frozen_est, SelectKBest)) + + dumped = pickle.dumps(frozen_est) + frozen_est2 = pickle.loads(dumped) + assert_false(frozen_est is frozen_est2) + assert_array_equal(est.scores_, frozen_est2.scores_) + + # scores should be unaffected by new fit + assert_true(frozen_est2.fit() is frozen_est2) + assert_array_equal(est.scores_, frozen_est2.scores_) + + # Test fit_transform where expected + assert_true(hasattr(est, 'fit_transform')) + assert_true(hasattr(frozen_est, 'fit_transform')) + assert_false(est.fit_transform is frozen_est.fit_transform) + frozen_est.fit_transform([np.arange(X.shape[1])], [0]) + # scores should be unaffected by new fit_transform + assert_array_equal(est.scores_, frozen_est.scores_) + + # Test fit_transform not set when not needed + est = DecisionTreeClassifier().fit(X, y) frozen_est = freeze(est) - assert_true(est is frozen_est) - assert_array_equal(est.coef_, frozen_est.coef_) - assert_true(isinstance(frozen_est, LogisticRegression)) - - dumped = pickle.dumps(frozen_est) - frozen_est2 = pickle.loads(dumped) - assert_false(frozen_est is frozen_est2) - assert_array_equal(est.coef_, frozen_est2.coef_) - - assert_true(frozen_est2.fit() is frozen_est2) - assert_array_equal(est.coef_, frozen_est2.coef_) + assert_false(hasattr(est, 'fit_transform')) + assert_false(hasattr(frozen_est, 'fit_transform'))