8000 Cherry pick #129244, #129251, #129239, 129396 into release/2.4 by mikaylagawarecki · Pull Request #129478 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Cherry pick #129244, #129251, #129239, 129396 into release/2.4 #129478

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions test/distributed/_tensor/test_dtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,16 @@ def test_dtensor_save_load(self):
buffer.seek(0)
reloaded_st = torch.load(buffer)
self.assertEqual(sharded_tensor, reloaded_st)
# Test weights_only load
try:
torch.serialization.add_safe_globals(
[DTensor, DeviceMesh, Shard, DTensorSpec, TensorMeta]
)
buffer.seek(0)
reloaded_st = torch.load(buffer, weights_only=True)
self.assertEqual(sharded_tensor, reloaded_st)
finally:
torch.serialization.clear_safe_globals()


class DTensorMeshTest(DTensorTestBase):
Expand Down
17 changes: 13 additions & 4 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1801,26 +1801,35 @@ def test_parameterlistdict_pickle(self):
m = nn.ParameterList(map(nn.Parameter, [torch.rand(2), torch.rand(2)]))
with warnings.catch_warnings(record=True) as w:
m = pickle.loads(pickle.dumps(m))
self.assertTrue(len(w) == 0)
# warning from torch.load call in _load_from_bytes
num_warnings = 2 if torch._dynamo.is_compiling() else 1
self.assertTrue(len(w) == num_warnings)
self.assertEqual(w[0].category, FutureWarning)

# Test whether loading from older checkpoints works without triggering warnings
m = nn.ParameterList(map(nn.Parameter, [torch.rand(2), torch.rand(2)]))
del m._forward_pre_hooks, m._state_dict_hooks, m._load_state_dict_pre_hooks, m._non_persistent_buffers_set
with warnings.catch_warnings(record=True) as w:
m = pickle.loads(pickle.dumps(m))
self.assertTrue(len(w) == 0)
# warning from torch.load call in _load_from_bytes
self.assertTrue(len(w) == 1)
self.assertEqual(w[0].category, FutureWarning)

m = nn.ParameterDict({"a": nn.Parameter(torch.rand(2)), "b": nn.Parameter(torch.rand(2))})
with warnings.catch_warnings(record=True) as w:
m = pickle.loads(pickle.dumps(m))
self.assertTrue(len(w) == 0)
# warning from torch.load call in _load_from_bytes
self.assertTrue(len(w) == 1)
self.assertEqual(w[0].category, FutureWarning)

# Test whether loading from older checkpoints works without triggering warnings
m = nn.ParameterDict({"a": nn.Parameter(torch.rand(2)), "b": nn.Parameter(torch.rand(2))})
del m._forward_pre_hooks, m._state_dict_hooks, m._load_state_dict_pre_hooks, m._non_persistent_buffers_set
with warnings.catch_warnings(record=True) as w:
m = pickle.loads(pickle.dumps(m))
self.assertTrue(len(w) == 0)
# warning from torch.load call in _load_from_bytes
self.assertTrue(len(w) == 1)
self.assertEqual(w[0].category, FutureWarning)

def test_weight_norm_pickle(self):
m = torch.nn.utils.weight_norm(nn.Linear(5, 7))
Expand Down
78 changes: 72 additions & 6 deletions test/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import shutil
import pathlib
import platform
from collections import OrderedDict
from collections import namedtuple, OrderedDict
from copy import deepcopy
from itertools import product

Expand Down Expand Up @@ -804,6 +804,17 @@ def wrapper(*args, **kwargs):
def __exit__(self, *args, **kwargs):
torch.save = self.torch_save

Point = namedtuple('Point', ['x', 'y'])

class ClassThatUsesBuildInstruction:
def __init__(self, num):
self.num = num

def __reduce_ex__(self, proto):
# Third item, state here will cause pickle to push a BUILD instruction
return ClassThatUsesBuildInstruction, (self.num,), {'foo': 'bar'}


@unittest.skipIf(IS_WINDOWS, "NamedTemporaryFile on windows")
class TestBothSerialization(TestCase):
@parametrize("weights_only", (True, False))
Expand All @@ -826,7 +837,6 @@ def test(f_new, f_old):
test(f_new, f_old)
self.assertTrue(len(w) == 0, msg=f"Expected no warnings but got {[str(x) for x in w]}")


class TestOldSerialization(TestCase, SerializationMixin):
# unique_key is necessary because on Python 2.7, if a warning passed to
# the warning module is the same, it is not raised again.
Expand Down Expand Up @@ -854,7 +864,8 @@ def import_module(name, filename):
loaded = torch.load(checkpoint)
self.assertTrue(isinstance(loaded, module.Net))
if can_retrieve_source:
self.assertEqual(len(w), 0)
self.assertEqual(len(w), 1)
self.assertEqual(w[0].category, FutureWarning)

# Replace the module with different source
fname = get_file_path_2(os.path.dirname(os.path.dirname(torch.__file__)), 'torch', 'testing',
Expand All @@ -865,8 +876,8 @@ def import_module(name, filename):
loaded = torch.load(checkpoint)
self.assertTrue(isinstance(loaded, module.Net))
if can_retrieve_source:
self.assertEqual(len(w), 1)
self.assertTrue(w[0].category, 'SourceChangeWarning')
self.assertEqual(len(w), 2)
self.assertTrue(w[1].category, 'SourceChangeWarning')

def test_serialization_container(self):
self._test_serialization_container('file', tempfile.NamedTemporaryFile)
Expand Down Expand Up @@ -1040,8 +1051,63 @@ def __reduce__(self):
self.assertIsNone(torch.load(f, weights_only=False))
f.seek(0)
# Safe load should assert
with self.assertRaisesRegex(pickle.UnpicklingError, "Unsupported global: GLOBAL __builtin__.print"):
with self.assertRaisesRegex(pickle.UnpicklingError, "Unsupported global: GLOBAL builtins.print"):
torch.load(f, weights_only=True)
try:
torch.serialization.add_safe_globals([print])
f.seek(0)
torch.load(f, weights_only=True)
finally:
torch.serialization.clear_safe_globals()

def test_weights_only_safe_globals_newobj(self):
# This will use NEWOBJ
p = Point(x=1, y=2)
with BytesIOContext() as f:
torch.save(p, f)
f.seek(0)
with self.assertRaisesRegex(pickle.UnpicklingError,
"GLOBAL __main__.Point was not an allowed global by default"):
torch.load(f, weights_only=True)
f.seek(0)
try:
torch.serialization.add_safe_globals([Point])
loaded_p = torch.load(f, weights_only=True)
self.assertEqual(loaded_p, p)
finally:
torch.serialization.clear_safe_globals()

def test_weights_only_safe_globals_build(self):
counter = 0

def fake_set_state(obj, *args):
nonlocal counter
counter += 1

c = ClassThatUsesBuildInstruction(2)
with BytesIOContext() as f:
torch.save(c, f)
f.seek(0)
with self.assertRaisesRegex(pickle.UnpicklingError,
"GLOBAL __main__.ClassThatUsesBuildInstruction was not an allowed global by default"):
torch.load(f, weights_only=True)
try:
torch.serialization.add_safe_globals([ClassThatUsesBuildInstruction])
# Test dict update path
f.seek(0)
loaded_c = torch.load(f, weights_only=True)
self.assertEqual(loaded_c.num, 2)
self.assertEqual(loaded_c.foo, 'bar')
# Test setstate path
ClassThatUsesBuildInstruction.__setstate__ = fake_set_state
f.seek(0)
loaded_c = torch.load(f, weights_only=True)
self.assertEqual(loaded_c.num, 2)
self.assertEqual(counter, 1)
self.assertFalse(hasattr(loaded_c, 'foo'))
finally:
torch.serialization.clear_safe_globals()
ClassThatUsesBuildInstruction.__setstate__ = None

@parametrize('weights_only', (False, True))
def test_serialization_math_bits(self, weights_only):
Expand Down
53 changes: 47 additions & 6 deletions torch/_weights_only_unpickler.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
# weights = torch.load(buf, weights_only = True)

import functools as _functools
import warnings
from collections import Counter, OrderedDict
from pickle import (
APPEND,
Expand Down Expand Up @@ -67,6 +68,16 @@
from sys import maxsize
from typing import Any, Dict, List

try:
# We rely on this module in private cPython which provides dicts of
# modules/functions that had their names changed from Python 2 to 3
has_compat_pickle = True
from _compat_pickle import IMPORT_MAPPING, NAME_MAPPING
except ImportError:
# To prevent warning on import torch, we warn in the Unpickler.load below
has_compat_pickle = False
IMPORT_MAPPING, NAME_MAPPING = dict(), dict()

import torch

_marked_safe_globals_list: List[Any] = []
Expand Down Expand Up @@ -97,7 +108,8 @@ def _clear_safe_globals():
def _get_user_allowed_globals():
rc: Dict[str, Any] = {}
for f in _marked_safe_globals_list:
rc[f"{f.__module__}.{f.__name__}"] = f
module, name = f.__module__, f.__name__
rc[f"{module}.{name}"] = f
return rc


Expand Down Expand Up @@ -170,12 +182,20 @@ def __init__(self, file, *, encoding: str = "bytes"):
self.readline = file.readline
self.read = file.read
self.memo: Dict[int, Any] = {}
self.proto: int = -1

def load(self):
"""Read a pickled object representation from the open file.

Return the reconstituted object hierarchy specified in the file.
"""
if not has_compat_pickle:
warnings.warn(
"Could not import IMPORT_MAPPING and NAME_MAPPING from _compat_pickle. "
"If the default `pickle_protocol` was used at `torch.save` time, any functions or "
"classes that are in these maps might not behave correctly if allowlisted via "
"`torch.serialization.add_safe_globals()`."
)
self.metastack = []
self.stack: List[Any] = []
self.append = self.stack.append
Expand All @@ -190,6 +210,13 @@ def load(self):
if key[0] == GLOBAL[0]:
module = readline()[:-1].decode("utf-8")
name = readline()[:-1].decode("utf-8")
# Patch since torch.save default protocol is 2
# users will be running this code in python > 3
if self.proto == 2 and has_compat_pickle:
if (module, name) in NAME_MAPPING:
module, name = NAME_MAPPING[(module, name)]
elif module in IMPORT_MAPPING:
module = IMPORT_MAPPING[module]
full_path = f"{module}.{name}"
if full_path in _get_allowed_globals():
self.append(_get_allowed_globals()[full_path])
Expand All @@ -204,9 +231,12 @@ def load(self):
elif key[0] == NEWOBJ[0]:
args = self.stack.pop()
cls = self.stack.pop()
if cls is not torch.nn.Parameter:
if cls is torch.nn.Parameter:
self.append(torch.nn.Parameter(*args))
elif cls in _get_user_allowed_globals().values():
self.append(cls.__new__(cls, *args))
else:
raise RuntimeError(f"Trying to instantiate unsupported class {cls}")
self.append(torch.nn.Parameter(*args))
elif key[0] == REDUCE[0]:
args = self.stack.pop()
func = self.stack[-1]
Expand All @@ -228,9 +258,14 @@ def load(self):
inst.__setstate__(state)
elif type(inst) is OrderedDict:
inst.__dict__.update(state)
elif type(inst) in _get_user_allowed_globals().values():
if hasattr(inst, "__setstate__"):
inst.__setstate__(state)
else:
inst.__dict__.update(state)
else:
raise RuntimeError(
f"Can only build Tensor, parameter or dict objects, but got {type(inst)}"
f"Can only build Tensor, parameter or OrderedDict objects, but got {type(inst)}"
)
# Stack manipulation
elif key[0] == APPEND[0]:
Expand Down Expand Up @@ -334,8 +369,14 @@ def load(self):
self.append(decode_long(data))
# First and last deserializer ops
elif key[0] == PROTO[0]:
# Read and ignore proto version
read(1)[0]
self.proto = read(1)[0]
if self.proto != 2:
warnings.warn(
f"Detected pickle protocol {self.proto} in the checkpoint, which was "
"not the default pickle protocol used by `torch.load` (2). The weights_only "
"Unpickler might not support all instructions implemented by this protocol, "
"please file an issue for adding support if you encounter this."
)
elif key[0] == STOP[0]:
rc = self.stack.pop()
return rc
Expand Down
45 changes: 41 additions & 4 deletions torch/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,12 +165,30 @@
return _weights_only_unpickler._get_safe_globals()

def add_safe_globals(safe_globals: List[Any]) -> None:
'''
Marks the given globals as safe for ``weights_only`` load.
"""
Marks the given globals as safe for ``weights_only`` load. For example, functions
added to this list can be called during unpickling, classes could be instantiated
and have state set.

Args:
safe_globals (List[Any]): list of globals to mark as safe
'''

Example:
>>> # xdoctest: +SKIP("Can't torch.save(t, ...) as doctest thinks MyTensor is defined on torch.serialization")
>>> import tempfile
>>> class MyTensor(torch.Tensor):
... pass
>>> t = MyTensor(torch.randn(2, 3))
>>> with tempfile.NamedTemporaryFile() as f:
... torch.save(t, f.name)
# Running `torch.load(f.name, weights_only=True)` will fail with
# Unsupported global: GLOBAL __main__.MyTensor was not an allowed global by default.
# Check the code and make sure MyTensor is safe to be used when loaded from an arbitrary checkpoint.
... torch.serialization.add_safe_globals([MyTensor])
... torch.load(f.name, weights_only=True)
# MyTensor([[-0.5024, -1.8152, -0.5455],
# [-0.8234, 2.0500, -0.3657]])
"""
_weights_only_unpickler._add_safe_globals(safe_globals)

def _is_zipfile(f) -> bool:
Expand Down Expand Up @@ -867,12 +885,12 @@
zip_file.write_record(name, storage, num_bytes)


def load(

Check warning on line 888 in torch/serialization.py

View workflow job for this annotation

GitHub Actions / bc_linter

Function load: weights_only changed from bool to Optional[bool]
f: FILE_LIKE,
map_location: MAP_LOCATION = None,
pickle_module: Any = None,
*,
weights_only: bool = False,
weights_only: Optional[bool] = None,
mmap: Optional[bool] = None,
**pickle_load_args: Any
) -> Any:
Expand Down Expand Up @@ -982,6 +1000,11 @@
" with `weights_only` please check the recommended steps in the following error message."
" WeightsUnpickler error: "
)
if weights_only is None:
weights_only, warn_weights_only = False, True
else:
warn_weights_only = False

# Add ability to force safe only weight loads via environment variable
if os.getenv("TORCH_FORCE_WEIGHTS_ONLY_LOAD", "0").lower() in ['1', 'y', 'yes', 'true']:
weights_only = True
Expand All @@ -991,6 +1014,20 @@
raise RuntimeError("Can not safely load weights when explicit pickle_module is specified")
else:
if pickle_module is None:
if warn_weights_only:
warnings.warn(
"You are using `torch.load` with `weights_only=False` (the current default value), which uses "
"the default pickle module implicitly. It is possible to construct malicious pickle data "
"which will execute arbitrary code during unpickling (See "
"https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). "
"In a future release, the default value for `weights_only` will be flipped to `True`. This "
"limits the functions that could be executed during unpickling. Arbitrary objects will no "
"longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the "
"user via `torch.serialization.add_safe_globals`. We recommend you start setting "
"`weights_only=True` for any use case where you don't have full control of the loaded file. "
"Please open an issue on GitHub for any issues related to this experimental feature.",
FutureWarning,
)
pickle_module = pickle

# make flipping default BC-compatible
Expand Down
Loading
0