diff --git a/doc/whats_new/v1.3.rst b/doc/whats_new/v1.3.rst index 1c1fb5c1c07a8..76bf1024a232c 100644 --- a/doc/whats_new/v1.3.rst +++ b/doc/whats_new/v1.3.rst @@ -275,6 +275,12 @@ Changelog during `transform` with no prior call to `fit` or `fit_transform`. :pr:`25190` by :user:`Vincent Maladière `. +- |API| A `FutureWarning` is now raised when instantiating a class which inherits from + a deprecated base class (i.e. decorated by :class:`utils.deprecated`) and which + overrides the `__init__` method. + :pr:`25733` by :user:`Brigitta Sipőcz ` and + :user:`Jérémie du Boisberranger `. + :mod:`sklearn.semi_supervised` .............................. diff --git a/sklearn/tests/test_docstring_parameters.py b/sklearn/tests/test_docstring_parameters.py index 274f83445ee7f..8bf3e5dd7b24a 100644 --- a/sklearn/tests/test_docstring_parameters.py +++ b/sklearn/tests/test_docstring_parameters.py @@ -109,12 +109,11 @@ def test_docstring_parameters(): "Error for __init__ of %s in %s:\n%s" % (cls, name, w[0]) ) - cls_init = getattr(cls, "__init__", None) - - if _is_deprecated(cls_init): + # Skip checks on deprecated classes + if _is_deprecated(cls.__new__): continue - elif cls_init is not None: - this_incorrect += check_docstring_parameters(cls.__init__, cdoc) + + this_incorrect += check_docstring_parameters(cls.__init__, cdoc) for method_name in cdoc.methods: method = getattr(cls, method_name) diff --git a/sklearn/utils/deprecation.py b/sklearn/utils/deprecation.py index 19d41aa1eaf85..a5a70ed699197 100644 --- a/sklearn/utils/deprecation.py +++ b/sklearn/utils/deprecation.py @@ -60,17 +60,18 @@ def _decorate_class(self, cls): if self.extra: msg += "; %s" % self.extra - # FIXME: we should probably reset __new__ for full generality - init = cls.__init__ + new = cls.__new__ - def wrapped(*args, **kwargs): + def wrapped(cls, *args, **kwargs): warnings.warn(msg, category=FutureWarning) - return init(*args, **kwargs) + if new is object.__new__: + return object.__new__(cls) + return new(cls, *args, **kwargs) - cls.__init__ = wrapped + cls.__new__ = wrapped - wrapped.__name__ = "__init__" - wrapped.deprecated_original = init + wrapped.__name__ = "__new__" + wrapped.deprecated_original = new return cls diff --git a/sklearn/utils/tests/test_deprecation.py b/sklearn/utils/tests/test_deprecation.py index b810cfb85d3f6..98c69a8abb780 100644 --- a/sklearn/utils/tests/test_deprecation.py +++ b/sklearn/utils/tests/test_deprecation.py @@ -36,6 +36,22 @@ class MockClass4: pass +class MockClass5(MockClass1): + """Inherit from deprecated class but does not call super().__init__.""" + + def __init__(self, a): + self.a = a + + +@deprecated("a message") +class MockClass6: + """A deprecated class that overrides __new__.""" + + def __new__(cls, *args, **kwargs): + assert len(args) > 0 + return super().__new__(cls) + + @deprecated() def mock_function(): return 10 @@ -48,6 +64,10 @@ def test_deprecated(): MockClass2().method() with pytest.warns(FutureWarning, match="deprecated"): MockClass3() + with pytest.warns(FutureWarning, match="qwerty"): + MockClass5(42) + with pytest.warns(FutureWarning, match="a message"): + MockClass6(42) with pytest.warns(FutureWarning, match="deprecated"): val = mock_function() assert val == 10 @@ -56,10 +76,11 @@ def test_deprecated(): def test_is_deprecated(): # Test if _is_deprecated helper identifies wrapping via deprecated # NOTE it works only for class methods and functions - assert _is_deprecated(MockClass1.__init__) + assert _is_deprecated(MockClass1.__new__) assert _is_deprecated(MockClass2().method) assert _is_deprecated(MockClass3.__init__) assert not _is_deprecated(MockClass4.__init__) + assert _is_deprecated(MockClass5.__new__) assert _is_deprecated(mock_function)