8000 Cherry pick #129244 #129251 #129509 (#129574) · pytorch/pytorch@1f84579 · GitHub
[go: up one dir, main page]

Skip to content

Commit 1f84579

Browse files
* Fix allowlisting of builtins for weights_only unpickler (#129244) Since we use [`DEFAULT_PROTOCOL=2`](https://github.com/pytorch/pytorch/blob/main/torch/serialization.py#L62), some functions/classes that were renamed from python 2-->3 will be pickled with their python2 name. This PR ensures that when a mod `GLOBAL <python2_mod>.<python2_name> ` is encountered, [following the strategy used by pickle](https://github.com/python/cpython/blob/main/Lib/pickle.py#L1590C13-L1593C63) it is properly mapped to `<python3_mod>.<python3_name>`. This fix ensures that `add_safe_globals` works properly for such functions/classes (i.e. users will allowlist the python3 func and the weights_only unpickler will do the appropriate translation when checking whether a class was allowlisted). An example is as follows: `__builtin__` was named to `builtins`, see the [release notes for Python 3.0](https://docs.python.org/3/whatsnew/3.0.html) > Renamed module `__builtin__` to [`builtins`](https://docs.python.org/3/library/builtins.html#module-builtins) (removing the underscores, adding an ‘s’). The __builtins__ variable found in most global namespaces is unchanged. To modify a builtin, you should use [builtins](https://docs.python.org/3/library/builtins.html#module-builtins), not `__builtins__`! However, since we use [`DEFAULT_PROTOCOL=2`](https://github.com/pytorch/pytorch/blob/main/torch/serialization.py#L62), builtins will be pickled with their module string as `__builtin__`. ```python >>> import pickle >>> import pickletools >>> print.__module__ 'builtins' >>> with open('print.pkl', 'wb') as f: >>> pickle.dump(print, f, protocol=2) # 2 because this is the default protocol used by pytorch >>> with open('print.pkl', 'rb') as f: >>> pickletools.dis(f) 0: \x80 PROTO 2 2: c GLOBAL '__builtin__ print' # pickle saves the module string as __builtin__ !!! :( 21: q BINPUT 0 23: . STOP ``` Pull Request resolved: #129244 Approved by: https://github.com/albanD * Allow BUILD/NEWOBJ instruction for items added via torch.serialization.add_safe_globals (#129251) Previously, allowlisting functions/classes via `torch.serialization.add_safe_globals(obj)` for the `weights_only` Unpickler had the following effect: - For a [`GLOBAL`](https://github.com/python/cpython/blob/3.12/Lib/pickletools.py#L1926-L1939) instruction, `GLOBAL obj.__module__ obj.__name__` would be allowed and translated back to obj to be pushed back to the stack. - For a [`REDUCE`](https://github.com/python/cpython/blob/3.12/Lib/pickletools.py#L1926-L1982) instruction where we expect the stack to contain `func` and `args`, `func` is allowed if it was added via `add_safe_globals` However, it did not have an effect on `BUILD` and `NEWOBJ` instructions Some classes may be rebuilt via [`NEWOBJ`](https://github.com/python/cpython/blob/3.12/Lib/pickletools.py#L2091-L2104) instruction, which indicates that their constructor should be used to rebuild the class. Further, a [`BUILD`](https://github.com/python/cpython/blob/3.12/Lib/pickletools.py#L1984-L2007) instruction might be used if an object's `__reduce__`/`__reduce_ex__` returns a non-None value for `state`. Which indicates a `__setstate__` or `__dict__.update`. **This PR makes sure that adding objects to the allowlist will also allow `NEWOBJ` and `BUILD` instructions for them.** In particular, the update for `NEWOBJ` should unblock allowlisting of [`ScaledMMConfig`](https://github.com/pytorch-labs/float8_experimental/blob/d4ade877dff327ea7f51e91f7cc218ae956e8cfd/float8_experimental/float8_tensor.py#L26-L30) in float8_experimental @drisspg Pull Request resolved: #129251 Approved by: https://github.com/albanD ghstack dependencies: #129244 * Remove dependency on private _compat_pickle in CPython ghstack-source-id: 7d6ee40 Pull Request resolved: #129509
1 parent 4d83bca commit 1f84579

File tree

5 files changed

+157
-9
lines changed

5 files changed

+157
-9
lines changed

test/distributed/_tensor/test_dtensor.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -536,6 +536,16 @@ def test_dtensor_save_load(self):
536536
buffer.seek(0)
537537
reloaded_st = torch.load(buffer)
538538
self.assertEqual(sharded_tensor, reloaded_st)
539+
# Test weights_only load
540+
try:
541+
torch.serialization.add_safe_globals(
542+
[DTensor, DeviceMesh, Shard, DTensorSpec, TensorMeta]
543+
)
544+
buffer.seek(0)
545+
reloaded_st = torch.load(buffer, weights_only=True)
546+
self.assertEqual(sharded_tensor, reloaded_st)
547+
finally:
548+
torch.serialization.clear_safe_globals()
539549

540550

541551
class DTensorMeshTest(DTensorTestBase):

test/test_serialization.py

Lines changed: 68 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import shutil
1616
import pathlib
1717
import platform
18-
from collections import OrderedDict
18+
from collections import namedtuple, OrderedDict
1919
from copy import deepcopy
2020
from itertools import product
2121

@@ -804,6 +804,17 @@ def wrapper(*args, **kwargs):
804804
def __exit__(self, *args, **kwargs):
805805
torch.save = self.torch_save
806806

807+
Point = namedtuple('Point', ['x', 'y'])
808+
809+
class ClassThatUsesBuildInstruction:
810+
def __init__(self, num):
811+
self.num = num
812+
813+
def __reduce_ex__(self, proto):
814+
# Third item, state here will cause pickle to push a BUILD instruction
815+
return ClassThatUsesBuildInstruction, (self.num,), {'foo': 'bar'}
816+
817+
807818
@unittest.skipIf(IS_WINDOWS, "NamedTemporaryFile on windows")
808819
class TestBothSerialization(TestCase):
809820
@parametrize("weights_only", (True, False))
@@ -1040,8 +1051,63 @@ def __reduce__(self):
10401051
self.assertIsNone(torch.load(f, weights_only=False))
10411052
f.seek(0)
10421053
# Safe load should assert
1043-
with self.assertRaisesRegex(pickle.UnpicklingError, "Unsupported global: GLOBAL __builtin__.print"):
1054+
with self.assertRaisesRegex(pickle.UnpicklingError, "Unsupported global: GLOBAL builtins.print"):
1055+
torch.load(f, weights_only=True)
1056+
try:
1057+
torch.serialization.add_safe_globals([print])
1058+
f.seek(0)
1059+
torch.load(f, weights_only=True)
1060+
finally:
1061+
torch.serialization.clear_safe_globals()
1062+
1063+
def test_weights_only_safe_globals_newobj(self):
1064+
# This will use NEWOBJ
1065+
p = Point(x=1, y=2)
1066+
with BytesIOContext() as f:
1067+
torch.save(p, f)
1068+
f.seek(0)
1069+
with self.assertRaisesRegex(pickle.UnpicklingError,
1070+
"GLOBAL __main__.Point was not an allowed global by default"):
1071+
torch.load(f, weights_only=True)
1072+
f.seek(0)
1073+
try:
1074+
torch.serialization.add_safe_globals([Point])
1075+
loaded_p = torch.load(f, weights_only=True)
1076+
self.assertEqual(loaded_p, p)
1077+
finally:
1078+
torch.serialization.clear_safe_globals()
1079+
1080+
def test_weights_only_safe_globals_build(self):
1081+
counter = 0
1082+
1083+
def fake_set_state(obj, *args):
1084+
nonlocal counter
1085+
counter += 1
1086+
1087+
c = ClassThatUsesBuildInstruction(2)
1088+
with BytesIOContext() as f:
1089+
torch.save(c, f)
1090+
f.seek(0)
1091+
with self.assertRaisesRegex(pickle.UnpicklingError,
1092+
"GLOBAL __main__.ClassThatUsesBuildInstruction was not an allowed global by default"):
10441093
torch.load(f, weights_only=True)
1094+
try:
1095+
torch.serialization.add_safe_globals([ClassThatUsesBuildInstruction])
1096+
# Test dict update path
1097+
f.seek(0)
1098+
loaded_c = torch.load(f, weights_only=True)
1099+
self.assertEqual(loaded_c.num, 2)
1100+
self.assertEqual(loaded_c.foo, 'bar')
1101+
# Test setstate path
1102+
ClassThatUsesBuildInstruction.__setstate__ = fake_set_state
1103+
f.seek(0)
1104+
loaded_c = torch.load(f, weights_only=True)
1105+
self.assertEqual(loaded_c.num, 2)
1106+
self.assertEqual(counter, 1)
1107+
self.assertFalse(hasattr(loaded_c, 'foo'))
1108+
finally:
1109+
torch.serialization.clear_safe_globals()
1110+
ClassThatUsesBuildInstruction.__setstate__ = None
10451111

10461112
@parametrize('weights_only', (False, True))
10471113
def test_serialization_math_bits(self, weights_only):

torch/_utils.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -962,3 +962,47 @@ def fire_callbacks(self, *args: P.args, **kwargs: P.kwargs) -> None:
962962
logger.exception(
963963
"Exception in callback for %s registered with gpu trace", self.name
964964
)
965+
966+
967+
# IMPORT_MAPPING and NAME_MAPPING are adapted from https://github.com/python/cpython/blob/main/Lib/_compat_pickle.py
968+
# for use in the weights_only Unpickler.
969+
970+
IMPORT_MAPPING = {
971+
"__builtin__": "builtins",
972+
"copy_reg": "copyreg",
973+
"Queue": "queue",
974+
"repr": "reprlib",
975+
"_abcoll": "collections.abc",
976+
# Non-mutual mappings.
977+
"UserDict": "collections",
978+
"UserList": "collections",
979+
"UserString": "collections",
980+
"whichdb": "dbm",
981+
"StringIO": "io",
982+
"cStringIO": "io",
983+
}
984+
985+
986+
# This contains rename rules that are easy to handle. We ignore the more
987+
# complex stuff (e.g. mapping the names in the urllib and types modules).
988+
# These rules should be run before import names are fixed.
989+
NAME_MAPPING = {
990+
("__builtin__", "xrange"): ("builtins", "range"),
991+
("__builtin__", "reduce"): ("functools", "reduce"),
992+
("__builtin__", "intern"): ("sys", "intern"),
993+
("__builtin__", "unichr"): ("builtins", "chr"),
994+
("__builtin__", "unicode"): ("builtins", "str"),
995+
("__builtin__", "long"): ("builtins", "int"),
996+
("itertools", "izip"): ("builtins", "zip"),
997+
("itertools", "imap"): ("builtins", "map"),
998+
("itertools", "ifilter"): ("builtins", "filter"),
999+
("itertools", "ifilterfalse"): ("itertools", "filterfalse"),
1000+
("itertools", "izip_longest"): ("itertools", "zip_longest"),
1001+
("UserDict", "IterableUserDict"): ("collections", "UserDict"),
1002+
("UserList", "UserList"): ("collections", "UserList"),
1003+
("UserString", "UserString"): ("collections", "UserString"),
1004+
# Non-mutual mappings.
1005+
("__builtin__", "basestring"): ("builtins", "str"),
1006+
("exceptions", "StandardError"): ("builtins", "Exception"),
1007+
("UserDict", "UserDict"): ("collections", "UserDict"),
1008+
}

torch/_weights_only_unpickler.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
# weights = torch.load(buf, weights_only = True)
2424

2525
import functools as _functools
26+
import warnings
2627
from collections import Counter, OrderedDict
2728
from pickle import (
2829
APPEND,
@@ -68,6 +69,8 @@
6869
from typing import Any, Dict, List
6970

7071
import torch
72+
from torch._utils import IMPORT_MAPPING, NAME_MAPPING
73+
7174

7275
_marked_safe_globals_list: List[Any] = []
7376

@@ -97,7 +100,8 @@ def _clear_safe_globals():
97100
def _get_user_allowed_globals():
98101
rc: Dict[str, Any] = {}
99102
for f in _marked_safe_globals_list:
100-
rc[f"{f.__module__}.{f.__name__}"] = f
103+
module, name = f.__module__, f.__name__
104+
rc[f"{module}.{name}"] = f
101105
return rc
102106

103107

@@ -170,6 +174,7 @@ def __init__(self, file, *, encoding: str = "bytes"):
170174
self.readline = file.readline
171175
self.read = file.read
172176
self.memo: Dict[int, Any] = {}
177+
self.proto: int = -1
173178

174179
def load(self):
175180
"""Read a pickled object representation from the open file.
@@ -190,6 +195,13 @@ def load(self):
190195
if key[0] == GLOBAL[0]:
191196
module = readline()[:-1].decode("utf-8")
192197
name = readline()[:-1].decode("utf-8")
198+
# Patch since torch.save default protocol is 2
199+
# users will be running this code in python > 3
200+
if self.proto == 2:
201+
if (module, name) in NAME_MAPPING:
202+
module, name = NAME_MAPPING[(module, name)]
203+
elif module in IMPORT_MAPPING:
204+
module = IMPORT_MAPPING[module]
193205
full_path = f"{module}.{name}"
194206
if full_path in _get_allowed_globals():
195207
self.append(_get_allowed_globals()[full_path])
@@ -204,9 +216,12 @@ def load(self):
204216
elif key[0] == NEWOBJ[0]:
205217
args = self.stack.pop()
206218
cls = self.stack.pop()
207-
if cls is not torch.nn.Parameter:
219+
if cls is torch.nn.Parameter:
220+
self.append(torch.nn.Parameter(*args))
221+
elif cls in _get_user_allowed_globals().values():
222+
self.append(cls.__new__(cls, *args))
223+
else:
208224
raise RuntimeError(f"Trying to instantiate unsupported class {cls}")
209-
self.append(torch.nn.Parameter(*args))
210225
elif key[0] == REDUCE[0]:
211226
args = self.stack.pop()
212227
func = self.stack[-1]
@@ -228,9 +243,14 @@ def load(self):
228243
inst.__setstate__(state)
229244
elif type(inst) is OrderedDict:
230245
inst.__dict__.update(state)
246+
elif type(inst) in _get_user_allowed_globals().values():
247+
if hasattr(inst, "__setstate__"):
248+
inst.__setstate__(state)
249+
else:
250+
inst.__dict__.update(state)
231251
else:
232252
raise RuntimeError(
233-
f"Can only build Tensor, parameter or dict objects, but got {type(inst)}"
253+
f"Can only build Tensor, parameter or OrderedDict objects, but got {type(inst)}"
234254
)
235255
# Stack manipulation
236256
elif key[0] == APPEND[0]:
@@ -334,8 +354,14 @@ def load(self):
334354
self.append(decode_long(data))
335355
# First and last deserializer ops
336356
elif key[0] == PROTO[0]:
337-
# Read and ignore proto version
338-
read(1)[0]
357+
self.proto = read(1)[0]
358+
if self.proto != 2:
359+
warnings.warn(
360+
f"Detected pickle protocol {self.proto} in the checkpoint, which was "
361+
"not the default pickle protocol used by `torch.load` (2). The weights_only "
362+
"Unpickler might not support all instructions implemented by this protocol, "
363+
"please file an issue for adding support if you encounter this."
364+
)
339365
elif key[0] == STOP[0]:
340366
rc = self.stack.pop()
341367
return rc

torch/serialization.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,9 @@ def get_safe_globals() -> List[Any]:
166166

167167
def add_safe_globals(safe_globals: List[Any]) -> None:
168168
'''
169-
Marks the given globals as safe for ``weights_only`` load.
169+
Marks the given globals as safe for ``weights_only`` load. For example, functions
170+
added to this list can be called during unpickling, classes could be instantiated
171+
and have state set.
170172
171173
Args:
172174
safe_globals (List[Any]): list of globals to mark as safe

0 commit comments

Comments
 (0)
0