From f0bde90e073fdfd132bfc1c1a2ed0221ef276b9b Mon Sep 17 00:00:00 2001 From: Andrey Talman Date: Wed, 26 Jun 2024 10:55:04 -0400 Subject: [PATCH] Revert "Cherry pick #129244, #129251, #129239, 129396 into release/2.4 (#129478)" This reverts commit 22a4d46e2b4d5404e7df374e8ecb21026feb373e. --- test/distributed/_tensor/test_dtensor.py | 10 --- test/test_nn.py | 17 ++---- test/test_serialization.py | 78 ++---------------------- torch/_weights_only_unpickler.py | 53 ++-------------- torch/serialization.py | 45 ++------------ 5 files changed, 20 insertions(+), 183 deletions(-) diff --git a/test/distributed/_tensor/test_dtensor.py b/test/distributed/_tensor/test_dtensor.py index e2758e3216d70c..17a6aebd8f9361 100644 --- a/test/distributed/_tensor/test_dtensor.py +++ b/test/distributed/_tensor/test_dtensor.py @@ -536,16 +536,6 @@ def test_dtensor_save_load(self): buffer.seek(0) reloaded_st = torch.load(buffer) self.assertEqual(sharded_tensor, reloaded_st) - # Test weights_only load - try: - torch.serialization.add_safe_globals( - [DTensor, DeviceMesh, Shard, DTensorSpec, TensorMeta] - ) - buffer.seek(0) - reloaded_st = torch.load(buffer, weights_only=True) - self.assertEqual(sharded_tensor, reloaded_st) - finally: - torch.serialization.clear_safe_globals() class DTensorMeshTest(DTensorTestBase): diff --git a/test/test_nn.py b/test/test_nn.py index ad54d4126b5840..6dfac4f7ca1bc2 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -1801,35 +1801,26 @@ def test_parameterlistdict_pickle(self): m = nn.ParameterList(map(nn.Parameter, [torch.rand(2), torch.rand(2)])) with warnings.catch_warnings(record=True) as w: m = pickle.loads(pickle.dumps(m)) - # warning from torch.load call in _load_from_bytes - num_warnings = 2 if torch._dynamo.is_compiling() else 1 - self.assertTrue(len(w) == num_warnings) - self.assertEqual(w[0].category, FutureWarning) + self.assertTrue(len(w) == 0) # Test whether loading from older checkpoints works without triggering warnings m = nn.ParameterList(map(nn.Parameter, [torch.rand(2), torch.rand(2)])) del m._forward_pre_hooks, m._state_dict_hooks, m._load_state_dict_pre_hooks, m._non_persistent_buffers_set with warnings.catch_warnings(record=True) as w: m = pickle.loads(pickle.dumps(m)) - # warning from torch.load call in _load_from_bytes - self.assertTrue(len(w) == 1) - self.assertEqual(w[0].category, FutureWarning) + self.assertTrue(len(w) == 0) m = nn.ParameterDict({"a": nn.Parameter(torch.rand(2)), "b": nn.Parameter(torch.rand(2))}) with warnings.catch_warnings(record=True) as w: m = pickle.loads(pickle.dumps(m)) - # warning from torch.load call in _load_from_bytes - self.assertTrue(len(w) == 1) - self.assertEqual(w[0].category, FutureWarning) + self.assertTrue(len(w) == 0) # Test whether loading from older checkpoints works without triggering warnings m = nn.ParameterDict({"a": nn.Parameter(torch.rand(2)), "b": nn.Parameter(torch.rand(2))}) del m._forward_pre_hooks, m._state_dict_hooks, m._load_state_dict_pre_hooks, m._non_persistent_buffers_set with warnings.catch_warnings(record=True) as w: m = pickle.loads(pickle.dumps(m)) - # warning from torch.load call in _load_from_bytes - self.assertTrue(len(w) == 1) - self.assertEqual(w[0].category, FutureWarning) + self.assertTrue(len(w) == 0) def test_weight_norm_pickle(self): m = torch.nn.utils.weight_norm(nn.Linear(5, 7)) diff --git a/test/test_serialization.py b/test/test_serialization.py index 31136c63639e4a..f22331831c39d8 100644 --- a/test/test_serialization.py +++ b/test/test_serialization.py @@ -15,7 +15,7 @@ import shutil import pathlib import platform -from collections import namedtuple, OrderedDict +from collections import OrderedDict from copy import deepcopy from itertools import product @@ -804,17 +804,6 @@ def wrapper(*args, **kwargs): def __exit__(self, *args, **kwargs): torch.save = self.torch_save -Point = namedtuple('Point', ['x', 'y']) - -class ClassThatUsesBuildInstruction: - def __init__(self, num): - self.num = num - - def __reduce_ex__(self, proto): - # Third item, state here will cause pickle to push a BUILD instruction - return ClassThatUsesBuildInstruction, (self.num,), {'foo': 'bar'} - - @unittest.skipIf(IS_WINDOWS, "NamedTemporaryFile on windows") class TestBothSerialization(TestCase): @parametrize("weights_only", (True, False)) @@ -837,6 +826,7 @@ def test(f_new, f_old): test(f_new, f_old) self.assertTrue(len(w) == 0, msg=f"Expected no warnings but got {[str(x) for x in w]}") + class TestOldSerialization(TestCase, SerializationMixin): # unique_key is necessary because on Python 2.7, if a warning passed to # the warning module is the same, it is not raised again. @@ -864,8 +854,7 @@ def import_module(name, filename): loaded = torch.load(checkpoint) self.assertTrue(isinstance(loaded, module.Net)) if can_retrieve_source: - self.assertEqual(len(w), 1) - self.assertEqual(w[0].category, FutureWarning) + self.assertEqual(len(w), 0) # Replace the module with different source fname = get_file_path_2(os.path.dirname(os.path.dirname(torch.__file__)), 'torch', 'testing', @@ -876,8 +865,8 @@ def import_module(name, filename): loaded = torch.load(checkpoint) self.assertTrue(isinstance(loaded, module.Net)) if can_retrieve_source: - self.assertEqual(len(w), 2) - self.assertTrue(w[1].category, 'SourceChangeWarning') + self.assertEqual(len(w), 1) + self.assertTrue(w[0].category, 'SourceChangeWarning') def test_serialization_container(self): self._test_serialization_container('file', tempfile.NamedTemporaryFile) @@ -1051,63 +1040,8 @@ def __reduce__(self): self.assertIsNone(torch.load(f, weights_only=False)) f.seek(0) # Safe load should assert - with self.assertRaisesRegex(pickle.UnpicklingError, "Unsupported global: GLOBAL builtins.print"): - torch.load(f, weights_only=True) - try: - torch.serialization.add_safe_globals([print]) - f.seek(0) - torch.load(f, weights_only=True) - finally: - torch.serialization.clear_safe_globals() - - def test_weights_only_safe_globals_newobj(self): - # This will use NEWOBJ - p = Point(x=1, y=2) - with BytesIOContext() as f: - torch.save(p, f) - f.seek(0) - with self.assertRaisesRegex(pickle.UnpicklingError, - "GLOBAL __main__.Point was not an allowed global by default"): + with self.assertRaisesRegex(pickle.UnpicklingError, "Unsupported global: GLOBAL __builtin__.print"): torch.load(f, weights_only=True) - f.seek(0) - try: - torch.serialization.add_safe_globals([Point]) - loaded_p = torch.load(f, weights_only=True) - self.assertEqual(loaded_p, p) - finally: - torch.serialization.clear_safe_globals() - - def test_weights_only_safe_globals_build(self): - counter = 0 - - def fake_set_state(obj, *args): - nonlocal counter - counter += 1 - - c = ClassThatUsesBuildInstruction(2) - with BytesIOContext() as f: - torch.save(c, f) - f.seek(0) - with self.assertRaisesRegex(pickle.UnpicklingError, - "GLOBAL __main__.ClassThatUsesBuildInstruction was not an allowed global by default"): - torch.load(f, weights_only=True) - try: - torch.serialization.add_safe_globals([ClassThatUsesBuildInstruction]) - # Test dict update path - f.seek(0) - loaded_c = torch.load(f, weights_only=True) - self.assertEqual(loaded_c.num, 2) - self.assertEqual(loaded_c.foo, 'bar') - # Test setstate path - ClassThatUsesBuildInstruction.__setstate__ = fake_set_state - f.seek(0) - loaded_c = torch.load(f, weights_only=True) - self.assertEqual(loaded_c.num, 2) - self.assertEqual(counter, 1) - self.assertFalse(hasattr(loaded_c, 'foo')) - finally: - torch.serialization.clear_safe_globals() - ClassThatUsesBuildInstruction.__setstate__ = None @parametrize('weights_only', (False, True)) def test_serialization_math_bits(self, weights_only): diff --git a/torch/_weights_only_unpickler.py b/torch/_weights_only_unpickler.py index ba131d20478430..2ca07d15136cd8 100644 --- a/torch/_weights_only_unpickler.py +++ b/torch/_weights_only_unpickler.py @@ -23,7 +23,6 @@ # weights = torch.load(buf, weights_only = True) import functools as _functools -import warnings from collections import Counter, OrderedDict from pickle import ( APPEND, @@ -68,16 +67,6 @@ from sys import maxsize from typing import Any, Dict, List -try: - # We rely on this module in private cPython which provides dicts of - # modules/functions that had their names changed from Python 2 to 3 - has_compat_pickle = True - from _compat_pickle import IMPORT_MAPPING, NAME_MAPPING -except ImportError: - # To prevent warning on import torch, we warn in the Unpickler.load below - has_compat_pickle = False - IMPORT_MAPPING, NAME_MAPPING = dict(), dict() - import torch _marked_safe_globals_list: List[Any] = [] @@ -108,8 +97,7 @@ def _clear_safe_globals(): def _get_user_allowed_globals(): rc: Dict[str, Any] = {} for f in _marked_safe_globals_list: - module, name = f.__module__, f.__name__ - rc[f"{module}.{name}"] = f + rc[f"{f.__module__}.{f.__name__}"] = f return rc @@ -182,20 +170,12 @@ def __init__(self, file, *, encoding: str = "bytes"): self.readline = file.readline self.read = file.read self.memo: Dict[int, Any] = {} - self.proto: int = -1 def load(self): """Read a pickled object representation from the open file. Return the reconstituted object hierarchy specified in the file. """ - if not has_compat_pickle: - warnings.warn( - "Could not import IMPORT_MAPPING and NAME_MAPPING from _compat_pickle. " - "If the default `pickle_protocol` was used at `torch.save` time, any functions or " - "classes that are in these maps might not behave correctly if allowlisted via " - "`torch.serialization.add_safe_globals()`." - ) self.metastack = [] self.stack: List[Any] = [] self.append = self.stack.append @@ -210,13 +190,6 @@ def load(self): if key[0] == GLOBAL[0]: module = readline()[:-1].decode("utf-8") name = readline()[:-1].decode("utf-8") - # Patch since torch.save default protocol is 2 - # users will be running this code in python > 3 - if self.proto == 2 and has_compat_pickle: - if (module, name) in NAME_MAPPING: - module, name = NAME_MAPPING[(module, name)] - elif module in IMPORT_MAPPING: - module = IMPORT_MAPPING[module] full_path = f"{module}.{name}" if full_path in _get_allowed_globals(): self.append(_get_allowed_globals()[full_path]) @@ -231,12 +204,9 @@ def load(self): elif key[0] == NEWOBJ[0]: args = self.stack.pop() cls = self.stack.pop() - if cls is torch.nn.Parameter: - self.append(torch.nn.Parameter(*args)) - elif cls in _get_user_allowed_globals().values(): - self.append(cls.__new__(cls, *args)) - else: + if cls is not torch.nn.Parameter: raise RuntimeError(f"Trying to instantiate unsupported class {cls}") + self.append(torch.nn.Parameter(*args)) elif key[0] == REDUCE[0]: args = self.stack.pop() func = self.stack[-1] @@ -258,14 +228,9 @@ def load(self): inst.__setstate__(state) elif type(inst) is OrderedDict: inst.__dict__.update(state) - elif type(inst) in _get_user_allowed_globals().values(): - if hasattr(inst, "__setstate__"): - inst.__setstate__(state) - else: - inst.__dict__.update(state) else: raise RuntimeError( - f"Can only build Tensor, parameter or OrderedDict objects, but got {type(inst)}" + f"Can only build Tensor, parameter or dict objects, but got {type(inst)}" ) # Stack manipulation elif key[0] == APPEND[0]: @@ -369,14 +334,8 @@ def load(self): self.append(decode_long(data)) # First and last deserializer ops elif key[0] == PROTO[0]: - self.proto = read(1)[0] - if self.proto != 2: - warnings.warn( - f"Detected pickle protocol {self.proto} in the checkpoint, which was " - "not the default pickle protocol used by `torch.load` (2). The weights_only " - "Unpickler might not support all instructions implemented by this protocol, " - "please file an issue for adding support if you encounter this." - ) + # Read and ignore proto version + read(1)[0] elif key[0] == STOP[0]: rc = self.stack.pop() return rc diff --git a/torch/serialization.py b/torch/serialization.py index 1163e8efc43461..1cab9b92c5501d 100644 --- a/torch/serialization.py +++ b/torch/serialization.py @@ -165,30 +165,12 @@ def get_safe_globals() -> List[Any]: return _weights_only_unpickler._get_safe_globals() def add_safe_globals(safe_globals: List[Any]) -> None: - """ - Marks the given globals as safe for ``weights_only`` load. For example, functions - added to this list can be called during unpickling, classes could be instantiated - and have state set. + ''' + Marks the given globals as safe for ``weights_only`` load. Args: safe_globals (List[Any]): list of globals to mark as safe - - Example: - >>> # xdoctest: +SKIP("Can't torch.save(t, ...) as doctest thinks MyTensor is defined on torch.serialization") - >>> import tempfile - >>> class MyTensor(torch.Tensor): - ... pass - >>> t = MyTensor(torch.randn(2, 3)) - >>> with tempfile.NamedTemporaryFile() as f: - ... torch.save(t, f.name) - # Running `torch.load(f.name, weights_only=True)` will fail with - # Unsupported global: GLOBAL __main__.MyTensor was not an allowed global by default. - # Check the code and make sure MyTensor is safe to be used when loaded from an arbitrary checkpoint. - ... torch.serialization.add_safe_globals([MyTensor]) - ... torch.load(f.name, weights_only=True) - # MyTensor([[-0.5024, -1.8152, -0.5455], - # [-0.8234, 2.0500, -0.3657]]) - """ + ''' _weights_only_unpickler._add_safe_globals(safe_globals) def _is_zipfile(f) -> bool: @@ -890,7 +872,7 @@ def load( map_location: MAP_LOCATION = None, pickle_module: Any = None, *, - weights_only: Optional[bool] = None, + weights_only: bool = False, mmap: Optional[bool] = None, **pickle_load_args: Any ) -> Any: @@ -1000,11 +982,6 @@ def load( " with `weights_only` please check the recommended steps in the following error message." " WeightsUnpickler error: " ) - if weights_only is None: - weights_only, warn_weights_only = False, True - else: - warn_weights_only = False - # Add ability to force safe only weight loads via environment variable if os.getenv("TORCH_FORCE_WEIGHTS_ONLY_LOAD", "0").lower() in ['1', 'y', 'yes', 'true']: weights_only = True @@ -1014,20 +991,6 @@ def load( raise RuntimeError("Can not safely load weights when explicit pickle_module is specified") else: if pickle_module is None: - if warn_weights_only: - warnings.warn( - "You are using `torch.load` with `weights_only=False` (the current default value), which uses " - "the default pickle module implicitly. It is possible to construct malicious pickle data " - "which will execute arbitrary code during unpickling (See " - "https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). " - "In a future release, the default value for `weights_only` will be flipped to `True`. This " - "limits the functions that could be executed during unpickling. Arbitrary objects will no " - "longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the " - "user via `torch.serialization.add_safe_globals`. We recommend you start setting " - "`weights_only=True` for any use case where you don't have full control of the loaded file. " - "Please open an issue on GitHub for any issues related to this experimental feature.", - FutureWarning, - ) pickle_module = pickle # make flipping default BC-compatible