8000 Revert "Cherry pick #129244, #129251, #129239, 129396 into release/2.… · pytorch/pytorch@4d83bca · GitHub
[go: up one dir, main page]

Skip to content

Commit 4d83bca

Browse files
authored
Revert "Cherry pick #129244, #129251, #129239, 129396 into release/2.4" (#129571)
Revert "Cherry pick #129244, #129251, #129239, 129396 into release/2.4 (#129478)" This reverts commit 22a4d46.
1 parent 04339ee commit 4d83bca

File tree

5 files changed

+20
-183
lines changed

5 files changed

+20
-183
lines changed

test/distributed/_tensor/test_dtensor.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -536,16 +536,6 @@ def test_dtensor_save_load(self):
536536
buffer.seek(0)
537537
reloaded_st = torch.load(buffer)
538538
self.assertEqual(sharded_tensor, reloaded_st)
539-
# Test weights_only load
540-
try:
541-
torch.serialization.add_safe_globals(
542-
[DTensor, DeviceMesh, Shard, DTensorSpec, 8000 TensorMeta]
543-
)
544-
buffer.seek(0)
545-
reloaded_st = torch.load(buffer, weights_only=True)
546-
self.assertEqual(sharded_tensor, reloaded_st)
547-
finally:
548-
torch.serialization.clear_safe_globals()
549539

550540

551541
class DTensorMeshTest(DTensorTestBase):

test/test_nn.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1801,35 +1801,26 @@ def test_parameterlistdict_pickle(self):
18011801
m = nn.ParameterList(map(nn.Parameter, [torch.rand(2), torch.rand(2)]))
18021802
with warnings.catch_warnings(record=True) as w:
18031803
m = pickle.loads(pickle.dumps(m))
1804-
# warning from torch.load call in _load_from_bytes
1805-
num_warnings = 2 if torch._dynamo.is_compiling() else 1
1806-
self.assertTrue(len(w) == num_warnings)
1807-
self.assertEqual(w[0].category, FutureWarning)
1804+
self.assertTrue(len(w) == 0)
18081805

18091806
# Test whether loading from older checkpoints works without triggering warnings
18101807
m = nn.ParameterList(map(nn.Parameter, [torch.rand(2), torch.rand(2)]))
18111808
del m._forward_pre_hooks, m._state_dict_hooks, m._load_state_dict_pre_hooks, m._non_persistent_buffers_set
18121809
with warnings.catch_warnings(record=True) as w:
18131810
m = pickle.loads(pickle.dumps(m))
1814-
# warning from torch.load call in _load_from_bytes
1815-
self.assertTrue(len(w) == 1)
1816-
self.assertEqual(w[0].category, FutureWarning)
1811+
self.assertTrue(len(w) == 0)
18171812

18181813
m = nn.ParameterDict({"a": nn.Parameter(torch.rand(2)), "b": nn.Parameter(torch.rand(2))})
18191814
with warnings.catch_warnings(record=True) as w:
18201815
m = pickle.loads(pickle.dumps(m))
1821-
# warning from torch.load call in _load_from_bytes
1822-
self.assertTrue(len(w) == 1)
1823-
self.assertEqual(w[0].category, FutureWarning)
1816+
self.assertTrue(len(w) == 0)
18241817

18251818
# Test whether loading from older checkpoints works without triggering warnings
18261819
m = nn.ParameterDict({"a": nn.Parameter(torch.rand(2)), "b": nn.Parameter(torch.rand(2))})
18271820
del m._forward_pre_hooks, m._state_dict_hooks, m._load_state_dict_pre_hooks, m._non_persistent_buffers_set
18281821
with warnings.catch_warnings(record=True) as w:
18291822
m = pickle.loads(pickle.dumps(m))
1830-
# warning from torch.load call in _load_from_bytes
1831-
self.assertTrue(len(w) == 1)
1832-
self.assertEqual(w[0].category, FutureWarning)
1823+
self.assertTrue(len(w) == 0)
18331824

18341825
def test_weight_norm_pickle(self):
18351826
m = torch.nn.utils.weight_norm(nn.Linear(5, 7))

test/test_serialization.py

Lines changed: 6 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import shutil
1616
import pathlib
1717
import platform
18-
from collections import namedtuple, OrderedDict
18+
from collections import OrderedDict
1919
from copy import deepcopy
2020
from itertools import product
2121

@@ -804,17 +804,6 @@ def wrapper(*args, **kwargs):
804804
def __exit__(self, *args, **kwargs):
805805
torch.save = self.torch_save
806806

807-
Point = namedtuple('Point', ['x', 'y'])
808-
809-
class ClassThatUsesBuildInstruction:
810-
def __init__(self, num):
811-
self.num = num
812-
813-
def __reduce_ex__(self, proto):
814-
# Third item, state here will cause pickle to push a BUILD instruction
815-
return ClassThatUsesBuildInstruction, (self.num,), {'foo': 'bar'}
816-
817-
818807
@unittest.skipIf(IS_WINDOWS, "NamedTemporaryFile on windows")
819808
class TestBothSerialization(TestCase):
820809
@parametrize("weights_only", (True, False))
@@ -837,6 +826,7 @@ def test(f_new, f_old):
837826
test(f_new, f_old)
838827
self.assertTrue(len(w) == 0, msg=f"Expected no warnings but got {[str(x) for x in w]}")
839828

829+
840830
class TestOldSerialization(TestCase, SerializationMixin):
841831
# unique_key is necessary because on Python 2.7, if a warning passed to
842832
# the warning module is the same, it is not raised again.
@@ -864,8 +854,7 @@ def import_module(name, filename):
864854
loaded = torch.load(checkpoint)
865855
self.assertTrue(isinstance(loaded, module.Net))
866856
if can_retrieve_source:
867-
self.assertEqual(len(w), 1)
868-
self.assertEqual(w[0].category, FutureWarning)
857+
self.assertEqual(len(w), 0)
869858

870859
# Replace the module with different source
871860
fname = get_file_path_2(os.path.dirname(os.path.dirname(torch.__file__)), 'torch', 'testing',
@@ -876,8 +865,8 @@ def import_module(name, filename):
876865
loaded = torch.load(checkpoint)
877866
self.assertTrue(isinstance(loaded, module.Net))
878867
if can_retrieve_source:
879-
self.assertEqual(len(w), 2)
880-
self.assertTrue(w[1].category, 'SourceChangeWarning')
868+
self.assertEqual(len(w), 1)
869+
self.assertTrue(w[0].category, 'SourceChangeWarning')
881870

882871
def test_serialization_container(self):
883872
self._test_serialization_container('file', tempfile.NamedTemporaryFile)
@@ -1051,63 +1040,8 @@ def __reduce__(self):
10511040
self.assertIsNone(torch.load(f, weights_only=False))
10521041
f.seek(0)
10531042
# Safe load should assert
1054-
with self.assertRaisesRegex(pickle.UnpicklingError, "Unsupported global: GLOBAL builtins.print"):
1055-
torch.load(f, weights_only=True)
1056-
try:
1057-
torch.serialization.add_safe_globals([print])
1058-
f.seek(0)
1059-
torch.load(f, weights_only=True)
1060-
finally:
1061-
torch.serialization.clear_safe_globals()
1062-
1063-
def test_weights_only_safe_globals_newobj(self):
1064-
# This will use NEWOBJ
1065-
p = Point(x=1, y=2)
1066-
with BytesIOContext() as f:
1067-
torch.save(p, f)
1068-
f.seek(0)
1069-
with self.assertRaisesRegex(pickle.UnpicklingError,
1070-
"GLOBAL __main__.Point was not an allowed global by default"):
1043+
with self.assertRaisesRegex(pickle.UnpicklingError, "Unsupported global: GLOBAL __builtin__.print"):
10711044
torch.load(f, weights_only=True)
1072-
f.seek(0)
1073-
try:
1074-
torch.serialization.add_safe_globals([Point])
1075-
loaded_p = torch.load(f, weights_only=True)
1076-
self.assertEqual(loaded_p, p)
1077-
finally:
1078-
torch.serialization.clear_safe_globals()
1079-
1080-
def test_weights_only_safe_globals_build(self):
1081-
counter = 0
1082-
1083-
def fake_set_state(obj, *args):
1084-
nonlocal counter
1085-
counter += 1
1086-
1087-
c = ClassThatUsesBuildInstruction(2)
1088-
with BytesIOContext() as f:
1089-
torch.save(c, f)
1090-
f.seek(0)
1091-
with self.assertRaisesRegex(pickle.UnpicklingError,
1092-
"GLOBAL __main__.ClassThatUsesBuildInstruction was not an allowed global by default"):
1093-
torch.load(f, weights_only=True)
1094-
try:
1095-
torch.serialization.add_safe_globals([ClassThatUsesBuildInstruction])
1096-
# Test dict update path
1097-
f.seek(0)
1098-
loaded_c = torch.load(f, weights_only=True)
1099-
self.assertEqual(loaded_c.num, 2)
1100-
self.assertEqual(loaded_c.foo, 'bar')
1101-
# Test setstate path
1102-
ClassThatUsesBuildInstruction.__setstate__ = fake_set_state
1103-
f.seek(0)
1104-
loaded_c = torch.load(f, weights_only=True)
1105-
self.assertEqual(loaded_c.num, 2)
1106-
self.assertEqual(counter, 1)
1107-
self.assertFalse(hasattr(loaded_c, 'foo'))
1108-
finally:
1109-
torch.serialization.clear_safe_globals()
1110-
ClassThatUsesBuildInstruction.__setstate__ = None
11111045

11121046
@parametrize('weights_only', (False, True))
11131047
def test_serialization_math_bits(self, weights_only):

torch/_weights_only_unpickler.py

Lines changed: 6 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
# weights = torch.load(buf, weights_only = True)
2424

2525
import functools as _functools
26-
import warnings
2726
from collections import Counter, OrderedDict
2827
from pickle import (
2928
APPEND,
@@ -68,16 +67,6 @@
6867
from sys import maxsize
6968
from typing import Any, Dict, List
7069

71-
try:
72-
# We rely on this module in private cPython which provides dicts of
73-
# modules/functions that had their names changed from Python 2 to 3
74-
has_compat_pickle = True
75-
from _compat_pickle import IMPORT_MAPPING, NAME_MAPPING
76-
except ImportError:
77-
# To prevent warning on import torch, we warn in the Unpickler.load below
78-
has_compat_pickle = False
79-
IMPORT_MAPPING, NAME_MAPPING = dict(), dict()
80-
8170
import torch
8271

8372
_marked_safe_globals_list: List[Any] = []
@@ -108,8 +97,7 @@ def _clear_safe_globals():
10897
def _get_user_allowed_globals():
10998
rc: Dict[str, Any] = {}
11099
for f in _marked_safe_globals_list:
111-
module, name = f.__module__, f.__name__
112-
rc[f"{module}.{name}"] = f
100+
rc[f"{f.__module__}.{f.__name__}"] = f
113101
return rc
114102

115103

@@ -182,20 +170,12 @@ def __init__(self, file, *, encoding: str = "bytes"):
182170
self.readline = file.readline
183171
self.read = file.read
184172
self.memo: Dict[int, Any] = {}
185-
self.proto: int = -1
186173

187174
def load(self):
188175
"""Read a pickled object representation from the open file.
189176
190177
Return the reconstituted object hierarchy specified in the file.
191178
"""
192-
if not has_compat_pickle:
193-
warnings.warn(
194-
"Could not import IMPORT_MAPPING and NAME_MAPPING from _compat_pickle. "
195-
"If the default `pickle_protocol` was used at `torch.save` time, any functions or "
196-
"classes that are in these maps might not behave correctly if allowlisted via "
197-
"`torch.serialization.add_safe_globals()`."
198-
)
199179
self.metastack = []
200180
self.stack: List[Any] = []
201181
self.append = self.stack.append
@@ -210,13 +190,6 @@ def load(self):
210190
if key[0] == GLOBAL[0]:
211191
module = readline()[:-1].decode("utf-8")
212192
name = readline()[:-1].decode("utf-8")
213-
# Patch since torch.save default protocol is 2
214-
# users will be running this code in python > 3
215-
if self.proto == 2 and has_compat_pickle:
216-
if (module, name) in NAME_MAPPING:
217-
module, name = NAME_MAPPING[(module, name)]
218-
elif module in IMPORT_MAPPING:
219-
module = IMPORT_MAPPING[module]
220193
full_path = f"{module}.{name}"
221194
if full_path in _get_allowed_globals():
222195
self.append(_get_allowed_globals()[full_path])
@@ -231,12 +204,9 @@ def load(self):
231204
elif key[0] == NEWOBJ[0]:
232205
args = self.stack.pop()
233206
cls = self.stack.pop()
234-
if cls is torch.nn.Parameter:
235-
self.append(torch.nn.Parameter(*args))
236-
elif cls in _get_user_allowed_globals().values():
237-
self.append(cls.__new__(cls, *args))
238-
else:
207+
if cls is not torch.nn.Parameter:
239208
raise RuntimeError(f"Trying to instantiate unsupported class {cls}")
209+
self.append(torch.nn.Parameter(*args))
240210
elif key[0] == REDUCE[0]:
241211
args = self.stack.pop()
242212
func = self.stack[-1]
@@ -258,14 +228,9 @@ def load(self):
258228
inst.__setstate__(state)
259229
elif type(inst) is OrderedDict:
260230
inst.__dict__.update(state)
261-
elif type(inst) in _get_user_allowed_globals().values():
262-
if hasattr(inst, "__setstate__"):
263-
inst.__setstate__(state)
264-
else:
265-
inst.__dict__.update(state)
266231
else:
267232
raise RuntimeError(
268-
f"Can only build Tensor, parameter or OrderedDict objects, but got {type(inst)}"
233+
f"Can only build Tensor, parameter or dict objects, but got {type(inst)}"
269234
)
270235
# Stack manipulation
271236
elif key[0] == APPEND[0]:
@@ -369,14 +334,8 @@ def load(self):
369334
self.append(decode_long(data))
370335
# First and last deserializer ops
371336
elif key[0] == PROTO[0]:
372-
self.proto = read(1)[0]
373-
if self.proto != 2:
374-
warnings.warn(
375-
f"Detected pickle protocol {self.proto} in the checkpoint, which was "
376-
"not the default pickle protocol used by `torch.load` (2). The weights_only "
377-
"Unpickler might not support all instructions implemented by this protocol, "
378-
"please file an issue for adding support if you encounter this."
379-
)
337+
# Read and ignore proto version
338+
read(1)[0]
380339
elif key[0] == STOP[0]:
381340
rc = self.stack.pop()
382341
return rc

torch/serialization.py

Lines changed: 4 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -165,30 +165,12 @@ def get_safe_globals() -> List[Any]:
165165
return _weights_only_unpickler._get_safe_globals()
166166

167167
def add_safe_globals(safe_globals: List[Any]) -> None:
168-
"""
169-
Marks the given globals as safe for ``weights_only`` load. For example, functions
170-
added to this list can be called during unpickling, classes could be instantiated
171-
and have state set.
168+
'''
169+
Marks the given globals as safe for ``weights_only`` load.
172170
173171
Args:
174172
safe_globals (List[Any]): list of globals to mark as safe
175-
176-
Example:
177-
>>> # xdoctest: +SKIP("Can't torch.save(t, ...) as doctest thinks MyTensor is defined on torch.serialization")
178-
>>> import tempfile
179-
>>> class MyTensor(torch.Tensor):
180-
... pass
181-
>>> t = MyTensor(torch.randn(2, 3))
182-
>>> with tempfile.NamedTemporaryFile() as f:
183-
... torch.save(t, f.name)
184-
# Running `torch.load(f.name, weights_only=True)` will fail with
185-
# Unsupported global: GLOBAL __main__.MyTensor was not an allowed global by default.
186-
# Check the code and make sure MyTensor is safe to be used when loaded from an arbitrary checkpoint.
187-
... torch.serialization.add_safe_globals([MyTensor])
188-
... torch.load(f.name, weights_only=True)
189-
# MyTensor([[-0.5024, -1.8152, -0.5455],
190-
# [-0.8234, 2.0500, -0.3657]])
191-
"""
173+
'''
192174
_weights_only_unpickler._add_safe_globals(safe_globals)
193175

194176
def _is_zipfile(f) -> bool:
@@ -890,7 +872,7 @@ def load(
890872
map_location: MAP_LOCATION = None,
891873
pickle_module: Any = None,
892874
*,
893-
weights_only: Optional[bool] = None,
875+
weights_only: bool = False,
894876
mmap: Optional[bool] = None,
895877
**pickle_load_args: Any
896878
) -> Any:
@@ -1000,11 +982,6 @@ def load(
1000982
" with `weights_only` please check the recommended steps in the following error message."
1001983
" WeightsUnpickler error: "
1002984
)
1003-
if weights_only is None:
1004-
weights_only, warn_weights_only = False, True
1005-
else:
1006-
warn_weights_only = False
1007-
1008985
# Add ability to force safe only weight loads via environment variable
1009986
if os.getenv("TORCH_FORCE_WEIGHTS_ONLY_LOAD", "0").lower() in ['1', 'y', 'yes', 'true']:
1010987
weights_only = True
@@ -1014,20 +991,6 @@ def load(
1014991
raise RuntimeError("Can not safely load weights when explicit pickle_module is specified")
1015992
else:
1016993
if pickle_module is None:
1017-
if warn_weights_only:
1018-
warnings.warn(
1019-
"You are using `torch.load` with `weights_only=False` (the current default value), which uses "
1020-
"the default pickle module implicitly. It is possible to construct malicious pickle data "
1021-
"which will execute arbitrary code during unpickling (See "
1022-
"https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). "
1023-
"In a future release, the default value for `weights_only` will be flipped to `True`. This "
1024-
"limits the functions that could be executed during unpickling. Arbitrary objects will no "
1025-
"longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the "
1026-
"user via `torch.serialization.add_safe_globals`. We recommend you start setting "
1027-
"`weights_only=True` for any use case where you don't have full control of the loaded file. "
1028-
"Please open an issue on GitHub for any issues related to this experimental feature.",
1029-
FutureWarning,
1030-
)
1031994
pickle_module = pickle
1032995

1033996
# make flipping default BC-compatible

0 commit comments

Comments
 (0)
0