8000 [MRG+1] Fix pickling bug due to multiple inheritance & __getstate__ by HolgerPeters · Pull Request #8324 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

[MRG+1] Fix pickling bug due to multiple inheritance & __getstate__ #8324

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Feb 20, 2017

Conversation

HolgerPeters
Copy link
Contributor
@HolgerPeters HolgerPeters commented Feb 9, 2017

Includes a reproducing test case.

sklearn.base.BaseEstimator now tries to use other __getstate__ methods of the class hierarchy first, before defaulting to the __dict__ attribute

Reference Issue

Fix issue #8316

@codecov
Copy link
codecov bot commented Feb 9, 2017

Codecov Report

Merging #8324 into master will increase coverage by <.01%.
The diff coverage is 100%.

@@            Coverage Diff             @@
##           master    #8324      +/-   ##
==========================================
+ Coverage   94.75%   94.75%   +<.01%     
==========================================
  Files         342      342              
  Lines       60816    60886      +70     
==========================================
+ Hits        57624    57695      +71     
+ Misses       3192     3191       -1
Impacted Files Coverage Δ
sklearn/base.py 94.94% <100%> (+0.79%)
sklearn/tests/test_base.py 97.6% <100%> (+0.8%)

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update ba8771f...934efaa. Read the comment docs.

@@ -290,10 +290,11 @@ def __repr__(self):
offset=len(class_name),),)

def __getstate__(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just wondering whether we should adopt a similar strategy for __setstate__? Quickly looking at it, it seems like we are doing startswith('sklearn.') there as well to avoid messing with classes outside scikit-learn deriving from sklearn.base.BaseEstimator.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I believe there should be parallel changes in __setstate__

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree

sklearn/base.py Outdated
@@ -290,10 +290,11 @@ def __ 8000 repr__(self):
offset=len(class_name),),)

def __getstate__(self):
if type(self).__module__.startswith('sklearn.'):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we still want this... we just want to delegate to super instead of blindly getting self.__dict__.items().

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will reintroduce the check

@@ -290,10 +290,11 @@ def __repr__(self):
offset=len(class_name),),)

def __getstate__(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I believe there should be parallel changes in __setstate__

sklearn/base.py Outdated
state = super(BaseEstimator, self).__getstate__()
except AttributeError:
state = self.__dict__

if type(self).__module__.startswith('sklearn.'):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This conditional is not properly covered, because the classes I use in the test are from within the sklearn module namespace (since scikit-learn has its tests in the sklearn namespace). So basically return state.copy() is never reached in the tests. Problem is, I cannot patch type(self).__module__ without breaking the pickling mechanism (it makes a lookup for the module). So either we make the string 'sklearn.' in the BaseEstimator mock-patchable, or we need to create a test-class outside of the sklearn namespace, to get this conditional fully covered.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My idea was to write something like this

    def test_multiple_inheritance_setting_foreign_namespace(self):
        try:
            estimator = MultiInheritanceEstimator()
            old_mod = type(estimator).__module__
            type(estimator).__module__ = "notsklearn"

            serialized = pickle.dumps(estimator, protocol=2)
        finally:
            type(estimator).__module__ = old_mod

which doesn't work for the aforementioned reason.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we already test something like this. You can either test get/set_state directly (instead of pickling) or hack the pickle loading, perhaps by hacking sys.modules.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 for testing getstate / setstate directly.

@HolgerPeters
Copy link
Contributor Author

Not quite sure why codecov reports a reduction in coverage. It seems all code paths are now fully tested.

screen shot 2017-02-09 at 15 25 22

@lesteve
Copy link
Member
lesteve commented Feb 9, 2017

Not quite sure why codecov reports a reduction in coverage

codecov seems to complain about a drop in coverage in sklearn/test/test_base.py looking at #8324 (comment).

Quickly looking at the diff it seems like you are not using MultiInheritanceEstimator.cache and SingleInheritanceEstimator.

@HolgerPeters
Copy link
Contributor Author

Alright, I think I have incorporated your feedback in this PR. And all CI was green. I have now squashed the commits, so it makes for a nicer patch in the history and rebased stuff on the most recent master. Would you say it is mergeable? Anything else I need to address?

@jnothman
Copy link
Member

You don't need to squash: github provides a "squash and merge" button. Also, unless there are merge conflicts, rebase is usually unnecessary (and even then, merging in the latest master suffices).

However, you generally require two full reviews and "LGTM"s before merge. We have a long backlog of reviewing. Thanks for your patience.

sklearn/base.py Outdated
else:
return dict(self.__dict__.items())
return state.copy()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

surely it should only be necessary to copy in the .__dict__ case.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, an object's state need not be a dictionary and this line will break given some other types.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in 13c74fc

try:
super(BaseEstimator, self).__setstate__(state)
except AttributeError:
self.__dict__.update(state)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is okay, but FYI, pickle doesn't directly use update, in order to ensure all strings are interned: https://github.com/python/cpython/blob/master/Lib/pickle.py#L1522.

return self._cache


class TestPicklingConstraints(object):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. We usually don't use test classes, but I'm personally okay with this as a way of grouping the code.
  2. It might be good to mention that this test is about BaseEstimator somewhere (though I see there's a lack of that in this file).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Addressed in 0a55414
  2. since the tests are about estimators, isn't it (implicitly) clear, that BaseEstimator is involved. Also, the module being sklearn.base.


def test_singleinheritance_clone(self):
estimator = SingleInheritanceEstimator()
assert estimator.cache
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's really hard to tell from the test code that this modifies __dict__. I think I'd rather something more transparent than a cache. For example, the getstate could store a timestamp.


serialized = pickle.dumps(estimator, protocol=2)
estimator_restored = pickle.loads(serialized)
assert estimator_restored.b == 5
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I notice this test tests the basic restoration that should have also been tested in test_pickle_version_warning which should really be checking that the loaded pickled trees can still predict. (You could add that if you wish.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added in 393ff64

finally:
type(estimator).__module__ = old_mod

def test_uses_object_dictionary_when_getstate_not_present(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't get what you mean by "when_getstate_not_present". You're using a MultiInheritanceEstimator for which __getstate__ is present twice in the MRO.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure what I meant by this test anymore, so I removed it in a94e4d1 without decreasing coverage

Copy link
Member
@jnothman jnothman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. In addition to these nitpicks, you've not tested your modifications to __setstate__

tree = TreeBadVersion().fit(iris.data, iris.target)
tree_pickle_other = pickle.dumps(tree)
message = ("Trying to unpickle estimator TreeBadVersion from "
"version {0} when using version {1}. This might lead to "
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's clearer to the maintainer if "something" is still formatted in

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hope a35832c is as you intend it to be. I really didn't like the replace calls in the old tests. I assume you favour a global message template, two over repeating it in the tests. The other option would be to duplicate the template in the tests (which I think is probably the least optimal variant).

self._cache = None

@property
def cache(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still don't think this makes for the test being easily read, and would rather something more explicit than a property with side effects. Even if the test did estimator._cache = "some_value" directly it would be better.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in c965511, indeed this is better than the property

@jnothman jnothman added the Bug label Feb 15, 2017
Copy link
Member
@lesteve lesteve left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few comments

return data

@property
def cache(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are still not using cache anywhere, right? If so can you please remove it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed in 32cbb36


serialized = pickle.dumps(estimator, protocol=2)
estimator_restored = pickle.loads(serialized)
assert estimator_restored.b == 5
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using bare asserts with nose creates not so great error messages. This does matter on CIs. Can you please use assert_* helpers in sklearn.utils.testing in all your tests?

I know the migration to pytest is definitely on the radar, but still, I think this is the right thing to do. I'd be happy to hear different opinions.

An example of using bare asserts vs an assert_* helper when running with nose:

======================================================================
FAIL: test_nose.test
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/volatile/le243287/miniconda3/lib/python3.5/site-packages/nose/case.py", line 198, in runTest
    self.test(*self.arg)
  File "/tmp/test_nose.py", line 7, in test
    assert x > y
AssertionError

======================================================================
FAIL: test_nose.test2
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/volatile/le243287/miniconda3/lib/python3.5/site-packages/nose/case.py", line 198, in runTest
    self.test(*self.arg)
  File "/tmp/test_nose.py", line 13, in test2
    assert_greater(x, y)
AssertionError: 2 not greater than 4

----------------------------------------------------------------------
Ran 2 tests in 0.002s

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see e1417ad

using py.test locally so I wasn't aware that this is an issue with nose :)

if type(self).__module__.startswith('sklearn.'):
return dict(self.__dict__.items(), _sklearn_version=__version__)
return dict(state.items(), _sklearn_version=__version__)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually thinking about it I am a bit confused about this, should we not always have a warning mechanism even if the estimator is outside scikit-learn? I could imagine someone inheriting from say LogisticRegression with a minor modification and the warning applies to this case as well, right?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Firstly that is not a problem for this PR. Secondly, we do not issue the warning if out of sklearn because such an estimator is likely to be versioned differently.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Firstly that is not a problem for this PR.

Agreed.

Secondly, we do not issue the warning if out of sklearn because such an estimator is likely to be versioned differently.

OK I trust your judgement on this. The use case I had in mind was a thin wrapper around a scikit-learn estimator (deriving from an estimator class mostly for convenience), in which case it would make sense to get the warnings.

Copy link
Member
@lesteve lesteve left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some small comments. LGTM otherwise.

def test_pickle_version_no_warning_is_issued_with_non_sklearn_estimator():
iris = datasets.load_iris()
tree = TreeNoVersion().fit(iris.data, iris.target)
tree_pickle_noversion = pickle.dumps(tree)
TreeNoVersion.__module__ = "notsklearn"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know it was like this before, but don't you need a try/finally here too to make sure that TreeNoVersion.__module__ is set to its original value?

estimator = MultiInheritanceEstimator()
estimator._cache = "this attribute should not be pickled"

serialized = pickle.dumps(estimator, protocol=2)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why protocol=2? The tests pass fine with serialized = pickle.dump(estimator).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed in 7985b0b


class MultiInheritanceEstimator(BaseEstimator, DontPickleCacheMixin):
def __init__(self, b=5):
self.b = b
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like the test could be made clearer by better variable naming, e.g. attr_pickled and attr_not_pickled. If you choose to do it, change it uniformly (and also maybe the naming of DontPickleCacheMixin).

assert_warns_message(UserWarning, message, pickle.loads, tree_pickle_other)

# check that not including any version also works:

class TreeNoVersion(DecisionTreeClassifier):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My 2c: moving things around like this adds unnecessary noise in the diff thus making it harder to review without any significant benefit.

@lesteve
Copy link
Member
lesteve commented Feb 20, 2017

LGTM, could you add an entry in doc/whats_new.rst?

Copy link
Member
@lesteve lesteve left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, @jnothman do you want to have a look at this one so we can merge it? Two minor comments in the changelog.

@@ -220,6 +220,10 @@ Bug fixes
- Fix a bug in cases where `numpy.cumsum` may be numerically unstable,
raising an exception if instability is identified. :issue:`7376` and
:issue:`7331` by `Joel Nothman`_ and :user:`yangarbiter`.
- Fix a bug where :meth:`sklearn.base.BaseEstimator.__getstate__` blocked
obstructed pickling customizations of child-classes, when used in a
multiple inheritence context.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo: inheritance

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indeed, should be fixed with the last push.

@@ -220,6 +220,10 @@ Bug fixes
- Fix a bug in cases where `numpy.cumsum` may be numerically unstable,
raising an exception if instability is identified. :issue:`7376` and
:issue:`7331` by `Joel Nothman`_ and :user:`yangarbiter`.
- Fix a bug where :meth:`sklearn.base.BaseEstimator.__getstate__` blocked
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like you could not make up your mind between blocked and obstructed ;-)

@lesteve lesteve changed the title Fix pickling bug due to multiple inheritance & __getstate__ [MRG+1] Fix pickling bug due to multiple inheritance & __getstate__ Feb 20, 2017
@jnothman
Copy link
Member

LGTM thanks @HolgerPeters

@jnothman jnothman merged commit 4493d37 into scikit-learn:master Feb 20, 2017
@lesteve
Copy link
Member
lesteve commented Feb 21, 2017

Great stuff @HolgerPeters, thanks a lot!

sergeyf pushed a commit to sergeyf/scikit-learn that referenced this pull request Feb 28, 2017
…cikit-learn#8324)

Fixes scikit-learn#8316


* Don't use test classes to group tests

* only use formatting for parts of the string that change

* Flake 8 column limit

* Make the modification of the estimator more explicit in the tests

* As suggested in code review, prefer formatting over two literals

* Also assert, that __setstate__ overwriting works in mixin

* Remove cache property

* Use assertion functions from sklearn.utils.testing

* remove the protocol argument in tests

* Rename attributes to better convey their purpose

* Revert change of module in TreeNoVersion

* Adhere to column-limit

* changelog entry

* Fix commit message
@Przemo10 Przemo10 mentioned this pull request Mar 17, 2017
Sundrique pushed a commit to Sundrique/scikit-learn that referenced this pull request Jun 14, 2017
…cikit-learn#8324)

Fixes scikit-learn#8316


* Don't use test classes to group tests

* only use formatting for parts of the string that change

* Flake 8 column limit

* Make the modification of the estimator more explicit in the tests

* As suggested in code review, prefer formatting over two literals

* Also assert, that __setstate__ overwriting works in mixin

* Remove cache property

* Use assertion functions from sklearn.utils.testing

* remove the protocol argument in tests

* Rename attributes to better convey their purpose

* Revert change of module in TreeNoVersion

* Adhere to column-limit

* changelog entry

* Fix commit message
NelleV pushed a commit to NelleV/scikit-learn that referenced this pull request Aug 11, 2017
…cikit-learn#8324)

Fixes scikit-learn#8316


* Don't use test classes to group tests

* only use formatting for parts of the string that change

* Flake 8 column limit

* Make the modification of the estimator more explicit in the tests

* As suggested in code review, prefer formatting over two literals

* Also assert, that __setstate__ overwriting works in mixin

* Remove cache property

* Use assertion functions from sklearn.utils.testing

* remove the protocol argument in tests

* Rename attributes to better convey their purpose

* Revert change of module in TreeNoVersion

* Adhere to column-limit

* changelog entry

* Fix commit message
paulha pushed a commit to paulha/scikit-learn that referenced this pull request Aug 19, 2017
…cikit-learn#8324)

Fixes scikit-learn#8316


* Don't use test classes to group tests

* only use formatting for parts of the string that change

* Flake 8 column limit

* Make the modification of the estimator more explicit in the tests

* As suggested in code review, prefer formatting over two literals

* Also assert, that __setstate__ overwriting works in mixin

* Remove cache property

* Use assertion functions from sklearn.utils.testing

* remove the protocol argument in tests

* Rename attributes to better convey their purpose

* Revert change of module in TreeNoVersion

* Adhere to column-limit

* changelog entry

* Fix commit message
maskani-moh pushed a commit to maskani-moh/scikit-learn that referenced this pull request Nov 15, 2017
…cikit-learn#8324)

Fixes scikit-learn#8316


* Don't use test classes to group tests

* only use formatting for parts of the string that change

* Flake 8 column limit

* Make the modification of the estimator more explicit in the tests

* As suggested in code review, prefer formatting over two literals

* Also assert, that __setstate__ overwriting works in mixin

* Remove cache property

* Use assertion functions from sklearn.utils.testing

* remove the protocol argument in tests

* Rename attributes to better convey their purpose

* Revert change of module in TreeNoVersion

* Adhere to column-limit

* changelog entry

* Fix commit message
lemonlaug pushed a commit to lemonlaug/scikit-learn that referenced this pull request Jan 6, 2021
…cikit-learn#8324)

Fixes scikit-learn#8316


* Don't use test classes to group tests

* only use formatting for parts of the string that change

* Flake 8 column limit

* Make the modification of the estimator more explicit in the tests

* As suggested in code review, prefer formatting over two literals

* Also assert, that __setstate__ overwriting works in mixin

* Remove cache property

* Use assertion functions from sklearn.utils.testing

* remove the protocol argument in tests

* Rename attributes to better convey their purpose

* Revert change of module in TreeNoVersion

* Adhere to column-limit

* changelog entry

* Fix commit message
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants
0