|
1 | 1 | # Author: Gael Varoquaux
|
2 | 2 | # License: BSD 3 clause
|
3 | 3 |
|
4 |
| -import sys |
5 |
| - |
6 | 4 | import numpy as np
|
7 | 5 | import scipy.sparse as sp
|
8 | 6 |
|
|
15 | 13 | from sklearn.utils.testing import assert_raises
|
16 | 14 | from sklearn.utils.testing import assert_no_warnings
|
17 | 15 | from sklearn.utils.testing import assert_warns_message
|
| 16 | +from sklearn.utils.testing import assert_dict_equal |
18 | 17 |
|
19 | 18 | from sklearn.base import BaseEstimator, clone, is_classifier
|
20 | 19 | from sklearn.svm import SVC
|
@@ -314,48 +313,138 @@ def transform(self, X, y=None):
|
314 | 313 | assert_equal(e.scalar_param, cloned_e.scalar_param)
|
315 | 314 |
|
316 | 315 |
|
317 |
| -class TreeNoVersion(DecisionTreeClassifier): |
318 |
| - def __getstate__(self): |
319 |
| - return self.__dict__ |
| 316 | +def test_pickle_version_warning_is_not_raised_with_matching_version(): |
| 317 | + iris = datasets.load_iris() |
| 318 | + tree = DecisionTreeClassifier().fit(iris.data, iris.target) |
| 319 | + tree_pickle = pickle.dumps(tree) |
| 320 | + assert_true(b"version" in tree_pickle) |
| 321 | + tree_restored = assert_no_warnings(pickle.loads, tree_pickle) |
| 322 | + |
| 323 | + # test that we can predict with the restored decision tree classifier |
| 324 | + score_of_original = tree.score(iris.data, iris.target) |
| 325 | + score_of_restored = tree_restored.score(iris.data, iris.target) |
| 326 | + assert_equal(score_of_original, score_of_restored) |
320 | 327 |
|
321 | 328 |
|
322 | 329 | class TreeBadVersion(DecisionTreeClassifier):
|
323 | 330 | def __getstate__(self):
|
324 | 331 | return dict(self.__dict__.items(), _sklearn_version="something")
|
325 | 332 |
|
326 | 333 |
|
327 |
| -def test_pickle_version_warning():
328 |
| - # check that warnings are raised when unpickling in a different version |
| 334 | +pickle_error_message = ( |
| 335 | + "Trying to unpickle estimator {estimator} from " |
| 336 | + "version {old_version} when using version " |
| 337 | + "{current_version}. This might " |
| 338 | + "lead to breaking code or invalid results. " |
| 339 | + "Use at your own risk.") |
329 | 340 |
|
330 |
| - # first, check no warning when in the same version: |
331 |
| - iris = datasets.load_iris() |
332 |
| - tree = DecisionTreeClassifier().fit(iris.data, iris.target) |
333 |
| - tree_pickle = pickle.dumps(tree) |
334 |
| - assert_true(b"version" in tree_pickle) |
335 |
| - assert_no_warnings(pickle.loads, tree_pickle) |
336 | 341 |
|
337 |
| - # check that warning is raised on different version |
| 342 | +def test_pickle_version_warning_is_issued_upon_different_version(): |
| 343 | + iris = datasets.load_iris() |
338 | 344 | tree = TreeBadVersion().fit(iris.data, iris.target)
|
339 | 345 | tree_pickle_other = pickle.dumps(tree)
|
340 |
| - message = ("Trying to unpickle estimator TreeBadVersion from " |
341 |
| - "version {0} when using version {1}. This might lead to " |
342 |
| - "breaking code or invalid results. " |
343 |
| - "Use at your own risk.".format("something", |
344 |
| - sklearn.__version__)) |
| 346 | + message = pickle_error_message.format(estimator="TreeBadVersion", |
| 347 | + old_version="something", |
| 348 | + current_version=sklearn.__version__) |
345 | 349 | assert_warns_message(UserWarning, message, pickle.loads, tree_pickle_other)
|
346 | 350 |
|
347 |
| - # check that not including any version also works: |
| 351 | + |
| 352 | +class TreeNoVersion(DecisionTreeClassifier): |
| 353 | + def __getstate__(self): |
| 354 | + return self.__dict__ |
| 355 | + |
| 356 | + |
| 357 | +def test_pickle_version_warning_is_issued_when_no_version_info_in_pickle(): |
| 358 | + iris = datasets.load_iris() |
348 | 359 | # TreeNoVersion has no getstate, like pre-0.18
|
349 | 360 | tree = TreeNoVersion().fit(iris.data, iris.target)
|
350 | 361 |
|
351 | 362 | tree_pickle_noversion = pickle.dumps(tree)
|
352 | 363 | assert_false(b"version" in tree_pickle_noversion)
|
353 |
| - message = message.replace("something", "pre-0.18") |
354 |
| - message = message.replace("TreeBadVersion", "TreeNoVersion") |
| 364 | + message = pickle_error_message.format(estimator="TreeNoVersion", |
| 365 | + old_version="pre-0.18", |
| 366 | + current_version=sklearn.__version__) |
355 | 367 | # check we got the warning about using pre-0.18 pickle
|
356 | 368 | assert_warns_message(UserWarning, message, pickle.loads,
|
357 | 369 | tree_pickle_noversion)
|
358 | 370 |
|
359 |
| - # check that no warning is raised for external estimators |
360 |
| - TreeNoVersion.__module__ = "notsklearn" |
361 |
| - assert_no_warnings(pickle.loads, tree_pickle_noversion) |
| 371 | + |
| 372 | +def test_pickle_version_no_warning_is_issued_with_non_sklearn_estimator(): |
| 373 | + iris = datasets.load_iris() |
| 374 | + tree = TreeNoVersion().fit(iris.data, iris.target) |
| 375 | + tree_pickle_noversion = pickle.dumps(tree) |
| 376 | + try: |
| 377 | + module_backup = TreeNoVersion.__module__ |
| 378 | + TreeNoVersion.__module__ = "notsklearn" |
| 379 | + assert_no_warnings(pickle.loads, tree_pickle_noversion) |
| 380 | + finally: |
| 381 | + TreeNoVersion.__module__ = module_backup |
| 382 | + |
| 383 | + |
| 384 | +class DontPickleAttributeMixin(object): |
| 385 | + def __getstate__(self): |
| 386 | + data = self.__dict__.copy() |
| 387 | + data["_attribute_not_pickled"] = None |
| 388 | + return data |
| 389 | + |
| 390 | + def __setstate__(self, state): |
| 391 | + state["_restored"] = True |
| 392 | + self.__dict__.update(state) |
| 393 | + |
| 394 | + |
| 395 | +class MultiInheritanceEstimator(BaseEstimator, DontPickleAttributeMixin): |
| 396 | + def __init__(self, attribute_pickled=5): |
| 397 | + self.attribute_pickled = attribute_pickled |
| 398 | + self._attribute_not_pickled = None |
| 399 | + |
| 400 | + |
| 401 | +def test_pickling_when_getstate_is_overwritten_by_mixin(): |
| 402 | + estimator = MultiInheritanceEstimator() |
| 403 | + estimator._attribute_not_pickled = "this attribute should not be pickled" |
| 404 | + |
| 405 | + serialized = pickle.dumps(estimator) |
| 406 | + estimator_restored = pickle.loads(serialized) |
| 407 | + assert_equal(estimator_restored.attribute_pickled, 5) |
| 408 | + assert_equal(estimator_restored._attribute_not_pickled, None) |
| 409 | + assert_true(estimator_restored._restored) |
| 410 | + |
| 411 | + |
| 412 | +def test_pickling_when_getstate_is_overwritten_by_mixin_outside_of_sklearn(): |
| 413 | + try: |
| 414 | + estimator = MultiInheritanceEstimator() |
| 415 | + text = "this attribute should not be pickled" |
| 416 | + estimator._attribute_not_pickled = text |
| 417 | + old_mod = type(estimator).__module__ |
| 418 | + type(estimator).__module__ = "notsklearn" |
| 419 | + |
| 420 | + serialized = estimator.__getstate__() |
| 421 | + assert_dict_equal(serialized, {'_attribute_not_pickled': None, |
| 422 | + 'attribute_pickled': 5}) |
| 423 | + |
| 424 | + serialized['attribute_pickled'] = 4 |
| 425 | + estimator.__setstate__(serialized) |
| 426 | + assert_equal(estimator.attribute_pickled, 4) |
| 427 | + assert_true(estimator._restored) |
| 428 | + finally: |
| 429 | + type(estimator).__module__ = old_mod |
| 430 | + |
| 431 | + |
| 432 | +class SingleInheritanceEstimator(BaseEstimator): |
| 433 | + def __init__(self, attribute_pickled=5): |
| 434 | + self.attribute_pickled = attribute_pickled |
| 435 | + self._attribute_not_pickled = None |
| 436 | + |
| 437 | + def __getstate__(self): |
| 438 | + data = self.__dict__.copy() |
| 439 | + data["_attribute_not_pickled"] = None |
| 440 | + return data |
| 441 | + |
| 442 | + |
| 443 | +def test_pickling_works_when_getstate_is_overwritten_in_the_child_class(): |
| 444 | + estimator = SingleInheritanceEstimator() |
| 445 | + estimator._attribute_not_pickled = "this attribute should not be pickled" |
| 446 | + |
| 447 | + serialized = pickle.dumps(estimator) |
| 448 | + estimator_restored = pickle.loads(serialized) |
| 449 | + assert_equal(estimator_restored.attribute_pickled, 5) |
| 450 | + assert_equal(estimator_restored._attribute_not_pickled, None) |
0 commit comments