8000 Remove dependency on private _compat_pickle in CPython by mikaylagawarecki · Pull Request #129509 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Remove dependency on private _compat_pickle in CPython #129509

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions torch/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -977,3 +977,47 @@ def fire_callbacks(self, *args: P.args, **kwargs: P.kwargs) -> None:
logger.exception(
"Exception in callback for %s registered with gpu trace", self.name
)


# IMPORT_MAPPING and NAME_MAPPING are adapted from https://github.com/python/cpython/blob/main/Lib/_compat_pickle.py
# for use in the weights_only Unpickler.

IMPORT_MAPPING = {
"__builtin__": "builtins",
"copy_reg": "copyreg",
"Queue": "queue",
"repr": "reprlib",
"_abcoll": "collections.abc",
# Non-mutual mappings.
"UserDict": "collections",
"UserList": "collections",
"UserString": "collections",
"whichdb": "dbm",
"StringIO": "io",
"cStringIO": "io",
}


# This contains rename rules that are easy to handle. We ignore the more
# complex stuff (e.g. mapping the names in the urllib and types modules).
# These rules should be run before import names are fixed.
NAME_MAPPING = {
("__builtin__", "xrange"): ("builtins", "range"),
("__builtin__", "reduce"): ("functools", "reduce"),
("__builtin__", "intern"): ("sys", "intern"),
("__builtin__", "unichr"): ("builtins", "chr"),
("__builtin__", "unicode"): ("builtins", "str"),
("__builtin__", "long"): ("builtins", "int"),
("itertools", "izip"): ("builtins", "zip"),
("itertools", "imap"): ("builtins", "map"),
("itertools", "ifilter"): ("builtins", "filter"),
("itertools", "ifilterfalse"): ("itertools", "filterfalse"),
("itertools", "izip_longest"): ("itertools", "zip_longest"),
("UserDict", "IterableUserDict"): ("collections", "UserDict"),
("UserList", "UserList"): ("collections", "UserList"),
("UserString", "UserString"): ("collections", "UserString"),
# Non-mutual mappings.
("__builtin__", "basestring"): ("builtins", "str"),
("exceptions", "StandardError"): ("builtins", "Exception"),
("UserDict", "UserDict"): ("collections", "UserDict"),
}
21 changes: 3 additions & 18 deletions torch/_weights_only_unpickler.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,17 +68,9 @@
from sys import maxsize
from typing import Any, Dict, List

try:
# We rely on this module in private cPython which provides dicts of
# modules/functions that had their names changed from Python 2 to 3
has_compat_pickle = True
from _compat_pickle import IMPORT_MAPPING, NAME_MAPPING
except ImportError:
# To prevent warning on import torch, we warn in the Unpickler.load below
has_compat_pickle = False
IMPORT_MAPPING, NAME_MAPPING = dict(), dict()

import torch
from torch._utils import IMPORT_MAPPING, NAME_MAPPING


_marked_safe_globals_list: List[Any] = []

Expand Down Expand Up @@ -189,13 +181,6 @@ def load(self):

Return the reconstituted object hierarchy specified in the file.
"""
if not has_compat_pickle:
warnings.warn(
"Could not import IMPORT_MAPPING and NAME_MAPPING from _compat_pickle. "
"If the default `pickle_protocol` was used at `torch.save` time, any functions or "
"classes that are in these maps might not behave correctly if allowlisted via "
"`torch.serialization.add_safe_globals()`."
)
self.metastack = []
self.stack: List[Any] = []
self.append = self.stack.append
Expand All @@ -212,7 +197,7 @@ def load(self):
name = readline()[:-1].decode("utf-8")
# Patch since torch.save default protocol is 2
# users will be running this code in python > 3
if self.proto == 2 and has_compat_pickle:
if self.proto == 2:
if (module, name) in NAME_MAPPING:
module, name = NAME_MAPPING[(module, name)]
elif module in IMPORT_MAPPING:
Expand Down
Loading
0