8000 ENH Adds InconsistentVersionWarning when unpickling version doesn't m… · scikit-learn/scikit-learn@70c690e · GitHub
[go: up one dir, main page]

Skip to content

Commit 70c690e

Browse files
ENH Adds InconsistentVersionWarning when unpickling version doesn't match (#25297)
Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
1 parent c676917 commit 70c690e

File tree

6 files changed

+69
-9
lines changed

6 files changed

+69
-9
lines changed

doc/model_persistence.rst

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,19 @@ with::
5555
available `here
5656
<https://joblib.readthedocs.io/en/latest/persistence.html>`_.
5757

58+
When an estimator is unpickled with a scikit-learn version that is inconsistent
59+
with the version the estimator was pickled with, a
60+
:class:`~sklearn.exceptions.InconsistentVersionWarning` is raised. This warning
61+
can be caught to obtain the original version the estimator was pickled with:
62+
63+
from sklearn.exceptions import InconsistentVersionWarning
64+
warnings.simplefilter("error", InconsistentVersionWarning)
65+
66+
try:
67+
est = pickle.loads("model_from_prevision_version.pickle")
68+
except InconsistentVersionWarning as w:
69+
print(w.original_sklearn_version)
70+
5871
.. _persistence_limitations:
5972

6073
Security & maintainability limitations

doc/modules/classes.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,7 @@ Samples generator
443443
exceptions.DataDimensionalityWarning
444444
exceptions.EfficiencyWarning
445445
exceptions.FitFailedWarning
446+
exceptions.InconsistentVersionWarning
446447
exceptions.NotFittedError
447448
exceptions.UndefinedMetricWarning
448449

doc/whats_new/v1.3.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,13 @@ Changelog
100100
out-of-bag scores via the `oob_scores_` or `oob_score_` attributes.
101101
:pr:`24882` by :user:`Ashwin Mathur <awinml>`.
102102

103+
:mod:`sklearn.exception`
104+
........................
105+
- |Feature| Added :class:`exception.InconsistentVersionWarning` which is raised
106+
when a scikit-learn estimator is unpickled with a scikit-learn version that is
107+
inconsistent with the sckit-learn verion the estimator was pickled with.
108+
:pr:`25297` by `Thomas Fan`_.
109+
103110
:mod:`sklearn.feature_extraction`
104111
.................................
105112

sklearn/base.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from .utils._tags import (
2020
_DEFAULT_TAGS,
2121
)
22+
from .exceptions import InconsistentVersionWarning
2223
from .utils.validation import check_X_y
2324
from .utils.validation import check_array
2425
from .utils.validation import _check_y
@@ -297,15 +298,11 @@ def __setstate__(self, state):
297298
pickle_version = state.pop("_sklearn_version", "pre-0.18")
298299
if pickle_version != __version__:
299300
warnings.warn(
300-
"Trying to unpickle estimator {0} from version {1} when "
301-
"using version {2}. This might lead to breaking code or "
302-
"invalid results. Use at your own risk. "
303-
"For more info please refer to:\n"
304-
"https://scikit-learn.org/stable/model_persistence.html"
305-
"#security-maintainability-limitations".format(
306-
self.__class__.__name__, pickle_version, __version__
301+
InconsistentVersionWarning(
302+
estimator_name=self.__class__.__name__,
303+
current_sklearn_version=__version__,
304+
original_sklearn_version=pickle_version,
307305
),
308-
UserWarning,
309306
)
310307
try:
311308
super().__setstate__(state)

sklearn/exceptions.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,3 +128,38 @@ class PositiveSpectrumWarning(UserWarning):
128128
129129
.. versionadded:: 0.22
130130
"""
131+
132+
133+
class InconsistentVersionWarning(UserWarning):
134+
"""Warning raised when an estimator is unpickled with a inconsistent version.
135+
136+
Parameters
137+
----------
138+
estimator_name : str
139+
Estimator name.
140+
141+
current_sklearn_version : str
142+
Current scikit-learn version.
143+
144+
original_sklearn_vers B41A ion : str
145+
Original scikit-learn version.
146+
"""
147+
148+
def __init__(
149+
self, *, estimator_name, current_sklearn_version, original_sklearn_version
150+
):
151+
self.estimator_name = estimator_name
152+
self.current_sklearn_version = current_sklearn_version
153+
self.original_sklearn_version = original_sklearn_version
154+
155+
def __str__(self):
156+
return (
157+
f"Trying to unpickle estimator {self.estimator_name} from version"
158+
f" {self.original_sklearn_version} when "
159+
f"using version {self.current_sklearn_version}. This might lead to breaking"
160+
" code or "
161+
"invalid results. Use at your own risk. "
162+
"For more info please refer to:\n"
163+
"https://scikit-learn.org/stable/model_persistence.html"
164+
"#security-maintainability-limitations"
165+
)

sklearn/tests/test_base.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from sklearn.tree import DecisionTreeClassifier
2323
from sklearn.tree import DecisionTreeRegressor
2424
from sklearn import datasets
25+
from sklearn.exceptions import InconsistentVersionWarning
2526

2627
from sklearn.base import TransformerMixin
2728
from sklearn.utils._mocking import MockDataFrame
@@ -397,9 +398,15 @@ def test_pickle_version_warning_is_issued_upon_different_version():
397398
old_version="something",
398399
current_version=sklearn.__version__,
399400
)
400-
with pytest.warns(UserWarning, match=message):
401+
with pytest.warns(UserWarning, match=message) as warning_record:
401402
pickle.loads(tree_pickle_other)
402403

404+
message = warning_record.list[0].message
405+
assert isinstance(message, InconsistentVersionWarning)
406+
assert message.estimator_name == "TreeBadVersion"
407+
assert message.original_sklearn_version == "something"
408+
assert message.current_sklearn_version == sklearn.__version__
409+
403410

404411
class TreeNoVersion(DecisionTreeClassifier):
405412
def __getstate__(self):

0 commit comments

Comments
 (0)
0