-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
[WIP] ENH estimator freezing to stop it being cloned/refit #8374
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
efa9a86
4d06fb0
ecefd05
bcd4eae
3398026
705416b
347c8ef
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,8 +3,9 @@ | |
# Author: Gael Varoquaux <gael.varoquaux@normalesup.org> | ||
# License: BSD 3 clause | ||
|
||
import copy | ||
from copy import deepcopy | ||
import warnings | ||
import functools | ||
|
||
import numpy as np | ||
from scipy import sparse | ||
|
@@ -45,11 +46,13 @@ def clone(estimator, safe=True): | |
""" | ||
estimator_type = type(estimator) | ||
# XXX: not handling dictionaries | ||
if isinstance(getattr(estimator, 'fit', None), _FrozenFit): | ||
return estimator | ||
if estimator_type in (list, tuple, set, frozenset): | ||
return estimator_type([clone(e, safe=safe) for e in estimator]) | ||
elif not hasattr(estimator, 'get_params'): | ||
if not safe: | ||
return copy.deepcopy(estimator) | ||
return deepcopy(estimator) | ||
else: | ||
raise TypeError("Cannot clone object '%s' (type %s): " | ||
"it does not seem to be a scikit-learn estimator " | ||
|
@@ -578,3 +581,51 @@ def is_regressor(estimator): | |
True if estimator is a regressor and False otherwise. | ||
""" | ||
return getattr(estimator, "_estimator_type", None) == "regressor" | ||
|
||
|
||
class _FrozenFit(object): | ||
# We use a class as this allows isinstance check in clone | ||
def __init__(self, obj): | ||
self.obj = obj | ||
|
||
def __call__(self, *args, **kwargs): | ||
return self.obj | ||
|
||
|
||
def _frozen_fit_method(obj, method, X, *args, **kwargs): | ||
return getattr(obj, method)(X) | ||
|
||
|
||
def freeze(estimator, copy=False): | ||
"""Copies estimator and freezes it | ||
|
||
Frozen estimators: | ||
* have ``fit(self, *args, **kwargs)`` merely return ``self`` | ||
* have ``fit_transform`` merely perform ``transform`` | ||
* have ``fit_predict`` merely perform ``predict`` | ||
|
||
Parameters | ||
---------- | ||
estimator : estimator | ||
copy : bool | ||
|
||
Returns | ||
------- | ||
frozen_estimator : estimator | ||
A frozen copy of the input estimator, if ``copy``; otherwise, the | ||
estimator is mutated to a frozen version of itself. | ||
|
||
Notes | ||
----- | ||
Only works on estimators with ``__dict__``. | ||
""" | ||
if copy: | ||
estimator = deepcopy(estimator) | ||
estimator.fit = _FrozenFit(estimator) | ||
if hasattr(estimator, 'fit_transform'): | ||
estimator.fit_transform = functools.partial(_frozen_fit_method, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is because |
||
estimator, 'transform') | ||
if hasattr(estimator, 'fit_predict'): | ||
estimator.fit_predict = functools.partial(_frozen_fit_method, | ||
estimator, 'predict') | ||
return estimator |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,9 @@ | ||
# Author: Gael Varoquaux | ||
# License: BSD 3 clause | ||
|
||
import sys | ||
import pickle | ||
|
||
import numpy as np | ||
import scipy.sparse as sp | ||
|
||
|
@@ -23,12 +26,13 @@ | |
|
||
from sklearn.tree import DecisionTreeClassifier | ||
from sklearn.tree import DecisionTreeRegressor | ||
from sklearn.feature_selection import SelectKBest | ||
from sklearn import datasets | ||
from sklearn.utils import deprecated | ||
|
||
from sklearn.base import TransformerMixin | ||
from sklearn.base import freeze | ||
from sklearn.utils.mocking import MockDataFrame | ||
import pickle | ||
|
||
|
||
############################################################################# | ||
|
@@ -370,6 +374,43 @@ def test_pickle_version_warning_is_issued_when_no_version_info_in_pickle(): | |
tree_pickle_noversion) | ||
|
||
|
||
def test_freeze(): | ||
X, y = datasets.load_iris(return_X_y=True) | ||
|
||
for copy in [False, True]: | ||
est = SelectKBest(k=1).fit(X, y) | ||
|
||
frozen_est = freeze(est, copy=copy) | ||
if copy: | ||
assert_false(est is frozen_est) | ||
else: | ||
assert_true(est is frozen_est) | ||
assert_array_equal(est.scores_, frozen_est.scores_) | ||
assert_true(isinstance(frozen_est, SelectKBest)) | ||
|
||
dumped = pickle.dumps(frozen_est) | ||
frozen_est2 = pickle.loads(dumped) | ||
assert_false(frozen_est is frozen_est2) | ||
assert_array_equal(est.scores_, frozen_est2.scores_) | ||
|
||
# scores should be unaffected by new fit | ||
assert_true(frozen_est2.fit() is frozen_est2) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is always true, right? Well, I guess you're testing that fit can be called without arguments? |
||
assert_array_equal(est.scores_, frozen_est2.scores_) | ||
|
||
# Test fit_transform where expected | ||
assert_true(hasattr(est, 'fit_transform')) | ||
assert_true(hasattr(frozen_est, 'fit_transform')) | ||
assert_false(est.fit_transform is frozen_est.fit_transform) | ||
frozen_est.fit_transform([np.arange(X.shape[1])], [0]) | ||
# scores should be unaffected by new fit_transform | ||
assert_array_equal(est.scores_, frozen_est.scores_) | ||
|
||
# Test fit_transform not set when not needed | ||
est = DecisionTreeClassifier().fit(X, y) | ||
frozen_est = freeze(est) | ||
assert_false(hasattr(est, 'fit_transform')) | ||
assert_false(hasattr(frozen_est, 'fit_transform')) | ||
|
||
def test_pickle_version_no_warning_is_issued_with_non_sklearn_estimator(): | ||
iris = datasets.load_iris() | ||
tree = TreeNoVersion().fit(iris.data, iris.target) | ||
|
@@ -449,4 +490,4 @@ def test_pickling_works_when_getstate_is_overwritten_in_the_child_class(): | |
serialized = pickle.dumps(estimator) | ||
estimator_restored = pickle.loads(serialized) | ||
assert_equal(estimator_restored.attribute_pickled, 5) | ||
assert_equal(estimator_restored._attribute_not_pickled, None) | ||
assert_equal(estimator_restored._attribute_not_pickled, None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would just remove
fit_transform
andfit_predict
as well. Downstream users should be able to duck-type around using these.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or did you consider them required API? The transformer and cluster base classes have those, but they are not really part of the API contract imho.