10BC0 FIX fix pickling for empty object with Python 3.11+ (#25188) · scikit-learn/scikit-learn@9017c70 · GitHub
[go: up one dir, main page]

Skip to content

Commit 9017c70

Browse files
FIX fix pickling for empty object with Python 3.11+ (#25188)
Co-authored-by: Adrin Jalali <adrin.jalali@gmail.com> Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com> Python 3.11 introduces `__getstate__` on the `object` level, which breaks our existing `__getstate__` code for objects w/o any attributes. This fixes the issue.
1 parent b0bf231 commit 9017c70

File tree

3 files changed

+63
-0
lines changed

3 files changed

+63
-0
lines changed

doc/whats_new/v1.2.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,13 @@ Version 1.2.1
1212
Changelog
1313
---------
1414

15+
:mod:`sklearn.base`
16+
...................
17+
18+
- |Fix| Fix a regression in `BaseEstimator.__getstate__` that would prevent
19+
certain estimators to be pickled when using Python 3.11. :pr:`25188` by
20+
:user:`Benjamin Bossan <BenjaminBossan>`.
21+
1522
:mod:`sklearn.utils`
1623
....................
1724

sklearn/base.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,9 +271,20 @@ def __repr__(self, N_CHAR_MAX=700):
271271
return repr_
272272

273273
def __getstate__(self):
274+
if getattr(self, "__slots__", None):
275+
raise TypeError(
276+
"You cannot use `__slots__` in objects inheriting from "
277+
"`sklearn.base.BaseEstimator`."
278+
)
279+
274280
try:
275281
state = super().__getstate__()
282+
if state is None:
283+
# For Python 3.11+, empty instance (no `__slots__`,
284+
# and `__dict__`) will return a state equal to `None`.
285+
state = self.__dict__.copy()
276286
except AttributeError:
287+
# Python < 3.11
277288
state = self.__dict__.copy()
278289

279290
if type(self).__module__.startswith("sklearn."):

sklearn/tests/test_base.py

Lines changed: 45 additions & 0 10BC0 deletions
Original file line numberDiff line numberDiff line change
@@ -675,3 +675,48 @@ def test_clone_keeps_output_config():
675675
ss_clone = clone(ss)
676676
config_clone = _get_output_config("transform", ss_clone)
677677
assert config == config_clone
678+
679+
680+
class _Empty:
681+
pass
682+
683+
684+
class EmptyEstimator(_Empty, BaseEstimator):
685+
pass
686+
687+
688+
@pytest.mark.parametrize("estimator", [BaseEstimator(), EmptyEstimator()])
689+
def test_estimator_empty_instance_dict(estimator):
690+
"""Check that ``__getstate__`` returns an empty ``dict`` with an empty
691+
instance.
692+
693+
Python 3.11+ changed behaviour by returning ``None`` instead of raising an
694+
``AttributeError``. Non-regression test for gh-25188.
695+
"""
696+
state = estimator.__getstate__()
697+
expected = {"_sklearn_version": sklearn.__version__}
698+
assert state == expected
699+
700+
# this should not raise
701+
pickle.loads(pickle.dumps(BaseEstimator()))
702+
703+
704+
def test_estimator_getstate_using_slots_error_message():
705+
"""Using a `BaseEstimator` with `__slots__` is not supported."""
706+
707+
class WithSlots:
708+
__slots__ = ("x",)
709+
710+
class Estimator(BaseEstimator, WithSlots):
711+
pass
712+
713+
msg = (
714+
"You cannot use `__slots__` in objects inheriting from "
715+
"`sklearn.base.BaseEstimator`"
716+
)
717+
718+
with pytest.raises(TypeError, match=msg):
719+
Estimator().__getstate__()
720+
721+
with pytest.raises(TypeError, match=msg):
722+
pickle.dumps(Estimator())

0 commit comments

Comments
 (0)
0