8000 [MRG+1] Fix pickling bug due to multiple inheritance & __getstate__ … · scikit-learn/scikit-learn@4493d37 · GitHub
[go: up one dir, main page]

Skip to content

Commit 4493d37

Browse files
HolgerPetersjnothman
authored andcommitted
[MRG+1] Fix pickling bug due to multiple inheritance & __getstate__ (#8324)
Fixes #8316 * Don't use test classes to group tests * only use formatting for parts of the string that change * Flake 8 column limit * Make the modification of the estimator more explicit in the tests * As suggested in code review, prefer formatting over two literals * Also assert, that __setstate__ overwriting works in mixin * Remove cache property * Use assertion functions from sklearn.utils.testing * remove the protocol argument in tests * Rename attributes to better convey their purpose * Revert change of module in TreeNoVersion * Adhere to column-limit * changelog entry * Fix commit message
1 parent fb65a0a commit 4493d37

File tree

3 files changed

+130
-28
lines changed

3 files changed

+130
-28
lines changed

doc/whats_new.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,10 @@ Bug fixes
223223
- Fix a bug in cases where `numpy.cumsum` may be numerically unstable,
224224
raising an exception if instability is identified. :issue:`7376` and
225225
:issue:`7331` by `Joel Nothman`_ and :user:`yangarbiter`.
226+
- Fix a bug where :meth:`sklearn.base.BaseEstimator.__getstate__`
227+
obstructed pickling customizations of child-classes, when used in a
228+
multiple inheritance context.
229+
:issue:`8316` by :user:`Holger Peters <HolgerPeters>`.
226230

227231
API changes summary
228232
-------------------

sklearn/base.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -290,10 +290,15 @@ def __repr__(self):
290290
offset=len(class_name),),)
291291

292292
def __getstate__(self):
293+
try:
294+
state = super(BaseEstimator, self).__getstate__()
295+
except AttributeError:
296+
state = self.__dict__.copy()
297+
293298
if type(self).__module__.startswith('sklearn.'):
294-
return dict(self.__dict__.items(), _sklearn_version=__version__)
299+
return dict(state.items(), _sklearn_version=__version__)
295300
else:
296-
return dict(self.__dict__.items())
301+
return state
297302

298303
def __setstate__(self, state):
299304
if type(self).__module__.startswith('sklearn.'):
@@ -305,7 +310,11 @@ def __setstate__(self, state):
305310
"invalid results. Use at your own risk.".format(
306311
self.__class__.__name__, pickle_version, __version__),
307312
UserWarning)
308-
self.__dict__.update(state)
313+
try:
314+
super(BaseEstimator, self).__setstate__(state)
315+
except AttributeError:
316+
self.__dict__.update(state)
317+
309318

310319

311320
###############################################################################

sklearn/tests/test_base.py

Lines changed: 114 additions & 25 deletions
-
def test_pickle_version_warning():
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
# Author: Gael Varoquaux
22
# License: BSD 3 clause
33

4-
import sys
5-
64
import numpy as np
75
import scipy.sparse as sp
86

@@ -15,6 +13,7 @@
1513
from sklearn.utils.testing import assert_raises
1614
from sklearn.utils.testing import assert_no_warnings
1715
from sklearn.utils.testing import assert_warns_message
16+
from sklearn.utils.testing import assert_dict_equal
1817

1918
from sklearn.base import BaseEstimator, clone, is_classifier
2019
from sklearn.svm import SVC
@@ -314,48 +313,138 @@ def transform(self, X, y=None):
314313
assert_equal(e.scalar_param, cloned_e.scalar_param)
315314

316315

317-
class TreeNoVersion(DecisionTreeClassifier):
318-
def __getstate__(self):
319-
return self.__dict__
316+
def test_pickle_version_warning_is_not_raised_with_matching_version():
317+
iris = datasets.load_iris()
318+
tree = DecisionTreeClassifier().fit(iris.data, iris.target)
319+
tree_pickle = pickle.dumps(tree)
320+
assert_true(b"version" in tree_pickle)
321+
tree_restored = assert_no_warnings(pickle.loads, tree_pickle)
322+
323+
# test that we can predict with the restored decision tree classifier
324+
score_of_original = tree.score(iris.data, iris.target)
325+
score_of_restored = tree_restored.score(iris.data, iris.target)
326+
assert_equal(score_of_original, score_of_restored)
320327

321328

322329
class TreeBadVersion(DecisionTreeClassifier):
323330
def __getstate__(self):
324331
return dict(self.__dict__.items(), _sklearn_version="something")
325332

326333

327
328-
# check that warnings are raised when unpickling in a different version
334+
pickle_error_message = (
335+
"Trying to unpickle estimator {estimator} from "
336+
"version {old_version} when using version "
337+
"{current_version}. This might "
338+
"lead to breaking code or invalid results. "
339+
"Use at your own risk.")
329340

330-
# first, check no warning when in the same version:
331-
iris = datasets.load_iris()
332-
tree = DecisionTreeClassifier().fit(iris.data, iris.target)
333-
tree_pickle = pickle.dumps(tree)
334-
assert_true(b"version" in tree_pickle)
335-
assert_no_warnings(pickle.loads, tree_pickle)
336341

337-
# check that warning is raised on different version
342+
def test_pickle_version_warning_is_issued_upon_different_version():
343+
iris = datasets.load_iris()
338344
tree = TreeBadVersion().fit(iris.data, iris.target)
339345
tree_pickle_other = pickle.dumps(tree)
340-
message = ("Trying to unpickle estimator TreeBadVersion from "
341-
"version {0} when using version {1}. This might lead to "
342-
"breaking code or invalid results. "
343-
"Use at your own risk.".format("something",
344-
sklearn.__version__))
346+
message = pickle_error_message.format(estimator="TreeBadVersion",
347+
old_version="something",
348+
current_version=sklearn.__version__)
345349
assert_warns_message(UserWarning, message, pickle.loads, tree_pickle_other)
346350

347-
# check that not including any version also works:
351+
352+
class TreeNoVersion(DecisionTreeClassifier):
353+
def __getstate__(self):
354+
return self.__dict__
355+
356+
357+
def test_pickle_version_warning_is_issued_when_no_version_info_in_pickle():
358+
iris = datasets.load_iris()
348359
# TreeNoVersion has no getstate, like pre-0.18
349360
tree = TreeNoVersion().fit(iris.data, iris.target)
350361

351362
tree_pickle_noversion = pickle.dumps(tree)
352363
assert_false(b"version" in tree_pickle_noversion)
353-
message = message.replace("something", "pre-0.18")
354-
message = message.replace("TreeBadVersion", "TreeNoVersion")
364+
message = pickle_error_message.format(estimator="TreeNoVersion",
365+
old_version="pre-0.18",
366+
current_version=sklearn.__version__)
355367
# check we got the warning about using pre-0.18 pickle
356368
assert_warns_message(UserWarning, message, pickle.loads,
357369
tree_pickle_noversion)
358370

359-
# check that no warning is raised for external estimators
360-
TreeNoVersion.__module__ = "notsklearn"
361-
assert_no_warnings(pickle.loads, tree_pickle_noversion)
371+
372+
def test_pickle_version_no_warning_is_issued_with_non_sklearn_estimator():
373+
iris = datasets.load_iris()
374+
tree = TreeNoVersion().fit(iris.data, iris.target)
375+
tree_pickle_noversion = pickle.dumps(tree)
376+
try:
377+
module_backup = TreeNoVersion.__module__
378+
TreeNoVersion.__module__ = "notsklearn"
379+
assert_no_warnings(pickle.loads, tree_pickle_noversion)
380+
finally:
381+
TreeNoVersion.__module__ = module_backup
382+
383+
384+
class DontPickleAttributeMixin(object):
385+
def __getstate__(self):
386+
data = self.__dict__.copy()
387+
data["_attribute_not_pickled"] = None
388+
return data
389+
390+
def __setstate__(self, state):
391+
state["_restored"] = True
392+
self.__dict__.update(state)
393+
394+
395+
class MultiInheritanceEstimator(BaseEstimator, DontPickleAttributeMixin):
396+
def __init__(self, attribute_pickled=5):
397+
self.attribute_pickled = attribute_pickled
398+
self._attribute_not_pickled = None
399+
400+
401+
def test_pickling_when_getstate_is_overwritten_by_mixin():
402+
estimator = MultiInheritanceEstimator()
403+
estimator._attribute_not_pickled = "this attribute should not be pickled"
404+
405+
serialized = pickle.dumps(estimator)
406+
estimator_restored = pickle.loads(serialized)
407+
assert_equal(estimator_restored.attribute_pickled, 5)
408+
assert_equal(estimator_restored._attribute_not_pickled, None)
409+
assert_true(estimator_restored._restored)
410+
411+
412+
def test_pickling_when_getstate_is_overwritten_by_mixin_outside_of_sklearn():
413+
try:
414+
estimator = MultiInheritanceEstimator()
415+
text = "this attribute should not be pickled"
416+
estimator._attribute_not_pickled = text
417+
old_mod = type(estimator).__module__
418+
type(estimator).__module__ = "notsklearn"
419+
420+
serialized = estimator.__getstate__()
421+
assert_dict_equal(serialized, {'_attribute_not_pickled': None,
422+
'attribute_pickled': 5})
423+
424+
serialized['attribute_pickled'] = 4
425+
estimator.__setstate__(serialized)
426+
assert_equal(estimator.attribute_pickled, 4)
427+
assert_true(estimator._restored)
428+
finally:
429+
type(estimator).__module__ = old_mod
430+
431+
432+
class SingleInheritanceEstimator(BaseEstimator):
433+
def __init__(self, attribute_pickled=5):
434+
self.attribute_pickled = attribute_pickled
435+
self._attribute_not_pickled = None
436+
437+
def __getstate__(self):
438+
data = self.__dict__.copy()
439+
data["_attribute_not_pickled"] = None
440+
return data
441+
442+
443+
def test_pickling_works_when_getstate_is_overwritten_in_the_child_class():
444+
estimator = SingleInheritanceEstimator()
445+
estimator._attribute_not_pickled = "this attribute should not be pickled"
446+
447+
serialized = pickle.dumps(estimator)
448+
estimator_restored = pickle.loads(serialized)
449+
assert_equal(estimator_restored.attribute_pickled, 5)
450+
assert_equal(estimator_restored._attribute_not_pickled, None)

0 commit comments

Comments
 (0)
0