-
-
Notifications
You must be signed in to change notification settings - Fork 26k
[MRG+1] Fix pickling bug due to multiple inheritance & __getstate__ #8324
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Codecov Report
@@ Coverage Diff @@
## master #8324 +/- ##
==========================================
+ Coverage 94.75% 94.75% +<.01%
==========================================
Files 342 342
Lines 60816 60886 +70
==========================================
+ Hits 57624 57695 +71
+ Misses 3192 3191 -1
Continue to review full report at Codecov.
|
@@ -290,10 +290,11 @@ def __repr__(self): | |||
offset=len(class_name),),) | |||
|
|||
def __getstate__(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just wondering whether we should adopt a similar strategy for __setstate__
? Quickly looking at it, it seems like we are doing startswith('sklearn.')
there as well to avoid messing with classes outside scikit-learn deriving from sklearn.base.BaseEstimator
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I believe there should be parallel changes in __setstate__
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree
sklearn/base.py
Outdated
@@ -290,10 +290,11 @@ def __ 8000 repr__(self): | |||
offset=len(class_name),),) | |||
|
|||
def __getstate__(self): | |||
if type(self).__module__.startswith('sklearn.'): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we still want this... we just want to delegate to super
instead of blindly getting self.__dict__.items()
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will reintroduce the check
@@ -290,10 +290,11 @@ def __repr__(self): | |||
offset=len(class_name),),) | |||
|
|||
def __getstate__(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I believe there should be parallel changes in __setstate__
sklearn/base.py
Outdated
state = super(BaseEstimator, self).__getstate__() | ||
except AttributeError: | ||
state = self.__dict__ | ||
|
||
if type(self).__module__.startswith('sklearn.'): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This conditional is not properly covered, because the classes I use in the test are from within the sklearn module namespace (since scikit-learn has its tests in the sklearn namespace). So basically return state.copy()
is never reached in the tests. Problem is, I cannot patch type(self).__module__
without breaking the pickling mechanism (it makes a lookup for the module). So either we make the string 'sklearn.'
in the BaseEstimator
mock-patchable, or we need to create a test-class outside of the sklearn namespace, to get this conditional fully covered.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My idea was to write something like this
def test_multiple_inheritance_setting_foreign_namespace(self):
try:
estimator = MultiInheritanceEstimator()
old_mod = type(estimator).__module__
type(estimator).__module__ = "notsklearn"
serialized = pickle.dumps(estimator, protocol=2)
finally:
type(estimator).__module__ = old_mod
which doesn't work for the aforementioned reason.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we already test something like this. You can either test get/set_state directly (instead of pickling) or hack the pickle loading, perhaps by hacking sys.modules
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 for testing getstate / setstate directly.
codecov seems to complain about a drop in coverage in Quickly looking at the diff it seems like you are not using |
Alright, I think I have incorporated your feedback in this PR. And all CI was green. I have now squashed the commits, so it makes for a nicer patch in the history and rebased stuff on the most recent master. Would you say it is mergeable? Anything else I need to address? |
You don't need to squash: github provides a "squash and merge" button. Also, unless there are merge conflicts, rebase is usually unnecessary (and even then, merging in the latest master suffices). However, you generally require two full reviews and "LGTM"s before merge. We have a long backlog of reviewing. Thanks for your patience. |
sklearn/base.py
Outdated
else: | ||
return dict(self.__dict__.items()) | ||
return state.copy() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
surely it should only be necessary to copy in the .__dict__
case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed, an object's state need not be a dictionary and this line will break given some other types.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed in 13c74fc
try: | ||
super(BaseEstimator, self).__setstate__(state) | ||
except AttributeError: | ||
self.__dict__.update(state) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is okay, but FYI, pickle doesn't directly use update, in order to ensure all strings are interned: https://github.com/python/cpython/blob/master/Lib/pickle.py#L1522.
sklearn/tests/test_base.py
Outdated
return self._cache | ||
|
||
|
||
class TestPicklingConstraints(object): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- We usually don't use test classes, but I'm personally okay with this as a way of grouping the code.
- It might be good to mention that this test is about
BaseEstimator
somewhere (though I see there's a lack of that in this file).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- Addressed in 0a55414
- since the tests are about estimators, isn't it (implicitly) clear, that
BaseEstimator
is involved. Also, the module beingsklearn.base
.
sklearn/tests/test_base.py
Outdated
|
||
def test_singleinheritance_clone(self): | ||
estimator = SingleInheritanceEstimator() | ||
assert estimator.cache |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's really hard to tell from the test code that this modifies __dict__
. I think I'd rather something more transparent than a cache. For example, the getstate could store a timestamp.
sklearn/tests/test_base.py
Outdated
|
||
serialized = pickle.dumps(estimator, protocol=2) | ||
estimator_restored = pickle.loads(serialized) | ||
assert estimator_restored.b == 5 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I notice this test tests the basic restoration that should have also been tested in test_pickle_version_warning
which should really be checking that the loaded pickled trees can still predict. (You could add that if you wish.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added in 393ff64
sklearn/tests/test_base.py
Outdated
finally: | ||
type(estimator).__module__ = old_mod | ||
|
||
def test_uses_object_dictionary_when_getstate_not_present(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't get what you mean by "when_getstate_not_present". You're using a MultiInheritanceEstimator
for which __getstate__
is present twice in the MRO.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure what I meant by this test anymore, so I removed it in a94e4d1 without decreasing coverage
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. In addition to these nitpicks, you've not tested your modifications to __setstate__
sklearn/tests/test_base.py
Outdated
tree = TreeBadVersion().fit(iris.data, iris.target) | ||
tree_pickle_other = pickle.dumps(tree) | ||
message = ("Trying to unpickle estimator TreeBadVersion from " | ||
"version {0} when using version {1}. This might lead to " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's clearer to the maintainer if "something" is still formatted in
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hope a35832c is as you intend it to be. I really didn't like the replace
calls in the old tests. I assume you favour a global message template, two over repeating it in the tests. The other option would be to duplicate the template in the tests (which I think is probably the least optimal variant).
sklearn/tests/test_base.py
Outdated
self._cache = None | ||
|
||
@property | ||
def cache(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I still don't think this makes for the test being easily read, and would rather something more explicit than a property with side effects. Even if the test did estimator._cache = "some_value"
directly it would be better.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Addressed in c965511, indeed this is better than the property
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few comments
sklearn/tests/test_base.py
Outdated
return data | ||
|
||
@property | ||
def cache(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are still not using cache anywhere, right? If so can you please remove it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed in 32cbb36
sklearn/tests/test_base.py
Outdated
|
||
serialized = pickle.dumps(estimator, protocol=2) | ||
estimator_restored = pickle.loads(serialized) | ||
assert estimator_restored.b == 5 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using bare asserts with nose creates not so great error messages. This does matter on CIs. Can you please use assert_*
helpers in sklearn.utils.testing
in all your tests?
I know the migration to pytest is definitely on the radar, but still, I think this is the right thing to do. I'd be happy to hear different opinions.
An example of using bare asserts vs an assert_*
helper when running with nose:
======================================================================
FAIL: test_nose.test
----------------------------------------------------------------------
Traceback (most recent call last):
File "/volatile/le243287/miniconda3/lib/python3.5/site-packages/nose/case.py", line 198, in runTest
self.test(*self.arg)
File "/tmp/test_nose.py", line 7, in test
assert x > y
AssertionError
======================================================================
FAIL: test_nose.test2
----------------------------------------------------------------------
Traceback (most recent call last):
File "/volatile/le243287/miniconda3/lib/python3.5/site-packages/nose/case.py", line 198, in runTest
self.test(*self.arg)
File "/tmp/test_nose.py", line 13, in test2
assert_greater(x, y)
AssertionError: 2 not greater than 4
----------------------------------------------------------------------
Ran 2 tests in 0.002s
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
see e1417ad
using py.test locally so I wasn't aware that this is an issue with nose :)
if type(self).__module__.startswith('sklearn.'): | ||
return dict(self.__dict__.items(), _sklearn_version=__version__) | ||
return dict(state.items(), _sklearn_version=__version__) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually thinking about it I am a bit confused about this, should we not always have a warning mechanism even if the estimator is outside scikit-learn? I could imagine someone inheriting from say LogisticRegression
with a minor modification and the warning applies to this case as well, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Firstly that is not a problem for this PR. Secondly, we do not issue the warning if out of sklearn because such an estimator is likely to be versioned differently.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Firstly that is not a problem for this PR.
Agreed.
Secondly, we do not issue the warning if out of sklearn because such an estimator is likely to be versioned differently.
OK I trust your judgement on this. The use case I had in mind was a thin wrapper around a scikit-learn estimator (deriving from an estimator class mostly for convenience), in which case it would make sense to get the warnings.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some small comments. LGTM otherwise.
sklearn/tests/test_base.py
Outdated
def test_pickle_version_no_warning_is_issued_with_non_sklearn_estimator(): | ||
iris = datasets.load_iris() | ||
tree = TreeNoVersion().fit(iris.data, iris.target) | ||
tree_pickle_noversion = pickle.dumps(tree) | ||
TreeNoVersion.__module__ = "notsklearn" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know it was like this before, but don't you need a try/finally
here too to make sure that TreeNoVersion.__module__
is set to its original value?
sklearn/tests/test_base.py
Outdated
estimator = MultiInheritanceEstimator() | ||
estimator._cache = "this attribute should not be pickled" | ||
|
||
serialized = pickle.dumps(estimator, protocol=2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why protocol=2? The tests pass fine with serialized = pickle.dump(estimator)
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed in 7985b0b
sklearn/tests/test_base.py
Outdated
|
||
class MultiInheritanceEstimator(BaseEstimator, DontPickleCacheMixin): | ||
def __init__(self, b=5): | ||
self.b = b |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel like the test could be made clearer by better variable naming, e.g. attr_pickled
and attr_not_pickled
. If you choose to do it, change it uniformly (and also maybe the naming of DontPickleCacheMixin).
assert_warns_message(UserWarning, message, pickle.loads, tree_pickle_other) | ||
|
||
# check that not including any version also works: | ||
|
||
class TreeNoVersion(DecisionTreeClassifier): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My 2c: moving things around like this adds unnecessary noise in the diff thus making it harder to review without any significant benefit.
LGTM, could you add an entry in doc/whats_new.rst? |
…ate__ Includes a reproducing test case. sklearn.base.BaseEstimator now tries to use other __getstate__ methods of the class hierarchy first, before defaulting to the __dict__ attribute
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, @jnothman do you want to have a look at this one so we can merge it? Two minor comments in the changelog.
doc/whats_new.rst
Outdated
@@ -220,6 +220,10 @@ Bug fixes | |||
- Fix a bug in cases where `numpy.cumsum` may be numerically unstable, | |||
raising an exception if instability is identified. :issue:`7376` and | |||
:issue:`7331` by `Joel Nothman`_ and :user:`yangarbiter`. | |||
- Fix a bug where :meth:`sklearn.base.BaseEstimator.__getstate__` blocked | |||
obstructed pickling customizations of child-classes, when used in a | |||
multiple inheritence context. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typo: inheritance
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
indeed, should be fixed with the last push.
doc/whats_new.rst
Outdated
@@ -220,6 +220,10 @@ Bug fixes | |||
- Fix a bug in cases where `numpy.cumsum` may be numerically unstable, | |||
raising an exception if instability is identified. :issue:`7376` and | |||
:issue:`7331` by `Joel Nothman`_ and :user:`yangarbiter`. | |||
- Fix a bug where :meth:`sklearn.base.BaseEstimator.__getstate__` blocked |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks like you could not make up your mind between blocked and obstructed ;-)
LGTM thanks @HolgerPeters |
Great stuff @HolgerPeters, thanks a lot! |
…cikit-learn#8324) Fixes scikit-learn#8316 * Don't use test classes to group tests * only use formatting for parts of the string that change * Flake 8 column limit * Make the modification of the estimator more explicit in the tests * As suggested in code review, prefer formatting over two literals * Also assert, that __setstate__ overwriting works in mixin * Remove cache property * Use assertion functions from sklearn.utils.testing * remove the protocol argument in tests * Rename attributes to better convey their purpose * Revert change of module in TreeNoVersion * Adhere to column-limit * changelog entry * Fix commit message
…cikit-learn#8324) Fixes scikit-learn#8316 * Don't use test classes to group tests * only use formatting for parts of the string that change * Flake 8 column limit * Make the modification of the estimator more explicit in the tests * As suggested in code review, prefer formatting over two literals * Also assert, that __setstate__ overwriting works in mixin * Remove cache property * Use assertion functions from sklearn.utils.testing * remove the protocol argument in tests * Rename attributes to better convey their purpose * Revert change of module in TreeNoVersion * Adhere to column-limit * changelog entry * Fix commit message
…cikit-learn#8324) Fixes scikit-learn#8316 * Don't use test classes to group tests * only use formatting for parts of the string that change * Flake 8 column limit * Make the modification of the estimator more explicit in the tests * As suggested in code review, prefer formatting over two literals * Also assert, that __setstate__ overwriting works in mixin * Remove cache property * Use assertion functions from sklearn.utils.testing * remove the protocol argument in tests * Rename attributes to better convey their purpose * Revert change of module in TreeNoVersion * Adhere to column-limit * changelog entry * Fix commit message
…cikit-learn#8324) Fixes scikit-learn#8316 * Don't use test classes to group tests * only use formatting for parts of the string that change * Flake 8 column limit * Make the modification of the estimator more explicit in the tests * As suggested in code review, prefer formatting over two literals * Also assert, that __setstate__ overwriting works in mixin * Remove cache property * Use assertion functions from sklearn.utils.testing * remove the protocol argument in tests * Rename attributes to better convey their purpose * Revert change of module in TreeNoVersion * Adhere to column-limit * changelog entry * Fix commit message
…cikit-learn#8324) Fixes scikit-learn#8316 * Don't use test classes to group tests * only use formatting for parts of the string that change * Flake 8 column limit * Make the modification of the estimator more explicit in the tests * As suggested in code review, prefer formatting over two literals * Also assert, that __setstate__ overwriting works in mixin * Remove cache property * Use assertion functions from sklearn.utils.testing * remove the protocol argument in tests * Rename attributes to better convey their purpose * Revert change of module in TreeNoVersion * Adhere to column-limit * changelog entry * Fix commit message
…cikit-learn#8324) Fixes scikit-learn#8316 * Don't use test classes to group tests * only use formatting for parts of the string that change * Flake 8 column limit * Make the modification of the estimator more explicit in the tests * As suggested in code review, prefer formatting over two literals * Also assert, that __setstate__ overwriting works in mixin * Remove cache property * Use assertion functions from sklearn.utils.testing * remove the protocol argument in tests * Rename attributes to better convey their purpose * Revert change of module in TreeNoVersion * Adhere to column-limit * changelog entry * Fix commit message
Includes a reproducing test case.
sklearn.base.BaseEstimator now tries to use other
__getstate__
methods of the class hierarchy first, before defaulting to the__dict__
attributeReference Issue
Fix issue #8316