From 8c5252bf28c14bb37464a05cf1870e63f87f5148 Mon Sep 17 00:00:00 2001 From: Thomas J Fan Date: Sat, 21 Dec 2019 18:37:41 -0500 Subject: [PATCH 01/13] ENH Adds attributes back to check_is_fitted --- sklearn/utils/tests/test_validation.py | 20 ++++++++++++++++++++ sklearn/utils/validation.py | 22 ++++++++++++++++++---- 2 files changed, 38 insertions(+), 4 deletions(-) diff --git a/sklearn/utils/tests/test_validation.py b/sklearn/utils/tests/test_validation.py index f121f11658051..b800b61964eff 100644 --- a/sklearn/utils/tests/test_validation.py +++ b/sklearn/utils/tests/test_validation.py @@ -5,6 +5,7 @@ from tempfile import NamedTemporaryFile from itertools import product +from operator import itemgetter import pytest from pytest import importorskip @@ -678,6 +679,25 @@ def test_check_is_fitted(): assert check_is_fitted(svr) is None +@pytest.mark.parametrize("wrap", + [itemgetter(0), list, tuple], + ids=["single", "list", "tuple"]) +def test_check_is_fitted_with_attributes(wrap): + + ard = ARDRegression() + with pytest.raises(NotFittedError, match="is not fitted yet"): + check_is_fitted(ard, wrap(["coef_"])) + + ard.fit(*make_blobs()) + + # Does not raise + check_is_fitted(ard, wrap(["coef_"])) + + # Raises when using attribute that is not defined + with pytest.raises(NotFittedError, match="is not fitted yet"): + check_is_fitted(ard, wrap(["coef_bad_"])) + + def test_check_consistent_length(): check_consistent_length([1], [2], [3], [4], [5]) check_consistent_length([[1, 2], [[1, 2]]], [1, 2], ['a', 'b']) diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index 5502fdd534965..22740d0f7f8cc 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -850,7 +850,7 @@ def check_symmetric(array, tol=1E-10, raise_warning=True, return array -def check_is_fitted(estimator, msg=None): +def check_is_fitted(estimator, attributes=None, msg=None): """Perform is_fitted validation for estimator. Checks if the estimator is fitted by verifying the presence of @@ -862,6 +862,15 @@ def check_is_fitted(estimator, msg=None): estimator : estimator instance. estimator instance for which the check is performed. + attributes : list or tuple of str or None, default=None + attribute name(s) given as string or a list/tuple of strings + Eg.: + ``["coef_", "estimator_", ...], "coef_"`` + + If None, `estimator` is considered fitted if there exist an attribute + that starts or ends with a underscore and does not start + with double underscore. + msg : string The default error message is, "This %(name)s instance is not fitted yet. Call 'fit' with appropriate arguments before using this @@ -890,9 +899,14 @@ def check_is_fitted(estimator, msg=None): if not hasattr(estimator, 'fit'): raise TypeError("%s is not an estimator instance." % (estimator)) - attrs = [v for v in vars(estimator) - if (v.endswith("_") or v.startswith("_")) - and not v.startswith("__")] + if attributes is not None: + if not isinstance(attributes, (list, tuple)): + attributes = [attributes] + attrs = all([hasattr(estimator, attr) for attr in attributes]) + else: + attrs = [v for v in vars(estimator) + if (v.endswith("_") or v.startswith("_")) + and not v.startswith("__")] if not attrs: raise NotFittedError(msg % {'name': type(estimator).__name__}) From 9c970faad0cb2e23ae7ee4596dfd68726b7ac2f5 Mon Sep 17 00:00:00 2001 From: Thomas J Fan Date: Sat, 21 Dec 2019 18:41:45 -0500 Subject: [PATCH 02/13] DOC Updates docstring --- sklearn/utils/tests/test_validation.py | 1 - sklearn/utils/validation.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/sklearn/utils/tests/test_validation.py b/sklearn/utils/tests/test_validation.py index b800b61964eff..4d4ed3f0a0499 100644 --- a/sklearn/utils/tests/test_validation.py +++ b/sklearn/utils/tests/test_validation.py @@ -683,7 +683,6 @@ def test_check_is_fitted(): [itemgetter(0), list, tuple], ids=["single", "list", "tuple"]) def test_check_is_fitted_with_attributes(wrap): - ard = ARDRegression() with pytest.raises(NotFittedError, match="is not fitted yet"): check_is_fitted(ard, wrap(["coef_"])) diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index 22740d0f7f8cc..132c9c7561992 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -862,7 +862,7 @@ def check_is_fitted(estimator, attributes=None, msg=None): estimator : estimator instance. estimator instance for which the check is performed. - attributes : list or tuple of str or None, default=None + attributes : str, list or tuple of str or None, default=None attribute name(s) given as string or a list/tuple of strings Eg.: ``["coef_", "estimator_", ...], "coef_"`` From 699dcfe553eb38e1a5ce353cff40e61fcbc67875 Mon Sep 17 00:00:00 2001 From: Thomas J Fan Date: Sat, 21 Dec 2019 19:34:42 -0500 Subject: [PATCH 03/13] STY DOC Fix --- sklearn/utils/validation.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index 132c9c7561992..9e7c65e3b3c01 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -863,9 +863,8 @@ def check_is_fitted(estimator, attributes=None, msg=None): estimator instance for which the check is performed. attributes : str, list or tuple of str or None, default=None - attribute name(s) given as string or a list/tuple of strings - Eg.: - ``["coef_", "estimator_", ...], "coef_"`` + Attribute name(s) given as string or a list/tuple of strings + Eg.: ``["coef_", "estimator_", ...], "coef_"`` If None, `estimator` is considered fitted if there exist an attribute that starts or ends with a underscore and does not start From 9d776ff9faff869687464f987dfc5f300a1e6edb Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Tue, 24 Dec 2019 10:19:23 +0100 Subject: [PATCH 04/13] Stop looking for fit attributes starting with _ in check_is_fitted --- sklearn/utils/validation.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index 9e7c65e3b3c01..3adb4622f2f04 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -862,12 +862,12 @@ def check_is_fitted(estimator, attributes=None, msg=None): estimator : estimator instance. estimator instance for which the check is performed. - attributes : str, list or tuple of str or None, default=None + attributes : str, list or tuple of str, default=None Attribute name(s) given as string or a list/tuple of strings Eg.: ``["coef_", "estimator_", ...], "coef_"`` - If None, `estimator` is considered fitted if there exist an attribute - that starts or ends with a underscore and does not start + If `None`, `estimator` is considered fitted if there exist an + attribute that starts or ends with a underscore and does not start with double underscore. msg : string @@ -904,8 +904,7 @@ def check_is_fitted(estimator, attributes=None, msg=None): attrs = all([hasattr(estimator, attr) for attr in attributes]) else: attrs = [v for v in vars(estimator) - if (v.endswith("_") or v.startswith("_")) - and not v.startswith("__")] + if v.endswith("_") and not v.startswith("__")] if not attrs: raise NotFittedError(msg % {'name': type(estimator).__name__}) From 3122da321c2e65896dc4239d7f61ffc405675eac Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Tue, 24 Dec 2019 10:20:15 +0100 Subject: [PATCH 05/13] Update TfidfVectorize/Transformer to leverage the new check_is_fitted introspection --- sklearn/feature_extraction/tests/test_text.py | 1 + sklearn/feature_extraction/text.py | 8 ++++++-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/sklearn/feature_extraction/tests/test_text.py b/sklearn/feature_extraction/tests/test_text.py index 25c353aae5276..f8f741a862594 100644 --- a/sklearn/feature_extraction/tests/test_text.py +++ b/sklearn/feature_extraction/tests/test_text.py @@ -1099,6 +1099,7 @@ def test_vectorizer_string_object_as_input(Vectorizer): assert_raise_message( ValueError, message, vec.fit_transform, "hello world!") assert_raise_message(ValueError, message, vec.fit, "hello world!") + vec.fit(["some text", "some other text"]) assert_raise_message(ValueError, message, vec.transform, "hello world!") diff --git a/sklearn/feature_extraction/text.py b/sklearn/feature_extraction/text.py index 2d8f7d840c55b..851bc29e336d9 100644 --- a/sklearn/feature_extraction/text.py +++ b/sklearn/feature_extraction/text.py @@ -1498,7 +1498,11 @@ def transform(self, X, copy=True): X.data += 1 if self.use_idf: - check_is_fitted(self, msg='idf vector is not fitted') + # idf_ being a property, the automatic attributes detection + # does not work as usual and we need to specify the attribute + # name: + check_is_fitted(self, attributes=["idf_"], + msg='idf vector is not fitted') expected_n_features = self._idf_diag.shape[0] if n_features != expected_n_features: @@ -1883,7 +1887,7 @@ def transform(self, raw_documents, copy="deprecated"): X : sparse matrix, [n_samples, n_features] Tf-idf-weighted document-term matrix. """ - check_is_fitted(self, msg='The tfidf vector is not fitted') + check_is_fitted(self, msg='The TF-IDF vectorizer is not fitted') # FIXME Remove copy parameter support in 0.24 if copy != "deprecated": From 129d261ab8d79e845d77051528b09e16b71fac4f Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Tue, 24 Dec 2019 11:02:11 +0100 Subject: [PATCH 06/13] DOC whats new and scope info in docstring. --- doc/whats_new/v0.22.rst | 7 +++++++ sklearn/utils/validation.py | 3 +++ 2 files changed, 10 insertions(+) diff --git a/doc/whats_new/v0.22.rst b/doc/whats_new/v0.22.rst index ae9cbbd74e313..84fd3f32d651a 100644 --- a/doc/whats_new/v0.22.rst +++ b/doc/whats_new/v0.22.rst @@ -28,6 +28,13 @@ Changelog - |Fix| :func:`utils.check_array` now correctly converts pandas DataFrame with boolean columns to floats. :pr:`15797` by `Thomas Fan`_. +- |Fix| :func:`utils.check_is_fifted` accepts back an explicit ``attributes`` + argument to check for specific attributes as explicit markers of a fitted + estimator. When no explicit attributes are provided, only the attributes + ending with a single "_" are used as "fitted" markers. This change is made to + restore some backward compatibility with the behavior of this utility in + version 0.21. :pr:`15947` by `Thomas Fan`_. + :mod:`sklearn.inspection` ......................... diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index 3adb4622f2f04..c3a1f622a8acc 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -857,6 +857,9 @@ def check_is_fitted(estimator, attributes=None, msg=None): fitted attributes (ending with a trailing underscore) and otherwise raises a NotFittedError with the given message. + This utility is meant to be used internally by estimators them-selves, + typically in their own predict / transform methods. + Parameters ---------- estimator : estimator instance. From 6dc82d7e3a25c3548c85c9a5124acb4364980c49 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Tue, 24 Dec 2019 11:41:55 +0100 Subject: [PATCH 07/13] Remove duplicated whatsnew entry (wront merge conflict resolution) --- doc/whats_new/v0.22.rst | 7 ------- 1 file changed, 7 deletions(-) diff --git a/doc/whats_new/v0.22.rst b/doc/whats_new/v0.22.rst index 763edf09187d1..83cf3ab531cb0 100644 --- a/doc/whats_new/v0.22.rst +++ b/doc/whats_new/v0.22.rst @@ -67,13 +67,6 @@ Changelog restore some backward compatibility with the behavior of this utility in version 0.21. :pr:`15947` by `Thomas Fan`_. -:mod:`sklearn.inspection` -......................... - -- |Fix| :func:`inspection.plot_partial_dependence` and - :meth:`inspection.PartialDependenceDisplay.plot` now consistently checks - the number of axes passed in. :pr:`15760` by `Thomas Fan`_. - .. _changes_0_22: Version 0.22.0 From 34505cbfda8dc0c39f1ce1ef1fcf658e908bb148 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Tue, 24 Dec 2019 11:42:22 +0100 Subject: [PATCH 08/13] Update docstring for the attributes parameter --- sklearn/utils/validation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index 9397a49a9f507..67b115314970b 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -872,8 +872,8 @@ def check_is_fitted(estimator, attributes=None, msg=None): Eg.: ``["coef_", "estimator_", ...], "coef_"`` If `None`, `estimator` is considered fitted if there exist an - attribute that starts or ends with a underscore and does not start - with double underscore. + attribute that ends with a underscore and does not start with double + underscore. msg : string The default error message is, "This %(name)s instance is not fitted From 10d8e4d0abccd5c049de4b0bf1b03a45bc39109f Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Tue, 24 Dec 2019 15:58:08 +0100 Subject: [PATCH 09/13] Re-add any_or_all --- sklearn/utils/validation.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index 67b115314970b..19fa1475bc4a3 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -852,7 +852,7 @@ def check_symmetric(array, tol=1E-10, raise_warning=True, return array -def check_is_fitted(estimator, attributes=None, msg=None): +def check_is_fitted(estimator, attributes=None, msg=None, all_or_any=all): """Perform is_fitted validation for estimator. Checks if the estimator is fitted by verifying the presence of @@ -885,6 +885,9 @@ def check_is_fitted(estimator, attributes=None, msg=None): Eg. : "Estimator, %(name)s, must be fitted before sparsifying". + all_or_any : callable, {all, any}, default all + Specify whether all or any of the given attributes must exist. + Returns ------- None @@ -906,7 +909,7 @@ def check_is_fitted(estimator, attributes=None, msg=None): if attributes is not None: if not isinstance(attributes, (list, tuple)): attributes = [attributes] - attrs = all([hasattr(estimator, attr) for attr in attributes]) + attrs = all_or_any([hasattr(estimator, attr) for attr in attributes]) else: attrs = [v for v in vars(estimator) if v.endswith("_") and not v.startswith("__")] From 6208f4811ea55f79196b6e0caf959d369a547844 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Tue, 24 Dec 2019 16:19:36 +0100 Subject: [PATCH 10/13] Add missing test for check_is_fitted with attributes --- sklearn/utils/tests/test_validation.py | 30 ++++++++++++++++++++++++-- 1 file changed, 28 insertions(+), 2 deletions(-) diff --git a/sklearn/utils/tests/test_validation.py b/sklearn/utils/tests/test_validation.py index 4d4ed3f0a0499..b298424267067 100644 --- a/sklearn/utils/tests/test_validation.py +++ b/sklearn/utils/tests/test_validation.py @@ -15,7 +15,6 @@ from sklearn.utils._testing import assert_raises from sklearn.utils._testing import assert_raises_regex from sklearn.utils._testing import assert_no_warnings -from sklearn.utils._testing import assert_warns_message from sklearn.utils._testing import assert_warns from sklearn.utils._testing import ignore_warnings from sklearn.utils._testing import SkipTest @@ -51,7 +50,6 @@ import sklearn from sklearn.exceptions import NotFittedError, PositiveSpectrumWarning -from sklearn.exceptions import DataConversionWarning from sklearn.utils._testing import assert_raise_message from sklearn.utils._testing import TempMemmap @@ -679,6 +677,34 @@ def test_check_is_fitted(): assert check_is_fitted(svr) is None +def test_check_is_fitted_attributes(): + class MyEstimator(): + def fit(self, X, y): + return self + + msg = "not fitted" + est = MyEstimator() + + with pytest.raises(NotFittedError, match=msg): + check_is_fitted(est, attributes=["a_", "b_"]) + with pytest.raises(NotFittedError, match=msg): + check_is_fitted(est, attributes=["a_", "b_"], all_or_any=all) + with pytest.raises(NotFittedError, match=msg): + check_is_fitted(est, attributes=["a_", "b_"], all_or_any=any) + + est.a_ = "a" + with pytest.raises(NotFittedError, match=msg): + check_is_fitted(est, attributes=["a_", "b_"]) + with pytest.raises(NotFittedError, match=msg): + check_is_fitted(est, attributes=["a_", "b_"], all_or_any=all) + check_is_fitted(est, attributes=["a_", "b_"], all_or_any=any) + + est.b_ = "b" + check_is_fitted(est, attributes=["a_", "b_"]) + check_is_fitted(est, attributes=["a_", "b_"], all_or_any=all) + check_is_fitted(est, attributes=["a_", "b_"], all_or_any=any) + + @pytest.mark.parametrize("wrap", [itemgetter(0), list, tuple], ids=["single", "list", "tuple"]) From 24e5011627de706207ddf6e8a301670ef53c5d43 Mon Sep 17 00:00:00 2001 From: Roman Yurchak Date: Fri, 27 Dec 2019 10:23:39 +0100 Subject: [PATCH 11/13] DOC Address comments --- doc/whats_new/v0.22.rst | 2 +- sklearn/utils/validation.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/doc/whats_new/v0.22.rst b/doc/whats_new/v0.22.rst index 83cf3ab531cb0..da2f2a956f2e7 100644 --- a/doc/whats_new/v0.22.rst +++ b/doc/whats_new/v0.22.rst @@ -62,7 +62,7 @@ Changelog - |Fix| :func:`utils.check_is_fifted` accepts back an explicit ``attributes`` argument to check for specific attributes as explicit markers of a fitted - estimator. When no explicit attributes are provided, only the attributes + estimator. When no explicit ``attributes`` are provided, only the attributes ending with a single "_" are used as "fitted" markers. This change is made to restore some backward compatibility with the behavior of this utility in version 0.21. :pr:`15947` by `Thomas Fan`_. diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index 19fa1475bc4a3..2248389d0b3b1 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -859,7 +859,7 @@ def check_is_fitted(estimator, attributes=None, msg=None, all_or_any=all): fitted attributes (ending with a trailing underscore) and otherwise raises a NotFittedError with the given message. - This utility is meant to be used internally by estimators them-selves, + This utility is meant to be used internally by estimators themselves, typically in their own predict / transform methods. Parameters From 111ee041644add5b44e0522ec72739f8d16892fe Mon Sep 17 00:00:00 2001 From: Roman Yurchak Date: Fri, 27 Dec 2019 10:31:47 +0100 Subject: [PATCH 12/13] DOC Mention all_or_any in what's new --- doc/whats_new/v0.22.rst | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/doc/whats_new/v0.22.rst b/doc/whats_new/v0.22.rst index da2f2a956f2e7..8f8f646676e9d 100644 --- a/doc/whats_new/v0.22.rst +++ b/doc/whats_new/v0.22.rst @@ -63,7 +63,8 @@ Changelog - |Fix| :func:`utils.check_is_fifted` accepts back an explicit ``attributes`` argument to check for specific attributes as explicit markers of a fitted estimator. When no explicit ``attributes`` are provided, only the attributes - ending with a single "_" are used as "fitted" markers. This change is made to + ending with a single "_" are used as "fitted" markers. The ``all_or_any`` + argument is also no longer deprecated. This change is made to restore some backward compatibility with the behavior of this utility in version 0.21. :pr:`15947` by `Thomas Fan`_. From be86ffabfb40ded6f4cc706b40078643f3fe4ebf Mon Sep 17 00:00:00 2001 From: Roman Yurchak Date: Fri, 27 Dec 2019 10:33:13 +0100 Subject: [PATCH 13/13] DOC Another typo in what's new --- doc/whats_new/v0.22.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/whats_new/v0.22.rst b/doc/whats_new/v0.22.rst index 8f8f646676e9d..a8718152de1a9 100644 --- a/doc/whats_new/v0.22.rst +++ b/doc/whats_new/v0.22.rst @@ -60,7 +60,7 @@ Changelog - |Fix| :func:`utils.check_array` now correctly converts pandas DataFrame with boolean columns to floats. :pr:`15797` by `Thomas Fan`_. -- |Fix| :func:`utils.check_is_fifted` accepts back an explicit ``attributes`` +- |Fix| :func:`utils.check_is_fitted` accepts back an explicit ``attributes`` argument to check for specific attributes as explicit markers of a fitted estimator. When no explicit ``attributes`` are provided, only the attributes ending with a single "_" are used as "fitted" markers. The ``all_or_any``