8000 [MRG] ENH avoid deepcopy when a parameter is declared immutable by glemaitre · Pull Request #16185 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

[MRG] ENH avoid deepcopy when a parameter is declared immutable #16185

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 17 commits into from
Closed
46 changes: 29 additions & 17 deletions doc/developers/develop.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ Developing scikit-learn estimators
==================================

Whether you are proposing an estimator for inclusion in scikit-learn,
developing a separate package compatible with scikit-learn, or
implementing custom components for your own projects, this chapter
details how to develop objects that safely interact with scikit-learn
developing a separate package compatible with scikit-learn, or
implementing custom components for your own projects, this chapter
details how to develop objects that safely interact with scikit-learn
Pipelines and model selection tools.

.. currentmodule:: sklearn
Expand Down Expand Up @@ -266,7 +266,7 @@ the name of the underlying estimator and check in the test name. This allows
`pytest -k` to be used to specify which tests to run.

.. code-block: bash

pytest test_check_estimators.py -k check_estimators_fit_returns_self

Before detailing the required interface below, we describe two ways to achieve
Expand Down Expand Up @@ -388,16 +388,23 @@ trailing ``_`` is used to check if the estimator has been fitted.

Cloning
-------
For use with the :mod:`model_selection` module,
an estimator must support the ``base.clone`` function to replicate an estimator.
This can be done by providing a ``get_params`` method.
If ``get_params`` is present, then ``clone(estimator)`` will be an instance of
``type(estimator)`` on which ``set_params`` has been called with clones of
the result of ``estimator.get_params()``.
For use with the :mod:`model_selection` module, an estimator must support the
:func:`base.clone` function to replicate an estimator. This can be done by
providing a ``get_params`` method. If ``get_params`` is present, then
``clone(estimator)`` will be an instance of ``type(estimator)`` on which
``set_params`` has been called with clones of the result of
``estimator.get_params()``.

Objects that do not provide this method will either:

Objects that do not provide this method will be deep-copied
(using the Python standard function ``copy.deepcopy``)
if ``safe=False`` is passed to ``clone``.
* be deep-copied (using the Python standard function ``copy.deepcopy``) if
``safe=False`` is passed to :func:`base.clone`.

* be reassigned if the object is an estimator parameter declared to be
immutable using the estimator tag `immutable_params` (refer to
:ref:`estimator_tags` below). This behavior is interesting when a parameter
is known to be constant (i.e. not be changed during `fit`) and that a
deepcopy will be memory or computationally expensive.

Pipeline compatibility
----------------------
Expand Down Expand Up @@ -488,6 +495,11 @@ binary_only (default=``False``)
whether estimator supports binary classification but lacks multi-class
classification support.

immutable_params (default=None)
list of estimator parameters which are declared to be immutable. These
parameters will not be deep copied when calling :func:`base.clone` but only
reassigned.

multilabel (default=``False``)
whether the estimator supports multilabel output

Expand Down Expand Up @@ -567,10 +579,10 @@ closed-form solutions.
Coding guidelines
=================

The following are some guidelines on how new code should be written for
inclusion in scikit-learn, and which may be appropriate to adopt in external
projects. Of course, there are special cases and there will be exceptions to
these rules. However, following these rules when submitting new code makes
The following are some guidelines on how new code should be written for
inclusion in scikit-learn, and which may be appropriate to adopt in external
projects. Of course, there are special cases and there will be exceptions to
these rules. However, following these rules when submitting new code makes
the review easier so new code can be integrated in less time.

Uniformly formatted code makes it easier to share code ownership. The
Expand Down
123 changes: 102 additions & 21 deletions sklearn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,26 +33,97 @@
'_skip_test': False,
'multioutput_only': False,
'binary_only': False,
'requires_fit': True}
'requires_fit': True,
'immutable_params': None,
}


def clone(estimator, safe=True):
def clone(estimator, safe=True, deepcopy=True):
"""Constructs a new estimator with the same parameters.

Clone does a deep copy of the model in an estimator
without actually copying attached data. It yields a new estimator
with the same parameters that has not been fit on any data.
Clone does a deep copy of the model in an estimator without actually
copying attached data. It yields a new estimator with the same parameters
that has not been fit on any data.

In case an estimator declared some parameters to be immutable (via the
`'immutable_params'` tag), these paremeters will not be deep copied but
only reassigned.

Parameters
----------
estimator : estimator object, or list, tuple or set of objects
The estimator or group of estimators to be cloned
estimator : {list, tuple, set, frozenset} of estimator object or \
estimator object
The estimator or group of estimators to be cloned.

safe : bool, default=True
If safe is false, clone will fall back to a deep copy on objects
that are not estimators.
If `safe` is `False`, objects that are not estimators will be either
deep copied if `deepcopy=True` or just be passed by assignment if
`deepcopy=False`.

deepcopy : bool, default=True
Whether or not to trigger a deep copy of object which are not
estimator.

Returns
-------
new_object : estimator object
A new unfitted estimator.

Examples
--------
>>> from sklearn.base import clone
>>> from sklearn.exceptions import NotFittedError
>>> from sklearn.utils.validation import check_is_fitted
>>> from sklearn.datasets import make_classification
>>> from sklearn.linear_model import LogisticRegression
>>> X, y = make_classification(random_state=42)

We create a logistic regression and train it on some data. Once trained,
we can check that the classifier is fitted and that he has an attribute
`coef_`.

>>> clf = LogisticRegression().fit(X, y)
>>> check_is_fitted(clf)
>>> hasattr(clf, "coef_")
True

Now, we create a clone of this classifier. The role :func:`~clone` is
to make a copy of `clf`, removing all attributes which have been created
during `fit`.

>>> clf_clone = clone(clf)

In addition, `clf_clone` and its parameters will be copied from the
original `clf`.

>>> clf_clone is clf
False
>>> try:
... check_is_fitted(clf_clone)
... except NotFittedError as e:
... print(e)
This LogisticRegression instance is not fitted yet. Call 'fit' with
appropriate arguments before using this estimator.

To avoid making a copy of the parameter estimator which are immutable
parameters, you can use the estimator tag `immutable_params`. For instance,
:class:`sklearn.feature_extraction.textTfidfVectorizer` declared the
parameter `vocabulary` to be immutable. Thus, this parameter will not be
copied when calling :func:`~clone`:

>>> from sklearn.feature_extraction.text import TfidfVectorizer
>>> vectorizer = TfidfVectorizer(
... vocabulary={'g': 0, 'a': 1, 't': 2, 'c': 3},
... stop_words=["and", "maybe"]
... )
>>> vectorizer_clone = clone(vectorizer)
>>> vectorizer_clone.vocabulary is vectorizer.vocabulary
True
>>> vectorizer_clone.stop_words is vectorizer.stop_words
False
"""
if not safe and not deepcopy:
return estimator
estimator_type = type(estimator)
# XXX: not handling dictionaries
if estimator_type in (list, tuple, set, frozenset):
Expand All @@ -62,20 +133,29 @@ def clone(estimator, safe=True):
return copy.deepcopy(estimator)
else:
if isinstance(estimator, type):
raise TypeError("Cannot clone object. " +
"You should provide an instance of " +
"scikit-learn estimator instead of a class.")
raise TypeError(
"Cannot clone object. You should provide an instance of "
"scikit-learn estimator instead of a class."
)
else:
raise TypeError("Cannot clone object '%s' (type %s): "
"it does not seem to be a scikit-learn "
"estimator as it does not implement a "
"'get_params' method."
% (repr(estimator), type(estimator)))
raise TypeError(
f"Cannot clone object '{repr(estimator)}' (type "
f"{type(estimator)}): it does not seem to be a "
"scikit-learn estimator as it does not implement a "
"'get_params' method."
)

klass = estimator.__class__
new_object_params = estimator.get_params(deep=False)
tag = "immutable_params"
immutable_params = getattr(
estimator, "_get_tags", lambda: _DEFAULT_TAGS)().get(
tag, _DEFAULT_TAGS[tag])
for name, param in new_object_params.items():
new_object_params[name] = clone(param, safe=False)
if immutable_params is not None and name in immutable_params:
new_object_params[name] = clone(param, safe=False, deepcopy=False)
else:
new_object_params[name] = clone(param, safe=False, deepcopy=True)
new_object = klass(**new_object_params)
params_set = new_object.get_params(deep=False)

Expand All @@ -84,9 +164,10 @@ def clone(estimator, safe=True):
param1 = new_object_params[name]
param2 = params_set[name]
if param1 is not param2:
raise RuntimeError('Cannot clone object %s, as the constructor '
'either does not set or modifies parameter %s' %
(estimator, name))
raise RuntimeError(
f'Cannot clone object {estimator}, as the constructor '
f'either does not set or modifies parameter {name}'
)
return new_object


Expand Down
2 changes: 1 addition & 1 deletion sklearn/feature_extraction/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -1296,7 +1296,7 @@ def get_feature_names(self):
key=itemgetter(1))]

def _more_tags(self):
return {'X_types': ['string']}
return {'X_types': ['string'], 'immutable_params': ['vocabulary']}


def _make_int_array():
Expand Down
44 changes: 44 additions & 0 deletions sklearn/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,25 @@ def __init__(self, *vargs):
pass


class EstimatorImmutableParams(BaseEstimator):
"""Estimator which will declare parameters as immutable."""
def __init__(self, mutable_list, mutable_dict, mutable_estimator,
immutable_list, immutable_dict, immutable_estimator):
self.mutable_list = mutable_list
self.mutable_dict = mutable_dict
self.mutable_estimator = mutable_estimator

self.immutable_list = immutable_list
self.immutable_dict = immutable_dict
self.immutable_estimator = immutable_estimator

def _more_tags(self):
return {
"immutable_params":
["immutable_list", "immutable_dict", "immutable_estimator"]
}


#############################################################################
# The tests

Expand Down Expand Up @@ -204,6 +223,31 @@ def test_clone_class_rather_than_instance():
clone(MyEstimator)


def test_clone_immutable_params():
estimator = EstimatorImmutableParams(
mutable_list=[0, 1, 2],
mutable_dict={'a': 0, 'b': 1, 'c': 2},
mutable_estimator=MyEstimator(),
immutable_list=[0, 1, 2],
immutable_dict={'a': 0, 'b': 1, 'c': 2},
immutable_estimator=MyEstimator(),
)
estimator_2 = clone(estimator)

mutable_params_name = list(
set(estimator.get_params(deep=False).keys()) -
set(estimator._get_tags()["immutable_params"])
)
immutable_params_name = estimator._get_tags()["immutable_params"]

# no copy should be done for the declared immutable params
for param in immutable_params_name:
assert getattr(estimator, param) is getattr(estimator_2, param)
# otherwise we should take a copy or deep copy
for param in mutable_params_name:
assert getattr(estimator, param) is not getattr(estimator_2, param)


def test_repr():
# Smoke test the repr of the base estimator.
my_estimator = MyEstimator()
Expand Down
0