8000 Cherry pick #129244 #129251 #129509 by mikaylagawarecki · Pull Request #129574 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Cherry pick #129244 #129251 #129509 #129574

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations 8000
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions test/distributed/_tensor/test_dtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,16 @@ def test_dtensor_save_load(self):
buffer.seek(0)
reloaded_st = torch.load(buffer)
self.assertEqual(sharded_tensor, reloaded_st)
# Test weights_only load
try:
torch.serialization.add_safe_globals(
[DTensor, DeviceMesh, Shard, DTensorSpec, TensorMeta]
)
buffer.seek(0)
reloaded_st = torch.load(buffer, weights_only=True)
self.assertEqual(sharded_tensor, reloaded_st)
finally:
torch.serialization.clear_safe_globals()


class DTensorMeshTest(DTensorTestBase):
Expand Down
70 changes: 68 additions & 2 deletions test/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import shutil
import pathlib
import platform
from collections import OrderedDict
from collections import namedtuple, OrderedDict
from copy import deepcopy
from itertools import product

Expand Down Expand Up @@ -804,6 +804,17 @@ def wrapper(*args, **kwargs):
def __exit__(self, *args, **kwargs):
torch.save = self.torch_save

Point = namedtuple('Point', ['x', 'y'])

class ClassThatUsesBuildInstruction:
def __init__(self, num):
self.num = num

def __reduce_ex__(self, proto):
# Third item, state here will cause pickle to push a BUILD instruction
return ClassThatUsesBuildInstruction, (self.num,), {'foo': 'bar'}


@unittest.skipIf(IS_WINDOWS, "NamedTemporaryFile on windows")
class TestBothSerialization(TestCase):
@parametrize("weights_only", (True, False))
Expand Down Expand Up @@ -1040,8 +1051,63 @@ def __reduce__(self):
self.assertIsNone(torch.load(f, weights_only=False))
f.seek(0)
# Safe load should assert
with self.assertRaisesRegex(pickle.UnpicklingError, "Unsupported global: GLOBAL __builtin__.print"):
with self.assertRaisesRegex(pickle.UnpicklingError, "Unsupported global: GLOBAL builtins.print"):
torch.load(f, weights_only=True)
try:
torch.serialization.add_safe_globals([print])
f.seek(0)
torch.load(f, weights_only=True)
finally:
torch.serialization.clear_safe_globals()

def test_weights_only_safe_globals_newobj(self):
# This will use NEWOBJ
p = Point(x=1, y=2)
with BytesIOContext() as f:
torch.save(p, f)
f.seek(0)
with self.assertRaisesRegex(pickle.UnpicklingError,
"GLOBAL __main__.Point was not an allowed global by default"):
torch.load(f, weights_only=True)
f.seek(0)
try:
torch.serialization.add_safe_globals([Point])
loaded_p = torch.load(f, weights_only=True)
self.assertEqual(loaded_p, p)
finally:
torch.serialization.clear_safe_globals()

def test_weights_only_safe_globals_build(self):
counter = 0

def fake_set_state(obj, *args):
nonlocal counter
counter += 1

c = ClassThatUsesBuildInstruction(2)
with BytesIOContext() as f:
torch.save(c, f)
f.seek(0)
with self.assertRaisesRegex(pickle.UnpicklingError,
"GLOBAL __main__.ClassThatUsesBuildInstruction was not an allowed global by default"):
torch.load(f, weights_only=True)
try:
torch.serialization.add_safe_globals([ClassThatUsesBuildInstruction])
# Test dict update path
f.seek(0)
loaded_c = torch.load(f, weights_only=True)
self.assertEqual(loaded_c.num, 2)
self.assertEqual(loaded_c.foo, 'bar')
# Test setstate path
ClassThatUsesBuildInstruction.__setstate__ = fake_set_state
f.seek(0)
loaded_c = torch.load(f, weights_only=True)
self.assertEqual(loaded_c.num, 2)
self.assertEqual(counter, 1)
self.assertFalse(hasattr(loaded_c, 'foo'))
finally:
torch.serialization.clear_safe_globals()
ClassThatUsesBuildInstruction.__setstate__ = None

@parametrize('weights_only', (False, True))
def test_serialization_math_bits(self, weights_only):
Expand Down
44 changes: 44 additions & 0 deletions torch/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -962,3 +962,47 @@ def fire_callbacks(self, *args: P.args, **kwargs: P.kwargs) -> None:
logger.exception(
"Exception in callback for %s registered with gpu trace", self.name
)


# IMPORT_MAPPING and NAME_MAPPING are adapted from https://github.com/python/cpython/blob/main/Lib/_compat_pickle.py
# for use in the weights_only Unpickler.

IMPORT_MAPPING = {
"__builtin__": "builtins",
"copy_reg": "copyreg",
"Queue": "queue",
"repr": "reprlib",
"_abcoll": "collections.abc",
# Non-mutual mappings.
"UserDict": "collections",
"UserList": "collections",
"UserString": "collections",
"whichdb": "dbm",
"StringIO": "io",
"cStringIO": "io",
}


# This contains rename rules that are easy to handle. We ignore the more
# complex stuff (e.g. mapping the names in the urllib and types modules).
# These rules should be run before import names are fixed.
NAME_MAPPING = {
("__builtin__", "xrange"): ("builtins", "range"),
("__builtin__", "reduce"): ("functools", "reduce"),
("__builtin__", "intern"): ("sys", "intern"),
("__builtin__", "unichr"): ("builtins", "chr"),
("__builtin__", "unicode"): ("builtins", "str"),
("__builtin__", "long"): ("builtins", "int"),
("itertools", "izip"): ("builtins", "zip"),
("itertools", "imap"): ("builtins", "map"),
("itertools", "ifilter"): ("builtins", "filter"),
("itertools", "ifilterfalse"): ("itertools", "filterfalse"),
("itertools", "izip_longest"): ("itertools", "zip_longest"),
("UserDict", "IterableUserDict"): ("collections", "UserDict"),
("UserList", "UserList"): ("collections", "UserList"),
("UserString", "UserString"): ("collections", "UserString"),
# Non-mutual mappings.
("__builtin__", "basestring"): ("builtins", "str"),
("exceptions", "StandardError"): ("builtins", "Exception"),
("UserDict", "UserDict"): ("collections", "UserDict"),
}
38 changes: 32 additions & 6 deletions torch/_weights_only_unpickler.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
# weights = torch.load(buf, weights_only = True)

import functools as _functools
import warnings
from collections import Counter, OrderedDict
from pickle import (
APPEND,
Expand Down Expand Up @@ -68,6 +69,8 @@
from typing import Any, Dict, List

import torch
from torch._utils import IMPORT_MAPPING, NAME_MAPPING


_marked_safe_globals_list: List[Any] = []

Expand Down Expand Up @@ -97,7 +100,8 @@ def _clear_safe_globals():
def _get_user_allowed_globals():
rc: Dict[str, Any] = {}
for f in _marked_safe_globals_list:
rc[f"{f.__module__}.{f.__name__}"] = f
module, name = f.__module__, f.__name__
rc[f"{module}.{name}"] = f
return rc


Expand Down Expand Up @@ -170,6 +174,7 @@ def __init__(self, file, *, encoding: str = "bytes"):
self.readline = file.readline
self.read = file.read
self.memo: Dict[int, Any] = {}
self.proto: int = -1

def load(self):
"""Read a pickled object representation from the open file.
Expand All @@ -190,6 +195,13 @@ def load(self):
if key[0] == GLOBAL[0]:
module = readline()[:-1].decode("utf-8")
name = readline()[:-1].decode("utf-8")
# Patch since torch.save default protocol is 2
# users will be running this code in python > 3
if self.proto == 2:
if (module, name) in NAME_MAPPING:
module, name = NAME_MAPPING[(module, name)]
elif module in IMPORT_MAPPING:
module = IMPORT_MAPPING[module]
full_path = f"{module}.{name}"
if full_path in _get_allowed_globals():
self.append(_get_allowed_globals()[full_path])
Expand All @@ -204,9 +216,12 @@ def load(self):
elif key[0] == NEWOBJ[0]:
args = self.stack.pop()
cls = self.stack.pop()
if cls is not torch.nn.Parameter:
if cls is torch.nn.Parameter:
self.append(torch.nn.Parameter(*args))
elif cls in _get_user_allowed_globals().values():
self.append(cls.__new__(cls, *args))
else:
raise RuntimeError(f"Trying to instantiate unsupported class {cls}")
self.append(torch.nn.Parameter(*args))
elif key[0] == REDUCE[0]:
args = self.stack.pop()
func = self.stack[-1]
Expand All @@ -228,9 +243,14 @@ def load(self):
inst.__setstate__(state)
elif type(inst) is OrderedDict:
inst.__dict__.update(state)
elif type(inst) in _get_user_allowed_globals().values():
if hasattr(inst, "__setstate__"):
inst.__setstate__(state)
else:
inst.__dict__.update(state)
else:
raise RuntimeError(
f"Can only build Tensor, parameter or dict objects, but got {type(inst)}"
f"Can only build Tensor, parameter or OrderedDict objects, but got {type(inst)}"
)
# Stack manipulation
elif key[0] == APPEND[0]:
Expand Down Expand Up @@ -334,8 +354,14 @@ def load(self):
self.append(decode_long(data))
# First and last deserializer ops
elif key[0] == PROTO[0]:
# Read and ignore proto version
read(1)[0]
self.proto = read(1)[0]
if self.proto != 2:
warnings.warn(
f"Detected pickle protocol {self.proto} in the checkpoint, which was "
"not the default pickle protocol used by `torch.load` (2). The weights_only "
"Unpickler might not support all instructions implemented by this protocol, "
"please file an issue for adding support if you encounter this."
)
elif key[0] == STOP[0]:
rc = self.stack.pop()
return rc
Expand Down
4 changes: 3 additions & 1 deletion torch/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,9 @@ def get_safe_globals() -> List[Any]:

def add_safe_globals(safe_globals: List[Any]) -> None:
'''
Marks the given globals as safe for ``weights_only`` load.
Marks the given globals as safe for ``weights_only`` load. For example, functions
added to this list can be called during unpickling, classes could be instantiated
and have state set.

Args:
safe_globals (List[Any]): list of globals to mark as safe
Expand Down
Loading
0