8000 [WIP] ENH estimator freezing to stop it being cloned/refit by jnothman · Pull Request #8374 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

[WIP] ENH estimator freezing to stop it being cloned/refit #8374

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 53 additions & 2 deletions sklearn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
# Author: Gael Varoquaux <gael.varoquaux@normalesup.org>
# License: BSD 3 clause

import copy
from copy import deepcopy
import warnings
import functools

import numpy as np
from scipy import sparse
Expand Down Expand Up @@ -45,11 +46,13 @@ def clone(estimator, safe=True):
"""
estimator_type = type(estimator)
# XXX: not handling dictionaries
if isinstance(getattr(estimator, 'fit', None), _FrozenFit):
return estimator
if estimator_type in (list, tuple, set, frozenset):
return estimator_type([clone(e, safe=safe) for e in estimator])
elif not hasattr(estimator, 'get_params'):
if not safe:
return copy.deepcopy(estimator)
return deepcopy(estimator)
else:
raise TypeError("Cannot clone object '%s' (type %s): "
"it does not seem to be a scikit-learn estimator "
Expand Down Expand Up @@ -578,3 +581,51 @@ def is_regressor(estimator):
True if estimator is a regressor and False otherwise.
"""
return getattr(estimator, "_estimator_type", None) == "regressor"


class _FrozenFit(object):
# We use a class as this allows isinstance check in clone
def __init__(self, obj):
self.obj = obj

def __call__(self, *args, **kwargs):
return self.obj


def _frozen_fit_method(obj, method, X, *args, **kwargs):
return getattr(obj, method)(X)


def freeze(estimator, copy=False):
"""Copies estimator and freezes it

Frozen estimators:
* have ``fit(self, *args, **kwargs)`` merely return ``self``
* have ``fit_transform`` merely perform ``transform``
* have ``fit_predict`` merely perform ``predict``

Parameters
----------
estimator : estimator
copy : bool

Returns
-------
frozen_estimator : estimator
A frozen copy of the input estimator, if ``copy``; otherwise, the
estimator is mutated to a frozen version of itself.

Notes
-----
Only works on estimators with ``__dict__``.
"""
if copy:
estimator = deepcopy(estimator)
estimator.fit = _FrozenFit(estimator)
if hasattr(estimator, 'fit_transform'):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would just remove fit_transform and fit_predict as well. Downstream users should be able to duck-type around using these.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or did you consider them required API? The transformer and cluster base classes have those, but they are not really part of the API contract imho.

estimator.fit_transform = functools.partial(_frozen_fit_method,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is because estimator.transform might not exist, and we want to provide an attribute error when fit_transform is called, and not when freeze is called, right?
Maybe write a comment about that or maybe rename it to make that more clear?

estimator, 'transform')
if hasattr(estimator, 'fit_predict'):
estimator.fit_predict = functools.partial(_frozen_fit_method,
estimator, 'predict')
return estimator
45 changes: 43 additions & 2 deletions sklearn/tests/test_base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# Author: Gael Varoquaux
# License: BSD 3 clause

import sys
import pickle

import numpy as np
import scipy.sparse as sp

Expand All @@ -23,12 +26,13 @@

from sklearn.tree import DecisionTreeClassifier
from sklearn.tree import DecisionTreeRegressor
from sklearn.feature_selection import SelectKBest
from sklearn import datasets
from sklearn.utils import deprecated

from sklearn.base import TransformerMixin
from sklearn.base import freeze
from sklearn.utils.mocking import MockDataFrame
import pickle


#############################################################################
Expand Down Expand Up @@ -370,6 +374,43 @@ def test_pickle_version_warning_is_issued_when_no_version_info_in_pickle():
tree_pickle_noversion)


def test_freeze():
X, y = datasets.load_iris(return_X_y=True)

for copy in [False, True]:
est = SelectKBest(k=1).fit(X, y)

frozen_est = freeze(est, copy=copy)
if copy:
assert_false(est is frozen_est)
else:
assert_true(est is frozen_est)
assert_array_equal(est.scores_, frozen_est.scores_)
assert_true(isinstance(frozen_est, SelectKBest))

dumped = pickle.dumps(frozen_est)
frozen_est2 = pickle.loads(dumped)
assert_false(frozen_est is frozen_est2)
assert_array_equal(est.scores_, frozen_est2.scores_)

# scores should be unaffected by new fit
assert_true(frozen_est2.fit() is frozen_est2)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is always true, right? Well, I guess you're testing that fit can be called without arguments?

assert_array_equal(est.scores_, frozen_est2.scores_)

# Test fit_transform where expected
assert_true(hasattr(est, 'fit_transform'))
assert_true(hasattr(frozen_est, 'fit_transform'))
assert_false(est.fit_transform is frozen_est.fit_transform)
frozen_est.fit_transform([np.arange(X.shape[1])], [0])
# scores should be unaffected by new fit_transform
assert_array_equal(est.scores_, frozen_est.scores_)

# Test fit_transform not set when not needed
est = DecisionTreeClassifier().fit(X, y)
frozen_est = freeze(est)
assert_false(hasattr(est, 'fit_transform'))
assert_false(hasattr(frozen_est, 'fit_transform'))

def test_pickle_version_no_warning_is_issued_with_non_sklearn_estimator():
iris = datasets.load_iris()
tree = TreeNoVersion().fit(iris.data, iris.target)
Expand Down Expand Up @@ -449,4 +490,4 @@ def test_pickling_works_when_getstate_is_overwritten_in_the_child_class():
serialized = pickle.dumps(estimator)
estimator_restored = pickle.loads(serialized)
assert_equal(estimator_restored.attribute_pickled, 5)
assert_equal(estimator_restored._attribute_not_pickled, None)
assert_equal(estimator_restored._attribute_not_pickled, None)
0