23
23
# weights = torch.load(buf, weights_only = True)
24
24
25
25
import functools as _functools
26
- from _compat_pickle import IMPORT_MAPPING , NAME_MAPPING
26
+ import warnings
27
27
from collections import Counter , OrderedDict
28
28
from pickle import (
29
29
APPEND ,
68
68
from sys import maxsize
69
69
from typing import Any , Dict , List
70
70
71
+ try :
72
+ # We rely on this module in private cPython which provides dicts of
73
+ # modules/functions that had their names changed from Python 2 to 3
74
+ has_compat_pickle = True
75
+ from _compat_pickle import IMPORT_MAPPING , NAME_MAPPING
76
+ except ImportError :
77
+ # To prevent warning on import torch, we warn in the Unpickler.load below
78
+ has_compat_pickle = False
79
+ IMPORT_MAPPING , NAME_MAPPING = dict (), dict ()
80
+
71
81
import torch
72
82
73
83
_marked_safe_globals_list : List [Any ] = []
@@ -179,6 +189,13 @@ def load(self):
179
189
180
190
Return the reconstituted object hierarchy specified in the file.
181
191
"""
192
+ if not has_compat_pickle :
193
+ warnings .warn (
194
+ "Could not import IMPORT_MAPPING and NAME_MAPPING from _compat_pickle. "
195
+ "If the default `pickle_protocol` was used at `torch.save` time, any functions or "
196
+ "classes that are in these maps might not behave correctly if allowlisted via "
197
+ "`torch.serialization.add_safe_globals()`."
198
+ )
182
199
self .metastack = []
183
200
self .stack : List [Any ] = []
184
201
self .append = self .stack .append
@@ -195,7 +212,7 @@ def load(self):
195
212
name = readline ()[:- 1 ].decode ("utf-8" )
196
213
# Patch since torch.save default protocol is 2
197
214
# users will be running this code in python > 3
198
- if self .proto == 2 :
215
+ if self .proto == 2 and has_compat_pickle :
199
216
if (module , name ) in NAME_MAPPING :
200
217
module , name = NAME_MAPPING [(module , name )]
201
218
elif module in IMPORT_MAPPING :
@@ -352,8 +369,14 @@ def load(self):
352
369
self .append (decode_long (data ))
353
370
# First and last deserializer ops
354
371
elif key [0 ] == PROTO [0 ]:
355
- # Read and ignore proto version
356
372
self .proto = read (1 )[0 ]
373
+ if self .proto != 2 :
374
+ warnings .warn (
375
+ f"Detected pickle protocol { self .proto } in the checkpoint, which was "
376
+ "not the default pickle protocol used by `torch.load` (2). The weights_only "
377
+ "Unpickler might not support all instructions implemented by this protocol, "
378
+ "please file an issue for adding support if you encounter this."
379
+ )
357
380
elif key [0 ] == STOP [0 ]:
358
381
rc = self .stack .pop ()
359
382
return rc
0 commit comments