diff --git a/doc/whats_new/v0.22.rst b/doc/whats_new/v0.22.rst index 370f02fda49f9..a8718152de1a9 100644 --- a/doc/whats_new/v0.22.rst +++ b/doc/whats_new/v0.22.rst @@ -60,6 +60,14 @@ Changelog - |Fix| :func:`utils.check_array` now correctly converts pandas DataFrame with boolean columns to floats. :pr:`15797` by `Thomas Fan`_. +- |Fix| :func:`utils.check_is_fitted` accepts back an explicit ``attributes`` + argument to check for specific attributes as explicit markers of a fitted + estimator. When no explicit ``attributes`` are provided, only the attributes + ending with a single "_" are used as "fitted" markers. The ``all_or_any`` + argument is also no longer deprecated. This change is made to + restore some backward compatibility with the behavior of this utility in + version 0.21. :pr:`15947` by `Thomas Fan`_. + .. _changes_0_22: Version 0.22.0 diff --git a/sklearn/feature_extraction/tests/test_text.py b/sklearn/feature_extraction/tests/test_text.py index 25c353aae5276..f8f741a862594 100644 --- a/sklearn/feature_extraction/tests/test_text.py +++ b/sklearn/feature_extraction/tests/test_text.py @@ -1099,6 +1099,7 @@ def test_vectorizer_string_object_as_input(Vectorizer): assert_raise_message( ValueError, message, vec.fit_transform, "hello world!") assert_raise_message(ValueError, message, vec.fit, "hello world!") + vec.fit(["some text", "some other text"]) assert_raise_message(ValueError, message, vec.transform, "hello world!") diff --git a/sklearn/feature_extraction/text.py b/sklearn/feature_extraction/text.py index 9771c62204444..82ba60a18da28 100644 --- a/sklearn/feature_extraction/text.py +++ b/sklearn/feature_extraction/text.py @@ -1498,7 +1498,11 @@ def transform(self, X, copy=True): X.data += 1 if self.use_idf: - check_is_fitted(self, msg='idf vector is not fitted') + # idf_ being a property, the automatic attributes detection + # does not work as usual and we need to specify the attribute + # name: + check_is_fitted(self, attributes=["idf_"], + msg='idf vector is not fitted') expected_n_features = self._idf_diag.shape[0] if n_features != expected_n_features: @@ -1883,7 +1887,7 @@ def transform(self, raw_documents, copy="deprecated"): X : sparse matrix, [n_samples, n_features] Tf-idf-weighted document-term matrix. """ - check_is_fitted(self, msg='The tfidf vector is not fitted') + check_is_fitted(self, msg='The TF-IDF vectorizer is not fitted') # FIXME Remove copy parameter support in 0.24 if copy != "deprecated": diff --git a/sklearn/utils/tests/test_validation.py b/sklearn/utils/tests/test_validation.py index f121f11658051..b298424267067 100644 --- a/sklearn/utils/tests/test_validation.py +++ b/sklearn/utils/tests/test_validation.py @@ -5,6 +5,7 @@ from tempfile import NamedTemporaryFile from itertools import product +from operator import itemgetter import pytest from pytest import importorskip @@ -14,7 +15,6 @@ from sklearn.utils._testing import assert_raises from sklearn.utils._testing import assert_raises_regex from sklearn.utils._testing import assert_no_warnings -from sklearn.utils._testing import assert_warns_message from sklearn.utils._testing import assert_warns from sklearn.utils._testing import ignore_warnings from sklearn.utils._testing import SkipTest @@ -50,7 +50,6 @@ import sklearn from sklearn.exceptions import NotFittedError, PositiveSpectrumWarning -from sklearn.exceptions import DataConversionWarning from sklearn.utils._testing import assert_raise_message from sklearn.utils._testing import TempMemmap @@ -678,6 +677,52 @@ def test_check_is_fitted(): assert check_is_fitted(svr) is None +def test_check_is_fitted_attributes(): + class MyEstimator(): + def fit(self, X, y): + return self + + msg = "not fitted" + est = MyEstimator() + + with pytest.raises(NotFittedError, match=msg): + check_is_fitted(est, attributes=["a_", "b_"]) + with pytest.raises(NotFittedError, match=msg): + check_is_fitted(est, attributes=["a_", "b_"], all_or_any=all) + with pytest.raises(NotFittedError, match=msg): + check_is_fitted(est, attributes=["a_", "b_"], all_or_any=any) + + est.a_ = "a" + with pytest.raises(NotFittedError, match=msg): + check_is_fitted(est, attributes=["a_", "b_"]) + with pytest.raises(NotFittedError, match=msg): + check_is_fitted(est, attributes=["a_", "b_"], all_or_any=all) + check_is_fitted(est, attributes=["a_", "b_"], all_or_any=any) + + est.b_ = "b" + check_is_fitted(est, attributes=["a_", "b_"]) + check_is_fitted(est, attributes=["a_", "b_"], all_or_any=all) + check_is_fitted(est, attributes=["a_", "b_"], all_or_any=any) + + +@pytest.mark.parametrize("wrap", + [itemgetter(0), list, tuple], + ids=["single", "list", "tuple"]) +def test_check_is_fitted_with_attributes(wrap): + ard = ARDRegression() + with pytest.raises(NotFittedError, match="is not fitted yet"): + check_is_fitted(ard, wrap(["coef_"])) + + ard.fit(*make_blobs()) + + # Does not raise + check_is_fitted(ard, wrap(["coef_"])) + + # Raises when using attribute that is not defined + with pytest.raises(NotFittedError, match="is not fitted yet"): + check_is_fitted(ard, wrap(["coef_bad_"])) + + def test_check_consistent_length(): check_consistent_length([1], [2], [3], [4], [5]) check_consistent_length([[1, 2], [[1, 2]]], [1, 2], ['a', 'b']) diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index e08495de30af5..2248389d0b3b1 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -852,18 +852,29 @@ def check_symmetric(array, tol=1E-10, raise_warning=True, return array -def check_is_fitted(estimator, msg=None): +def check_is_fitted(estimator, attributes=None, msg=None, all_or_any=all): """Perform is_fitted validation for estimator. Checks if the estimator is fitted by verifying the presence of fitted attributes (ending with a trailing underscore) and otherwise raises a NotFittedError with the given message. + This utility is meant to be used internally by estimators themselves, + typically in their own predict / transform methods. + Parameters ---------- estimator : estimator instance. estimator instance for which the check is performed. + attributes : str, list or tuple of str, default=None + Attribute name(s) given as string or a list/tuple of strings + Eg.: ``["coef_", "estimator_", ...], "coef_"`` + + If `None`, `estimator` is considered fitted if there exist an + attribute that ends with a underscore and does not start with double + underscore. + msg : string The default error message is, "This %(name)s instance is not fitted yet. Call 'fit' with appropriate arguments before using this @@ -874,6 +885,9 @@ def check_is_fitted(estimator, msg=None): Eg. : "Estimator, %(name)s, must be fitted before sparsifying". + all_or_any : callable, {all, any}, default all + Specify whether all or any of the given attributes must exist. + Returns ------- None @@ -892,9 +906,13 @@ def check_is_fitted(estimator, msg=None): if not hasattr(estimator, 'fit'): raise TypeError("%s is not an estimator instance." % (estimator)) - attrs = [v for v in vars(estimator) - if (v.endswith("_") or v.startswith("_")) - and not v.startswith("__")] + if attributes is not None: + if not isinstance(attributes, (list, tuple)): + attributes = [attributes] + attrs = all_or_any([hasattr(estimator, attr) for attr in attributes]) + else: + attrs = [v for v in vars(estimator) + if v.endswith("_") and not v.startswith("__")] if not attrs: raise NotFittedError(msg % {'name': type(estimator).__name__})