8000 BUG Adds attributes back to check_is_fitted (#15947) · scikit-learn/scikit-learn@9accce5 · GitHub
[go: up one dir, main page]

Skip to content

Commit 9accce5

Browse files
thomasjpfanogrisel
authored andcommitted
BUG Adds attributes back to check_is_fitted (#15947)
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
1 parent d163d5a commit 9accce5

File tree

5 files changed

+84
-8
lines changed

5 files changed

+84
-8
lines changed

doc/whats_new/v0.22.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,14 @@ Changelog
6868
- |Fix| :func:`utils.check_array` now correctly converts pandas DataFrame with
6969
boolean columns to floats. :pr:`15797` by `Thomas Fan`_.
7070

71+
- |Fix| :func:`utils.check_is_fitted` accepts back an explicit ``attributes``
72+
argument to check for specific attributes as explicit markers of a fitted
73+
estimator. When no explicit ``attributes`` are provided, only the attributes
74+
ending with a single "_" are used as "fitted" markers. The ``all_or_any``
75+
argument is also no longer deprecated. This change is made to
76+
restore some backward compatibility with the behavior of this utility in
77+
version 0.21. :pr:`15947` by `Thomas Fan`_.
78+
7179
.. _changes_0_22:
7280

7381
Version 0.22.0

sklearn/feature_extraction/tests/test_text.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1099,6 +1099,7 @@ def test_vectorizer_string_object_as_input(Vectorizer):
10991099
assert_raise_message(
11001100
ValueError, message, vec.fit_transform, "hello world!")
11011101
assert_raise_message(ValueError, message, vec.fit, "hello world!")
1102+
vec.fit(["some text", "some other text"])
11021103
assert_raise_message(ValueError, message, vec.transform, "hello world!")
11031104

11041105

sklearn/feature_extraction/text.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1498,7 +1498,11 @@ def transform(self, X, copy=True):
14981498
X.data += 1
14991499

15001500
if self.use_idf:
1501-
check_is_fitted(self, msg='idf vector is not fitted')
1501+
# idf_ being a property, the automatic attributes detection
1502+
# does not work as usual and we need to specify the attribute
1503+
# name:
1504+
check_is_fitted(self, attributes=["idf_"],
1505+
msg='idf vector is not fitted')
15021506

15031507
expected_n_features = self._idf_diag.shape[0]
15041508
if n_features != expected_n_features:
@@ -1883,7 +1887,7 @@ def transform(self, raw_documents, copy="deprecated"):
18831887
X : sparse matrix, [n_samples, n_features]
18841888
Tf-idf-weighted document-term matrix.
18851889
"""
1886-
check_is_fitted(self, msg='The tfidf vector is not fitted')
1890+
check_is_fitted(self, msg='The TF-IDF vectorizer is not fitted')
18871891

18881892
# FIXME Remove 10000 copy parameter support in 0.24
18891893
if copy != "deprecated":

sklearn/utils/tests/test_validation.py

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from tempfile import NamedTemporaryFile
77
from itertools import product
8+
from operator import itemgetter
89

910
import pytest
1011
from pytest import importorskip
@@ -14,7 +15,6 @@
1415
from sklearn.utils._testing import assert_raises
1516
from sklearn.utils._testing import assert_raises_regex
1617
from sklearn.utils._testing import assert_no_warnings
17-
from sklearn.utils._testing import assert_warns_message
1818
from sklearn.utils._testing import assert_warns
1919
from sklearn.utils._testing import ignore_warnings
2020
from sklearn.utils._testing import SkipTest
@@ -50,7 +50,6 @@
5050
import sklearn
5151

5252
from sklearn.exceptions import NotFittedError, PositiveSpectrumWarning
53-
from sklearn.exceptions import DataConversionWarning
5453

5554
from sklearn.utils._testing import assert_raise_message
5655
from sklearn.utils._testing import TempMemmap
@@ -678,6 +677,52 @@ def test_check_is_fitted():
678677
assert check_is_fitted(svr) is None
679678

680679

680+
def test_check_is_fitted_attributes():
681+
class MyEstimator():
682+
def fit(self, X, y):
683+
return self
684+
685+
msg = "not fitted"
686+
est = MyEstimator()
687+
688+
with pytest.raises(NotFittedError, match=msg):
689+
check_is_fitted(est, attributes=["a_", "b_"])
690+
with pytest.raises(NotFittedError, match=msg):
691+
check_is_fitted(est, attributes=["a_", "b_"], all_or_any=all)
692+
with pytest.raises(NotFittedError, match=msg):
693+
check_is_fitted(est, attributes=["a_", "b_"], all_or_any=any)
694+
695+
est.a_ = "a"
696+
with pytest.raises(NotFittedError, match=msg):
697+
check_is_fitted(est, attributes=["a_", "b_"])
698+
with pytest.raises(NotFittedError, match=msg):
699+
check_is_fitted(est, attributes=["a_", "b_"], all_or_any=all)
700+
check_is_fitted(est, attributes=["a_", "b_"], all_or_any=any)
701+
702+
est.b_ = "b"
703+
check_is_fitted(est, attributes=["a_", "b_"])
704+
check_is_fitted(est, attributes=["a_", "b_"], all_or_any=all)
705+
check_is_fitted(est, attributes=["a_", "b_"], all_or_any=any)
706+
707+
708+
@pytest.mark.parametrize("wrap",
709+
[itemgetter(0), list, tuple],
710+
ids=["single", "list", "tuple"])
711+
def test_check_is_fitted_with_attributes(wrap):
712+
ard = ARDRegression()
713+
with pytest.raises(NotFittedError, match="is not fitted yet"):
714+
check_is_fitted(ard, wrap(["coef_"]))
715+
716+
ard.fit(*make_blobs())
717+
718+
# Does not raise
719+
check_is_fitted(ard, wrap(["coef_"]))
720+
721+
# Raises when using attribute that is not defined
722+
with pytest.raises(NotFittedError, match="is not fitted yet"):
723+
check_is_fitted(ard, wrap(["coef_bad_"]))
724+
725+
681726
def test_check_consistent_length():
682727
check_consistent_length([1], [2], [3], [4], [5])
683728
check_consistent_length([[1, 2], [[1, 2]]], [1, 2], ['a', 'b'])

sklearn/utils/validation.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -852,18 +852,29 @@ def check_symmetric(array, tol=1E-10, raise_warning=True,
852852
return array
853853

854854

855-
def check_is_fitted(estimator, msg=None):
855+
def check_is_fitted(estimator, attributes=None, msg=None, all_or_any=all):
856856
"""Perform is_fitted validation for estimator.
857857
858858
Checks if the estimator is fitted by verifying the presence of
859859
fitted attributes (ending with a trailing underscore) and otherwise
860860
raises a NotFittedError with the given message.
861861
862+
This utility is meant to be used internally by estimators themselves,
863+
typically in their own predict / transform methods.
864+
862865
Parameters
863866
----------
864867
estimator : estimator instance.
865868
estimator instance for which the check is performed.
866869
870+
attributes : str, list or tuple of str, default=None
871+
Attribute name(s) given as string or a list/tuple of strings
872+
Eg.: ``["coef_", "estimator_", ...], "coef_"``
873+
874+
If `None`, `estimator` is considered fitted if there exist an
875+
attribute that ends with a underscore and does not start with double
876+
underscore.
877+
867878
msg : string
868879
The default error message is, "This %(name)s instance is not fitted
869880
yet. Call 'fit' with appropriate arguments before using this
@@ -874,6 +885,9 @@ def check_is_fitted(estimator, msg=None):
874885
875886
Eg. : "Estimator, %(name)s, must be fitted before sparsifying".
876887
888+
all_or_any : callable, {all, any}, default all
889+
Specify whether all or any of the given attributes must exist.
890+
877891
Returns
878892
-------
879893
None
@@ -892,9 +906,13 @@ def check_is_fitted(estimator, msg=None):
892906
if not hasattr(estimator, 'fit'):
893907
raise TypeError("%s is not an estimator instance." % (estimator))
894908

895-
attrs = [v for v in vars(estimator)
896-
if (v.endswith("_") or v.startswith("_"))
897-
and not v.startswith("__")]
909+
if attributes is not None:
910+
if not isinstance(attributes, (list, tuple)):
911+
attributes = [attributes]
912+
attrs = all_or_any([hasattr(estimator, attr) for attr in attributes])
913+
else:
914+
attrs = [v for v in vars(estimator)
915+
if v.endswith("_") and not v.startswith("__")]
898916

899917
if not attrs:
900918
raise NotFittedError(msg % {'name': type(estimator).__name__})

0 commit comments

Comments
 (0)
0