8000 Fix allowlisting of builtins for weights_only unpickler · pytorch/pytorch@cc99c01 · GitHub
[go: up one dir, main page]

Skip to content

Commit cc99c01

Browse files
Fix allowlisting of builtins for weights_only unpickler
ghstack-source-id: de329c7 Pull Request resolved: #129244
1 parent 0acd09a commit cc99c01

File tree

2 files changed

+43
-4
lines changed

2 files changed

+43
-4
lines changed

test/test_serialization.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1040,8 +1040,14 @@ def __reduce__(self):
10401040
self.assertIsNone(torch.load(f, weights_only=False))
10411041
f.seek(0)
10421042
# Safe load should assert
1043-
with self.assertRaisesRegex(pickle.UnpicklingError, "Unsupported global: GLOBAL __builtin__.print"):
1043+
with self.assertRaisesRegex(pickle.UnpicklingError, "Unsupported global: GLOBAL builtins.print"):
10441044
torch.load(f, weights_only=True)
1045+
try:
1046+
torch.serialization.add_safe_globals([print])
1047+
f.seek(0)
1048+
torch.load(f, weights_only=True)
1049+
finally:
1050+
torch.serialization.clear_safe_globals()
10451051

10461052
@parametrize('weights_only', (False, True))
10471053
def test_serialization_math_bits(self, weights_only):

torch/_weights_only_unpickler.py

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
# weights = torch.load(buf, weights_only = True)
2424

2525
import functools as _functools
26+
import warnings
2627
from collections import Counter, OrderedDict
2728
from pickle import (
2829
APPEND,
@@ -67,6 +68,16 @@
6768
from sys import maxsize
6869
from typing import Any, Dict, List
6970

71+
try:
72+
# We rely on this module in private cPython which provides dicts of
73+
# modules/functions that had their names changed from Python 2 to 3
74+
has_compat_pickle = True
75+
from _compat_pickle import IMPORT_MAPPING, NAME_MAPPING
76+
except ImportError:
77+
# To prevent warning on import torch, we warn in the Unpickler.load below
78+
has_compat_pickle = False
79+
IMPORT_MAPPING, NAME_MAPPING = dict(), dict()
80+
7081
import torch
7182

7283
_marked_safe_globals_list: List[Any] = []
@@ -97,7 +108,8 @@ def _clear_safe_globals():
97108
def _get_user_allowed_globals():
98109
rc: Dict[str, Any] = {}
99110
for f in _marked_safe_globals_list:
100-
rc[f"{f.__module__}.{f.__name__}"] = f
111+
module, name = f.__module__, f.__name__
112+
rc[f"{module}.{name}"] = f
101113
return rc
102114

103115

@@ -170,12 +182,20 @@ def __init__(self, file, *, encoding: str = "bytes"):
170182
self.readline = file.readline
171183
self.read = file.read
172184
self.memo: Dict[int, Any] = {}
185+
self.proto: int = -1
173186

174187
def load(self):
175188
"""Read a pickled object representation from the open file.
176189
177190
Return the reconstituted object hierarchy specified in the file.
178191
"""
192+
if not has_compat_pickle:
193+
warnings.warn(
194+
"Could not import IMPORT_MAPPING and NAME_MAPPING from _compat_pickle. "
195+
"If the default `pickle_protocol` was used at `torch.save` time, any functions or "
196+
"classes that are in these maps might not behave correctly if allowlisted via "
197+
"`torch.serialization.add_safe_globals()`."
198+
)
179199
self.metastack = []
180200
self.stack: List[Any] = []
181201
self.append = self.stack.append
@@ -190,6 +210,13 @@ def load(self):
190210
if key[0] == GLOBAL[0]:
191211
module = readline()[:-1].decode("utf-8")
192212
name = readline()[:-1].decode("utf-8")
213+
# Patch since torch.save default protocol is 2
214+
# users will be running this code in python > 3
215+
if self.proto == 2 and has_compat_pickle:
216+
if (module, name) in NAME_MAPPING:
217+
module, name = NAME_MAPPING[(module, name)]
218+
elif module in IMPORT_MAPPING:
219+
module = IMPORT_MAPPING[module]
193220
full_path = f"{module}.{name}"
194221
if full_path in _get_allowed_globals():
195222
self.append(_get_allowed_globals()[full_path])
@@ -334,8 +361,14 @@ def load(self):
334361
self.append(decode_long(data))
335362
# First and last deserializer ops
336363
elif key[0] == PROTO[0]:
337-
# Read and ignore proto version
338-
read(1)[0]
364+
self.proto = read(1)[0]
365+
if self.proto != 2:
366+
warnings.warn(
367+
f"Detected pickle protocol {self.proto} in the checkpoint, which was "
368+
"not the default pickle protocol used by `torch.load` (2). The weights_only "
369+
"Unpickler might not support all instructions implemented by this protocol, "
370+
"please file an issue for adding support if you encounter this."
371+
)
339372
elif key[0] == STOP[0]:
340373
rc = self.stack.pop()
341374
return rc

0 commit comments

Comments
 (0)
0