8000 Update · pytorch/pytorch@5c9cdb3 · GitHub
[go: up one dir, main page]

Skip to content

Commit 5c9cdb3

Browse files
Update
[ghstack-poisoned]
2 parents cd4dbc9 + ad610aa commit 5c9cdb3

File tree

2 files changed

+29
-4
lines changed

2 files changed

+29
-4
lines changed

torch/_weights_only_unpickler.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
# weights = torch.load(buf, weights_only = True)
2424

2525
import functools as _functools
26-
from _compat_pickle import IMPORT_MAPPING, NAME_MAPPING
26+
import warnings
2727
from collections import Counter, OrderedDict
2828
from pickle import (
2929
APPEND,
@@ -68,6 +68,16 @@
6868
from sys import maxsize
6969
from typing import Any, Dict, List
7070

71+
try:
72+
# We rely on this module in private cPython which provides dicts of
73+
# modules/functions that had their names changed from Python 2 to 3
74+
has_compat_pickle = True
75+
from _compat_pickle import IMPORT_MAPPING, NAME_MAPPING
76+
except ImportError:
77+
# To prevent warning on import torch, we warn in the Unpickler.load below
78+
has_compat_pickle = False
79+
IMPORT_MAPPING, NAME_MAPPING = dict(), dict()
80+
7181
import torch
7282

7383
_marked_safe_globals_list: List[Any] = []
@@ -179,6 +189,13 @@ def load(self):
179189
180190
Return the reconstituted object hierarchy specified in the file.
181191
"""
192+
if not has_compat_pickle:
193+
warnings.warn(
194+
"Could not import IMPORT_MAPPING and NAME_MAPPING from _compat_pickle. "
195+
"If the default `pickle_protocol` was used at `torch.save` time, any functions or "
196+
"classes that are in these maps might not behave correctly if allowlisted via "
197+
"`torch.serialization.add_safe_globals()`."
198+
)
182199
self.metastack = []
183200
self.stack: List[Any] = []
184201
self.append = self.stack.append
@@ -195,7 +212,7 @@ def load(self):
195212
name = readline()[:-1].decode("utf-8")
196213
# Patch since torch.save default protocol is 2
197214
# users will be running this code in python > 3
198-
if self.proto == 2:
215+
if self.proto == 2 and has_compat_pickle:
199216
if (module, name) in NAME_MAPPING:
200217
module, name = NAME_MAPPING[(module, name)]
201218
elif module in IMPORT_MAPPING:
@@ -352,8 +369,14 @@ def load(self):
352369
self.append(decode_long(data))
353370
# First and last deserializer ops
354371
elif key[0] == PROTO[0]:
355-
# Read and ignore proto version
356372
self.proto = read(1)[0]
373+
if self.proto != 2:
374+
warnings.warn(
375+
f"Detected pickle protocol {self.proto} in the checkpoint, which was "
376+
"not the default pickle protocol used by `torch.load` (2). The weights_only "
377+
"Unpickler might not support all instructions implemented by this protocol, "
378+
"please file an issue for adding support if you encounter this."
379+
)
357380
elif key[0] == STOP[0]:
358381
rc = self.stack.pop()
359382
return rc

torch/serialization.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,9 @@ def get_safe_globals() -> List[Any]:
203203

204204
def add_safe_globals(safe_globals: List[Any]) -> None:
205205
"""
206-
Marks the given globals as safe for ``weights_only`` load.
206+
Marks the given globals as safe for ``weights_only`` load. For example, functions
207+
added to this list can be called during unpickling, classes could be instantiated
208+
and have state set.
207209
208210
Args:
209211
safe_globals (List[Any]): list of globals to mark as safe

0 commit comments

Comments
 (0)
0