8000 BUG Adds attributes back to check_is_fitted by thomasjpfan · Pull Request #15947 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

BUG Adds attributes back to check_is_fitted #15947

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Dec 27, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions doc/whats_new/v0.22.rst
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,14 @@ Changelog
- |Fix| :func:`utils.check_array` now correctly converts pandas DataFrame with
boolean columns to floats. :pr:`15797` by `Thomas Fan`_.

- |Fix| :func:`utils.check_is_fitted` accepts back an explicit ``attributes``
argument to check for specific attributes as explicit markers of a fitted
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

attributes and all_or_any

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The attributes argument to check for specific attributes.

I think that's correct syntax highlighting, there should be no need to write each occurrence of attributes as attribute.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

< 8000 input type="hidden" name="authenticity_token" value="A7vkCmIIIHBwhxIN7zsAOOZxyv0emNCOtrougA_3AOduHvpYKr5YnUfCNOnlM0y8Ot7xnmoVPKo8-rSVd1nd8w" autocomplete="off" />

OK, I see your point, added a mention about all_or_any to what's new.

estimator. When no explicit ``attributes`` are provided, only the attributes
ending with a single "_" are used as "fitted" markers. The ``all_or_any``
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be that ends with a underscore and does not start with double underscore.
I'll push directly.

argument is also no longer deprecated. This change is made to
restore some backward compatibility with the behavior of this utility in
version 0.21. :pr:`15947` by `Thomas Fan`_.

.. _changes_0_22:

Version 0.22.0
Expand Down
1 change: 1 addition & 0 deletions sklearn/feature_extraction/tests/test_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -1099,6 +1099,7 @@ def test_vectorizer_string_object_as_input(Vectorizer):
assert_raise_message(
ValueError, message, vec.fit_transform, "hello world!")
assert_raise_message(ValueError, message, vec.fit, "hello world!")
vec.fit(["some text", "some other text"])
assert_raise_message(ValueError, message, vec.transform, "hello world!")


Expand Down
8 changes: 6 additions & 2 deletions sklearn/feature_extraction/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -1498,7 +1498,11 @@ def transform(self, X, copy=True):
X.data += 1

if self.use_idf:
check_is_fitted(self, msg='idf vector is not fitted')
# idf_ being a property, the automatic attributes detection
# does not work as usual and we need to specify the attribute
# name:
check_is_fitted(self, attributes=["idf_"],
msg='idf vector is not fitted')

expected_n_features = self._idf_diag.shape[0]
if n_features != expected_n_features:
Expand Down Expand Up @@ -1883,7 +1887,7 @@ def transform(self, raw_documents, copy="deprecated"):
X : sparse matrix, [n_samples, n_features]
Tf-idf-weighted document-term matrix.
"""
check_is_fitted(self, msg='The tfidf vector is not fitted')
check_is_fitted(self, msg='The TF-IDF vectorizer is not fitted')

# FIXME Remove copy parameter support in 0.24
if copy != "deprecated":
Expand Down
49 changes: 47 additions & 2 deletions sklearn/utils/tests/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from tempfile import NamedTemporaryFile
from itertools import product
from operator import itemgetter

import pytest
from pytest import importorskip
Expand All @@ -14,7 +15,6 @@
from sklearn.utils._testing import assert_raises
from sklearn.utils._testing import assert_raises_regex
from sklearn.utils._testing import assert_no_warnings
from sklearn.utils._testing import assert_warns_message
from sklearn.utils._testing import assert_warns
from sklearn.utils._testing import ignore_warnings
from sklearn.utils._testing import SkipTest
Expand Down Expand Up @@ -50,7 +50,6 @@
import sklearn

from sklearn.exceptions import NotFittedError, PositiveSpectrumWarning
from sklearn.exceptions import DataConversionWarning

from sklearn.utils._testing import assert_raise_message
from sklearn.utils._testing import TempMemmap
Expand Down Expand Up @@ -678,6 +677,52 @@ def test_check_is_fitted():
assert check_is_fitted(svr) is None


def test_check_is_fitted_attributes():
class MyEstimator():
def fit(self, X, y):
return self

msg = "not fitted"
est = MyEstimator()

with pytest.raises(NotFittedError, match=msg):
check_is_fitted(est, attributes=["a_", "b_"])
with pytest.raises(NotFittedError, match=msg):
check_is_fitted(est, attributes=["a_", "b_"], all_or_any=all)
with pytest.raises(NotFittedError, match=msg):
check_is_fitted(est, attributes=["a_", "b_"], all_or_any=any)

est.a_ = "a"
with pytest.raises(NotFittedError, match=msg):
check_is_fitted(est, attributes=["a_", "b_"])
with pytest.raises(NotFittedError, match=msg):
check_is_fitted(est, attributes=["a_", "b_"], all_or_any=all)
check_is_fitted(est, attributes=["a_", "b_"], all_or_any=any)

est.b_ = "b"
check_is_fitted(est, attributes=["a_", "b_"])
check_is_fitted(est, attributes=["a_", "b_"], all_or_any=all)
check_is_fitted(est, attributes=["a_", "b_"], all_or_any=any)


@pytest.mark.parametrize("wrap",
[itemgetter(0), list, tuple],
ids=["single", "list", "tuple"])
def test_check_is_fitted_with_attributes(wrap):
ard = ARDRegression()
with pytest.raises(NotFittedError, match="is not fitted yet"):
check_is_fitted(ard, wrap(["coef_"]))

ard.fit(*make_blobs())

# Does not raise
check_is_fitted(ard, wrap(["coef_"]))

# Raises when using attribute that is not defined
with pytest.raises(NotFittedError, match="is not fitted yet"):
check_is_fitted(ard, wrap(["coef_bad_"]))


def test_check_consistent_length():
check_consistent_length([1], [2], [3], [4], [5])
check_consistent_length([[1, 2], [[1, 2]]], [1, 2], ['a', 'b'])
Expand Down
26 changes: 22 additions & 4 deletions sklearn/utils/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -852,18 +852,29 @@ def check_symmetric(array, tol=1E-10, raise_warning=True,
return array


def check_is_fitted(estimator, msg=None):
def check_is_fitted(estimator, attributes=None, msg=None, all_or_any=all):
"""Perform is_fitted validation for estimator.

Checks if the estimator is fitted by verifying the presence of
fitted attributes (ending with a trailing underscore) and otherwise
raises a NotFittedError with the given message.

This utility is meant to be used internally by estimators themselves,
typically in their own predict / transform methods.

Parameters
----------
estimator : estimator instance.
estimator instance for which the check is performed.

attributes : str, list or tuple of str, default=None
Attribute name(s) given as string or a list/tuple of strings
Eg.: ``["coef_", "estimator_", ...], "coef_"``

If `None`, `estimator` is considered fitted if there exist an
attribute that ends with a underscore and does not start with double
underscore.

msg : string
The default error message is, "This %(name)s instance is not fitted
yet. Call 'fit' with appropriate arguments before using this
Expand All @@ -874,6 +885,9 @@ def check_is_fitted(estimator, msg=None):

Eg. : "Estimator, %(name)s, must be fitted before sparsifying".

all_or_any : callable, {all, any}, default all
Specify whether all or any of the given attributes must exist.

Returns
-------
None
Expand All @@ -892,9 +906,13 @@ def check_is_fitted(estimator, msg=None):
if not hasattr(estimator, 'fit'):
raise TypeError("%s is not an estimator instance." % (estimator))

attrs = [v for v in vars(estimator)
if (v.endswith("_") or v.startswith("_"))
and not v.startswith("__")]
if attributes is not None:
if not isinstance(attributes, (list, tuple)):
attributes = [attributes]
attrs = all_or_any([hasattr(estimator, attr) for attr in attributes])
else:
attrs = [v for v in vars(estimator)
if v.endswith("_") and not v.startswith("__")]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we keep not v.startswith("__")?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because all dunder attributes end with _ and are not fit parameters.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And more generally, anything that starts with __ is weird / reserved in Python so scikit-learn attributes should never start with __.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And more generally, anything that starts with __ is weird / reserved in Python so scikit-learn attributes should never start with __.

You're right, but perhaps it's better to remove, because we do not forbid users to create attributes start with __.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we remove it we will have false positives.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perhaps a clearer solution is if v.endswith("_") and not v.endswith("__")? @jnothman

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 to keep excluding based on v.startswith("__"). I have seen users use private variables in __init__ in their estimators, it's not good, but we shouldn't count those as fit attributes either.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have seen users use private variables in init in their estimators, it's not good, but we shouldn't count those as fit attributes either.

but @rth we're talking about double underscore, not single underscore (private variables).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Anyway I guess this is not so important.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but @rth we're talking about double underscore, not single underscore (private variables).

I mean technically private variables as in self.__a, not very frequent but I have seen it. Anyway, yes it's not too critical either way. I just wanted to merge this quickly to fix CI on scikit-learn-extra. We can always come back on these detail at later time if needed.


if not attrs:
raise NotFittedError(msg % {'name': type(estimator).__name__})
Expand Down
0