diff --git a/test/distributed/_tensor/test_dtensor.py b/test/distributed/_tensor/test_dtensor.py index 17a6aebd8f9361..e2758e3216d70c 100644 --- a/test/distributed/_tensor/test_dtensor.py +++ b/test/distributed/_tensor/test_dtensor.py @@ -536,6 +536,16 @@ def test_dtensor_save_load(self): buffer.seek(0) reloaded_st = torch.load(buffer) self.assertEqual(sharded_tensor, reloaded_st) + # Test weights_only load + try: + torch.serialization.add_safe_globals( + [DTensor, DeviceMesh, Shard, DTensorSpec, TensorMeta] + ) + buffer.seek(0) + reloaded_st = torch.load(buffer, weights_only=True) + self.assertEqual(sharded_tensor, reloaded_st) + finally: + torch.serialization.clear_safe_globals() class DTensorMeshTest(DTensorTestBase): diff --git a/test/test_nn.py b/test/test_nn.py index 6dfac4f7ca1bc2..ad54d4126b5840 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -1801,26 +1801,35 @@ def test_parameterlistdict_pickle(self): m = nn.ParameterList(map(nn.Parameter, [torch.rand(2), torch.rand(2)])) with warnings.catch_warnings(record=True) as w: m = pickle.loads(pickle.dumps(m)) - self.assertTrue(len(w) == 0) + # warning from torch.load call in _load_from_bytes + num_warnings = 2 if torch._dynamo.is_compiling() else 1 + self.assertTrue(len(w) == num_warnings) + self.assertEqual(w[0].category, FutureWarning) # Test whether loading from older checkpoints works without triggering warnings m = nn.ParameterList(map(nn.Parameter, [torch.rand(2), torch.rand(2)])) del m._forward_pre_hooks, m._state_dict_hooks, m._load_state_dict_pre_hooks, m._non_persistent_buffers_set with warnings.catch_warnings(record=True) as w: m = pickle.loads(pickle.dumps(m)) - self.assertTrue(len(w) == 0) + # warning from torch.load call in _load_from_bytes + self.assertTrue(len(w) == 1) + self.assertEqual(w[0].category, FutureWarning) m = nn.ParameterDict({"a": nn.Parameter(torch.rand(2)), "b": nn.Parameter(torch.rand(2))}) with warnings.catch_warnings(record=True) as w: m = pickle.loads(pickle.dumps(m)) - self.assertTrue(len(w) == 0) + # warning from torch.load call in _load_from_bytes + self.assertTrue(len(w) == 1) + self.assertEqual(w[0].category, FutureWarning) # Test whether loading from older checkpoints works without triggering warnings m = nn.ParameterDict({"a": nn.Parameter(torch.rand(2)), "b": nn.Parameter(torch.rand(2))}) del m._forward_pre_hooks, m._state_dict_hooks, m._load_state_dict_pre_hooks, m._non_persistent_buffers_set with warnings.catch_warnings(record=True) as w: m = pickle.loads(pickle.dumps(m)) - self.assertTrue(len(w) == 0) + # warning from torch.load call in _load_from_bytes + self.assertTrue(len(w) == 1) + self.assertEqual(w[0].category, FutureWarning) def test_weight_norm_pickle(self): m = torch.nn.utils.weight_norm(nn.Linear(5, 7)) diff --git a/test/test_serialization.py b/test/test_serialization.py index f22331831c39d8..31136c63639e4a 100644 --- a/test/test_serialization.py +++ b/test/test_serialization.py @@ -15,7 +15,7 @@ import shutil import pathlib import platform -from collections import OrderedDict +from collections import namedtuple, OrderedDict from copy import deepcopy from itertools import product @@ -804,6 +804,17 @@ def wrapper(*args, **kwargs): def __exit__(self, *args, **kwargs): torch.save = self.torch_save +Point = namedtuple('Point', ['x', 'y']) + +class ClassThatUsesBuildInstruction: + def __init__(self, num): + self.num = num + + def __reduce_ex__(self, proto): + # Third item, state here will cause pickle to push a BUILD instruction + return ClassThatUsesBuildInstruction, (self.num,), {'foo': 'bar'} + + @unittest.skipIf(IS_WINDOWS, "NamedTemporaryFile on windows") class TestBothSerialization(TestCase): @parametrize("weights_only", (True, False)) @@ -826,7 +837,6 @@ def test(f_new, f_old): test(f_new, f_old) self.assertTrue(len(w) == 0, msg=f"Expected no warnings but got {[str(x) for x in w]}") - class TestOldSerialization(TestCase, SerializationMixin): # unique_key is necessary because on Python 2.7, if a warning passed to # the warning module is the same, it is not raised again. @@ -854,7 +864,8 @@ def import_module(name, filename): loaded = torch.load(checkpoint) self.assertTrue(isinstance(loaded, module.Net)) if can_retrieve_source: - self.assertEqual(len(w), 0) + self.assertEqual(len(w), 1) + self.assertEqual(w[0].category, FutureWarning) # Replace the module with different source fname = get_file_path_2(os.path.dirname(os.path.dirname(torch.__file__)), 'torch', 'testing', @@ -865,8 +876,8 @@ def import_module(name, filename): loaded = torch.load(checkpoint) self.assertTrue(isinstance(loaded, module.Net)) if can_retrieve_source: - self.assertEqual(len(w), 1) - self.assertTrue(w[0].category, 'SourceChangeWarning') + self.assertEqual(len(w), 2) + self.assertTrue(w[1].category, 'SourceChangeWarning') def test_serialization_container(self): self._test_serialization_container('file', tempfile.NamedTemporaryFile) @@ -1040,8 +1051,63 @@ def __reduce__(self): self.assertIsNone(torch.load(f, weights_only=False)) f.seek(0) # Safe load should assert - with self.assertRaisesRegex(pickle.UnpicklingError, "Unsupported global: GLOBAL __builtin__.print"): + with self.assertRaisesRegex(pickle.UnpicklingError, "Unsupported global: GLOBAL builtins.print"): + torch.load(f, weights_only=True) + try: + torch.serialization.add_safe_globals([print]) + f.seek(0) + torch.load(f, weights_only=True) + finally: + torch.serialization.clear_safe_globals() + + def test_weights_only_safe_globals_newobj(self): + # This will use NEWOBJ + p = Point(x=1, y=2) + with BytesIOContext() as f: + torch.save(p, f) + f.seek(0) + with self.assertRaisesRegex(pickle.UnpicklingError, + "GLOBAL __main__.Point was not an allowed global by default"): torch.load(f, weights_only=True) + f.seek(0) + try: + torch.serialization.add_safe_globals([Point]) + loaded_p = torch.load(f, weights_only=True) + self.assertEqual(loaded_p, p) + finally: + torch.serialization.clear_safe_globals() + + def test_weights_only_safe_globals_build(self): + counter = 0 + + def fake_set_state(obj, *args): + nonlocal counter + counter += 1 + + c = ClassThatUsesBuildInstruction(2) + with BytesIOContext() as f: + torch.save(c, f) + f.seek(0) + with self.assertRaisesRegex(pickle.UnpicklingError, + "GLOBAL __main__.ClassThatUsesBuildInstruction was not an allowed global by default"): + torch.load(f, weights_only=True) + try: + torch.serialization.add_safe_globals([ClassThatUsesBuildInstruction]) + # Test dict update path + f.seek(0) + loaded_c = torch.load(f, weights_only=True) + self.assertEqual(loaded_c.num, 2) + self.assertEqual(loaded_c.foo, 'bar') + # Test setstate path + ClassThatUsesBuildInstruction.__setstate__ = fake_set_state + f.seek(0) + loaded_c = torch.load(f, weights_only=True) + self.assertEqual(loaded_c.num, 2) + self.assertEqual(counter, 1) + self.assertFalse(hasattr(loaded_c, 'foo')) + finally: + torch.serialization.clear_safe_globals() + ClassThatUsesBuildInstruction.__setstate__ = None @parametrize('weights_only', (False, True)) def test_serialization_math_bits(self, weights_only): diff --git a/torch/_weights_only_unpickler.py b/torch/_weights_only_unpickler.py index 2ca07d15136cd8..ba131d20478430 100644 --- a/torch/_weights_only_unpickler.py +++ b/torch/_weights_only_unpickler.py @@ -23,6 +23,7 @@ # weights = torch.load(buf, weights_only = True) import functools as _functools +import warnings from collections import Counter, OrderedDict from pickle import ( APPEND, @@ -67,6 +68,16 @@ from sys import maxsize from typing import Any, Dict, List +try: + # We rely on this module in private cPython which provides dicts of + # modules/functions that had their names changed from Python 2 to 3 + has_compat_pickle = True + from _compat_pickle import IMPORT_MAPPING, NAME_MAPPING +except ImportError: + # To prevent warning on import torch, we warn in the Unpickler.load below + has_compat_pickle = False + IMPORT_MAPPING, NAME_MAPPING = dict(), dict() + import torch _marked_safe_globals_list: List[Any] = [] @@ -97,7 +108,8 @@ def _clear_safe_globals(): def _get_user_allowed_globals(): rc: Dict[str, Any] = {} for f in _marked_safe_globals_list: - rc[f"{f.__module__}.{f.__name__}"] = f + module, name = f.__module__, f.__name__ + rc[f"{module}.{name}"] = f return rc @@ -170,12 +182,20 @@ def __init__(self, file, *, encoding: str = "bytes"): self.readline = file.readline self.read = file.read self.memo: Dict[int, Any] = {} + self.proto: int = -1 def load(self): """Read a pickled object representation from the open file. Return the reconstituted object hierarchy specified in the file. """ + if not has_compat_pickle: + warnings.warn( + "Could not import IMPORT_MAPPING and NAME_MAPPING from _compat_pickle. " + "If the default `pickle_protocol` was used at `torch.save` time, any functions or " + "classes that are in these maps might not behave correctly if allowlisted via " + "`torch.serialization.add_safe_globals()`." + ) self.metastack = [] self.stack: List[Any] = [] self.append = self.stack.append @@ -190,6 +210,13 @@ def load(self): if key[0] == GLOBAL[0]: module = readline()[:-1].decode("utf-8") name = readline()[:-1].decode("utf-8") + # Patch since torch.save default protocol is 2 + # users will be running this code in python > 3 + if self.proto == 2 and has_compat_pickle: + if (module, name) in NAME_MAPPING: + module, name = NAME_MAPPING[(module, name)] + elif module in IMPORT_MAPPING: + module = IMPORT_MAPPING[module] full_path = f"{module}.{name}" if full_path in _get_allowed_globals(): self.append(_get_allowed_globals()[full_path]) @@ -204,9 +231,12 @@ def load(self): elif key[0] == NEWOBJ[0]: args = self.stack.pop() cls = self.stack.pop() - if cls is not torch.nn.Parameter: + if cls is torch.nn.Parameter: + self.append(torch.nn.Parameter(*args)) + elif cls in _get_user_allowed_globals().values(): + self.append(cls.__new__(cls, *args)) + else: raise RuntimeError(f"Trying to instantiate unsupported class {cls}") - self.append(torch.nn.Parameter(*args)) elif key[0] == REDUCE[0]: args = self.stack.pop() func = self.stack[-1] @@ -228,9 +258,14 @@ def load(self): inst.__setstate__(state) elif type(inst) is OrderedDict: inst.__dict__.update(state) + elif type(inst) in _get_user_allowed_globals().values(): + if hasattr(inst, "__setstate__"): + inst.__setstate__(state) + else: + inst.__dict__.update(state) else: raise RuntimeError( - f"Can only build Tensor, parameter or dict objects, but got {type(inst)}" + f"Can only build Tensor, parameter or OrderedDict objects, but got {type(inst)}" ) # Stack manipulation elif key[0] == APPEND[0]: @@ -334,8 +369,14 @@ def load(self): self.append(decode_long(data)) # First and last deserializer ops elif key[0] == PROTO[0]: - # Read and ignore proto version - read(1)[0] + self.proto = read(1)[0] + if self.proto != 2: + warnings.warn( + f"Detected pickle protocol {self.proto} in the checkpoint, which was " + "not the default pickle protocol used by `torch.load` (2). The weights_only " + "Unpickler might not support all instructions implemented by this protocol, " + "please file an issue for adding support if you encounter this." + ) elif key[0] == STOP[0]: rc = self.stack.pop() return rc diff --git a/torch/serialization.py b/torch/serialization.py index 1cab9b92c5501d..1163e8efc43461 100644 --- a/torch/serialization.py +++ b/torch/serialization.py @@ -165,12 +165,30 @@ def get_safe_globals() -> List[Any]: return _weights_only_unpickler._get_safe_globals() def add_safe_globals(safe_globals: List[Any]) -> None: - ''' - Marks the given globals as safe for ``weights_only`` load. + """ + Marks the given globals as safe for ``weights_only`` load. For example, functions + added to this list can be called during unpickling, classes could be instantiated + and have state set. Args: safe_globals (List[Any]): list of globals to mark as safe - ''' + + Example: + >>> # xdoctest: +SKIP("Can't torch.save(t, ...) as doctest thinks MyTensor is defined on torch.serialization") + >>> import tempfile + >>> class MyTensor(torch.Tensor): + ... pass + >>> t = MyTensor(torch.randn(2, 3)) + >>> with tempfile.NamedTemporaryFile() as f: + ... torch.save(t, f.name) + # Running `torch.load(f.name, weights_only=True)` will fail with + # Unsupported global: GLOBAL __main__.MyTensor was not an allowed global by default. + # Check the code and make sure MyTensor is safe to be used when loaded from an arbitrary checkpoint. + ... torch.serialization.add_safe_globals([MyTensor]) + ... torch.load(f.name, weights_only=True) + # MyTensor([[-0.5024, -1.8152, -0.5455], + # [-0.8234, 2.0500, -0.3657]]) + """ _weights_only_unpickler._add_safe_globals(safe_globals) def _is_zipfile(f) -> bool: @@ -872,7 +890,7 @@ def load( map_location: MAP_LOCATION = None, pickle_module: Any = None, *, - weights_only: bool = False, + weights_only: Optional[bool] = None, mmap: Optional[bool] = None, **pickle_load_args: Any ) -> Any: @@ -982,6 +1000,11 @@ def load( " with `weights_only` please check the recommended steps in the following error message." " WeightsUnpickler error: " ) + if weights_only is None: + weights_only, warn_weights_only = False, True + else: + warn_weights_only = False + # Add ability to force safe only weight loads via environment variable if os.getenv("TORCH_FORCE_WEIGHTS_ONLY_LOAD", "0").lower() in ['1', 'y', 'yes', 'true']: weights_only = True @@ -991,6 +1014,20 @@ def load( raise RuntimeError("Can not safely load weights when explicit pickle_module is specified") else: if pickle_module is None: + if warn_weights_only: + warnings.warn( + "You are using `torch.load` with `weights_only=False` (the current default value), which uses " + "the default pickle module implicitly. It is possible to construct malicious pickle data " + "which will execute arbitrary code during unpickling (See " + "https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). " + "In a future release, the default value for `weights_only` will be flipped to `True`. This " + "limits the functions that could be executed during unpickling. Arbitrary objects will no " + "longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the " + "user via `torch.serialization.add_safe_globals`. We recommend you start setting " + "`weights_only=True` for any use case where you don't have full control of the loaded file. " + "Please open an issue on GitHub for any issues related to this experimental feature.", + FutureWarning, + ) pickle_module = pickle # make flipping default BC-compatible