diff --git a/doc/developers/develop.rst b/doc/developers/develop.rst index 5e11f46eccdb8..3d3def795a7ba 100644 --- a/doc/developers/develop.rst +++ b/doc/developers/develop.rst @@ -5,9 +5,9 @@ Developing scikit-learn estimators ================================== Whether you are proposing an estimator for inclusion in scikit-learn, -developing a separate package compatible with scikit-learn, or -implementing custom components for your own projects, this chapter -details how to develop objects that safely interact with scikit-learn +developing a separate package compatible with scikit-learn, or +implementing custom components for your own projects, this chapter +details how to develop objects that safely interact with scikit-learn Pipelines and model selection tools. .. currentmodule:: sklearn @@ -266,7 +266,7 @@ the name of the underlying estimator and check in the test name. This allows `pytest -k` to be used to specify which tests to run. .. code-block: bash - + pytest test_check_estimators.py -k check_estimators_fit_returns_self Before detailing the required interface below, we describe two ways to achieve @@ -388,16 +388,23 @@ trailing ``_`` is used to check if the estimator has been fitted. Cloning ------- -For use with the :mod:`model_selection` module, -an estimator must support the ``base.clone`` function to replicate an estimator. -This can be done by providing a ``get_params`` method. -If ``get_params`` is present, then ``clone(estimator)`` will be an instance of -``type(estimator)`` on which ``set_params`` has been called with clones of -the result of ``estimator.get_params()``. +For use with the :mod:`model_selection` module, an estimator must support the +:func:`base.clone` function to replicate an estimator. This can be done by +providing a ``get_params`` method. If ``get_params`` is present, then +``clone(estimator)`` will be an instance of ``type(estimator)`` on which +``set_params`` has been called with clones of the result of +``estimator.get_params()``. + +Objects that do not provide this method will either: -Objects that do not provide this method will be deep-copied -(using the Python standard function ``copy.deepcopy``) -if ``safe=False`` is passed to ``clone``. +* be deep-copied (using the Python standard function ``copy.deepcopy``) if + ``safe=False`` is passed to :func:`base.clone`. + +* be reassigned if the object is an estimator parameter declared to be + immutable using the estimator tag `immutable_params` (refer to + :ref:`estimator_tags` below). This behavior is interesting when a parameter + is known to be constant (i.e. not be changed during `fit`) and that a + deepcopy will be memory or computationally expensive. Pipeline compatibility ---------------------- @@ -488,6 +495,11 @@ binary_only (default=``False``) whether estimator supports binary classification but lacks multi-class classification support. +immutable_params (default=None) + list of estimator parameters which are declared to be immutable. These + parameters will not be deep copied when calling :func:`base.clone` but only + reassigned. + multilabel (default=``False``) whether the estimator supports multilabel output @@ -567,10 +579,10 @@ closed-form solutions. Coding guidelines ================= -The following are some guidelines on how new code should be written for -inclusion in scikit-learn, and which may be appropriate to adopt in external -projects. Of course, there are special cases and there will be exceptions to -these rules. However, following these rules when submitting new code makes +The following are some guidelines on how new code should be written for +inclusion in scikit-learn, and which may be appropriate to adopt in external +projects. Of course, there are special cases and there will be exceptions to +these rules. However, following these rules when submitting new code makes the review easier so new code can be integrated in less time. Uniformly formatted code makes it easier to share code ownership. The diff --git a/sklearn/base.py b/sklearn/base.py index be329c196abb5..2170fe6bd3a0c 100644 --- a/sklearn/base.py +++ b/sklearn/base.py @@ -33,26 +33,97 @@ '_skip_test': False, 'multioutput_only': False, 'binary_only': False, - 'requires_fit': True} + 'requires_fit': True, + 'immutable_params': None, +} -def clone(estimator, safe=True): +def clone(estimator, safe=True, deepcopy=True): """Constructs a new estimator with the same parameters. - Clone does a deep copy of the model in an estimator - without actually copying attached data. It yields a new estimator - with the same parameters that has not been fit on any data. + Clone does a deep copy of the model in an estimator without actually + copying attached data. It yields a new estimator with the same parameters + that has not been fit on any data. + + In case an estimator declared some parameters to be immutable (via the + `'immutable_params'` tag), these paremeters will not be deep copied but + only reassigned. Parameters ---------- - estimator : estimator object, or list, tuple or set of objects - The estimator or group of estimators to be cloned + estimator : {list, tuple, set, frozenset} of estimator object or \ + estimator object + The estimator or group of estimators to be cloned. safe : bool, default=True - If safe is false, clone will fall back to a deep copy on objects - that are not estimators. + If `safe` is `False`, objects that are not estimators will be either + deep copied if `deepcopy=True` or just be passed by assignment if + `deepcopy=False`. + + deepcopy : bool, default=True + Whether or not to trigger a deep copy of object which are not + estimator. + Returns + ------- + new_object : estimator object + A new unfitted estimator. + + Examples + -------- + >>> from sklearn.base import clone + >>> from sklearn.exceptions import NotFittedError + >>> from sklearn.utils.validation import check_is_fitted + >>> from sklearn.datasets import make_classification + >>> from sklearn.linear_model import LogisticRegression + >>> X, y = make_classification(random_state=42) + + We create a logistic regression and train it on some data. Once trained, + we can check that the classifier is fitted and that he has an attribute + `coef_`. + + >>> clf = LogisticRegression().fit(X, y) + >>> check_is_fitted(clf) + >>> hasattr(clf, "coef_") + True + + Now, we create a clone of this classifier. The role :func:`~clone` is + to make a copy of `clf`, removing all attributes which have been created + during `fit`. + + >>> clf_clone = clone(clf) + + In addition, `clf_clone` and its parameters will be copied from the + original `clf`. + + >>> clf_clone is clf + False + >>> try: + ... check_is_fitted(clf_clone) + ... except NotFittedError as e: + ... print(e) + This LogisticRegression instance is not fitted yet. Call 'fit' with + appropriate arguments before using this estimator. + + To avoid making a copy of the parameter estimator which are immutable + parameters, you can use the estimator tag `immutable_params`. For instance, + :class:`sklearn.feature_extraction.textTfidfVectorizer` declared the + parameter `vocabulary` to be immutable. Thus, this parameter will not be + copied when calling :func:`~clone`: + + >>> from sklearn.feature_extraction.text import TfidfVectorizer + >>> vectorizer = TfidfVectorizer( + ... vocabulary={'g': 0, 'a': 1, 't': 2, 'c': 3}, + ... stop_words=["and", "maybe"] + ... ) + >>> vectorizer_clone = clone(vectorizer) + >>> vectorizer_clone.vocabulary is vectorizer.vocabulary + True + >>> vectorizer_clone.stop_words is vectorizer.stop_words + False """ + if not safe and not deepcopy: + return estimator estimator_type = type(estimator) # XXX: not handling dictionaries if estimator_type in (list, tuple, set, frozenset): @@ -62,20 +133,29 @@ def clone(estimator, safe=True): return copy.deepcopy(estimator) else: if isinstance(estimator, type): - raise TypeError("Cannot clone object. " + - "You should provide an instance of " + - "scikit-learn estimator instead of a class.") + raise TypeError( + "Cannot clone object. You should provide an instance of " + "scikit-learn estimator instead of a class." + ) else: - raise TypeError("Cannot clone object '%s' (type %s): " - "it does not seem to be a scikit-learn " - "estimator as it does not implement a " - "'get_params' method." - % (repr(estimator), type(estimator))) + raise TypeError( + f"Cannot clone object '{repr(estimator)}' (type " + f"{type(estimator)}): it does not seem to be a " + "scikit-learn estimator as it does not implement a " + "'get_params' method." + ) klass = estimator.__class__ new_object_params = estimator.get_params(deep=False) + tag = "immutable_params" + immutable_params = getattr( + estimator, "_get_tags", lambda: _DEFAULT_TAGS)().get( + tag, _DEFAULT_TAGS[tag]) for name, param in new_object_params.items(): - new_object_params[name] = clone(param, safe=False) + if immutable_params is not None and name in immutable_params: + new_object_params[name] = clone(param, safe=False, deepcopy=False) + else: + new_object_params[name] = clone(param, safe=False, deepcopy=True) new_object = klass(**new_object_params) params_set = new_object.get_params(deep=False) @@ -84,9 +164,10 @@ def clone(estimator, safe=True): param1 = new_object_params[name] param2 = params_set[name] if param1 is not param2: - raise RuntimeError('Cannot clone object %s, as the constructor ' - 'either does not set or modifies parameter %s' % - (estimator, name)) + raise RuntimeError( + f'Cannot clone object {estimator}, as the constructor ' + f'either does not set or modifies parameter {name}' + ) return new_object diff --git a/sklearn/feature_extraction/text.py b/sklearn/feature_extraction/text.py index 4954329728d5e..2389e05632268 100644 --- a/sklearn/feature_extraction/text.py +++ b/sklearn/feature_extraction/text.py @@ -1296,7 +1296,7 @@ def get_feature_names(self): key=itemgetter(1))] def _more_tags(self): - return {'X_types': ['string']} + return {'X_types': ['string'], 'immutable_params': ['vocabulary']} def _make_int_array(): diff --git a/sklearn/tests/test_base.py b/sklearn/tests/test_base.py index 95f7b01f27058..e054c6f956ecd 100644 --- a/sklearn/tests/test_base.py +++ b/sklearn/tests/test_base.py @@ -104,6 +104,25 @@ def __init__(self, *vargs): pass +class EstimatorImmutableParams(BaseEstimator): + """Estimator which will declare parameters as immutable.""" + def __init__(self, mutable_list, mutable_dict, mutable_estimator, + immutable_list, immutable_dict, immutable_estimator): + self.mutable_list = mutable_list + self.mutable_dict = mutable_dict + self.mutable_estimator = mutable_estimator + + self.immutable_list = immutable_list + self.immutable_dict = immutable_dict + self.immutable_estimator = immutable_estimator + + def _more_tags(self): + return { + "immutable_params": + ["immutable_list", "immutable_dict", "immutable_estimator"] + } + + ############################################################################# # The tests @@ -204,6 +223,31 @@ def test_clone_class_rather_than_instance(): clone(MyEstimator) +def test_clone_immutable_params(): + estimator = EstimatorImmutableParams( + mutable_list=[0, 1, 2], + mutable_dict={'a': 0, 'b': 1, 'c': 2}, + mutable_estimator=MyEstimator(), + immutable_list=[0, 1, 2], + immutable_dict={'a': 0, 'b': 1, 'c': 2}, + immutable_estimator=MyEstimator(), + ) + estimator_2 = clone(estimator) + + mutable_params_name = list( + set(estimator.get_params(deep=False).keys()) - + set(estimator._get_tags()["immutable_params"]) + ) + immutable_params_name = estimator._get_tags()["immutable_params"] + + # no copy should be done for the declared immutable params + for param in immutable_params_name: + assert getattr(estimator, param) is getattr(estimator_2, param) + # otherwise we should take a copy or deep copy + for param in mutable_params_name: + assert getattr(estimator, param) is not getattr(estimator_2, param) + + def test_repr(): # Smoke test the repr of the base estimator. my_estimator = MyEstimator()