8000 Remove dependency on private _compat_pickle in CPython · pytorch/pytorch@805f94e · GitHub
[go: up one dir, main page]

Skip to content

Commit 805f94e

Browse files
Remove dependency on private _compat_pickle in CPython
ghstack-source-id: 7d6ee40 Pull Request resolved: #129509
1 parent eba6f42 commit 805f94e

File tree

2 files changed

+47
-18
lines changed

2 files changed

+47
-18
lines changed

torch/_utils.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -977,3 +977,47 @@ def fire_callbacks(self, *args: P.args, **kwargs: P.kwargs) -> None:
977977
logger.exception(
978978
"Exception in callback for %s registered with gpu trace", self.name
979979
)
980+
981+
982+
# IMPORT_MAPPING and NAME_MAPPING are adapted from https://github.com/python/cpython/blob/main/Lib/_compat_pickle.py
983+
# for use in the weights_only Unpickler.
984+
985+
IMPORT_MAPPING = {
986+
"__builtin__": "builtins",
987+
"copy_reg": "copyreg",
988+
"Queue": "queue",
989+
"repr": "reprlib",
990+
"_abcoll": "collections.abc",
991+
# Non-mutual mappings.
992+
"UserDict": "collections",
993+
"UserList": "collections",
994+
"UserString": "collections",
995+
"whichdb": "dbm",
996+
"StringIO": "io",
997+
"cStringIO": "io",
998+
}
999+
1000+
1001+
# This contains rename rules that are easy to handle. We ignore the more
1002+
# complex stuff (e.g. mapping the names in the urllib and types modules).
1003+
# These rules should be run before import names are fixed.
1004+
NAME_MAPPING = {
1005+
("__builtin__", "xrange"): ("builtins", "range"),
1006+
("__builtin__", "reduce"): ("functools", "reduce"),
1007+
("__builtin__", "intern"): ("sys", "intern"),
1008+
("__builtin__", "unichr"): ("builtins", "chr"),
1009+
("__builtin__", "unicode"): ("builtins", "str"),
1010+
("__builtin__", "long"): ("builtins", "int"),
1011+
("itertools", "izip"): ("builtins", "zip"),
1012+
("itertools", "imap"): ("builtins", "map"),
1013+
("itertools", "ifilter"): ("builtins", "filter"),
1014+
("itertools", "ifilterfalse"): ("itertools", "filterfalse"),
1015+
("itertools", "izip_longest"): ("itertools", "zip_longest"),
1016+
("UserDict", "IterableUserDict"): ("collections", "UserDict"),
1017+
("UserList", "UserList"): ("collections", "UserList"),
1018+
("UserString", "UserString"): ("collections", "UserString"),
1019+
# Non-mutual mappings.
1020+
("__builtin__", "basestring"): ("builtins", "str"),
1021+
("exceptions", "StandardError"): ("builtins", "Exception"),
1022+
("UserDict", "UserDict"): ("collections", "UserDict"),
1023+
}

torch/_weights_only_unpickler.py

Lines changed: 3 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -68,17 +68,9 @@
6868
from sys import maxsize
6969
from typing import Any, Dict, List
7070

71-
try:
72-
# We rely on this module in private cPython which provides dicts of
73-
# modules/functions that had their names changed from Python 2 to 3
74-
has_compat_pickle = True
75-
from _compat_pickle import IMPORT_MAPPING, NAME_MAPPING
76-
except ImportError:
77-
# To prevent warning on import torch, we warn in the Unpickler.load below
78-
has_compat_pickle = False
79-
IMPORT_MAPPING, NAME_MAPPING = dict(), dict()
80-
8171
import torch
72+
from torch._utils import IMPORT_MAPPING, NAME_MAPPING
73+
8274

8375
_marked_safe_globals_list: List[Any] = []
8476

@@ -189,13 +181,6 @@ def load(self):
189181
190182
Return the reconstituted object hierarchy specified in the file.
191183
"""
192-
if not has_compat_pickle:
193-
warnings.warn(
194-
"Could not import IMPORT_MAPPING and NAME_MAPPING from _compat_pickle. "
195-
"If the default `pickle_protocol` was used at `torch.save` time, any functions or "
196-
"classes that are in these maps might not behave correctly if allowlisted via "
197-
"`torch.serialization.add_safe_globals()`."
198-
)
199184
self.metastack = []
200185
self.stack: List[Any] = []
201186
self.append = self.stack.append
@@ -212,7 +197,7 @@ def load(self):
212197
name = readline()[:-1].decode("utf-8")
213198
# Patch since torch.save default protocol is 2
214199
# users will be running this code in python > 3
215-
if self.proto == 2 and has_compat_pickle:
200+
if self.proto == 2:
216201
if (module, name) in NAME_MAPPING:
217202
module, name = NAME_MAPPING[(module, name)]
218203
elif module in IMPORT_MAPPING:

0 commit comments

Comments
 (0)
0