23
23
# weights = torch.load(buf, weights_only = True)
24
24
25
25
import functools as _functools
26
+ import warnings
26
27
from collections import Counter , OrderedDict
27
28
from pickle import (
28
29
APPEND ,
67
68
from sys import maxsize
68
69
from typing import Any , Dict , List
69
70
71
+ try :
72
+ # We rely on this module in private cPython which provides dicts of
73
+ # modules/functions that had their names changed from Python 2 to 3
74
+ has_compat_pickle = True
75
+ from _compat_pickle import IMPORT_MAPPING , NAME_MAPPING
76
+ except ImportError :
77
+ # To prevent warning on import torch, we warn in the Unpickler.load below
78
+ has_compat_pickle = False
79
+ IMPORT_MAPPING , NAME_MAPPING = dict (), dict ()
80
+
70
81
import torch
71
82
72
83
_marked_safe_globals_list : List [Any ] = []
@@ -97,7 +108,8 @@ def _clear_safe_globals():
97
108
def _get_user_allowed_globals ():
98
109
rc : Dict [str , Any ] = {}
99
110
for f in _marked_safe_globals_list :
100
- rc [f"{ f .__module__ } .{ f .__name__ } " ] = f
111
+ module , name = f .__module__ , f .__name__
112
+ rc [f"{ module } .{ name } " ] = f
101
113
return rc
102
114
103
115
@@ -170,12 +182,20 @@ def __init__(self, file, *, encoding: str = "bytes"):
170
182
self .readline = file .readline
171
183
self .read = file .read
172
184
self .memo : Dict [int , Any ] = {}
185
+ self .proto : int = - 1
173
186
174
187
def load (self ):
175
188
"""Read a pickled object representation from the open file.
176
189
177
190
Return the reconstituted object hierarchy specified in the file.
178
191
"""
192
+ if not has_compat_pickle :
193
+ warnings .warn (
194
+ "Could not import IMPORT_MAPPING and NAME_MAPPING from _compat_pickle. "
195
+ "If the default `pickle_protocol` was used at `torch.save` time, any functions or "
196
+ "classes that are in these maps might not behave correctly if allowlisted via "
197
+ "`torch.serialization.add_safe_globals()`."
198
+ )
179
199
self .metastack = []
180
200
self .stack : List [Any ] = []
181
201
self .append = self .stack .append
@@ -190,6 +210,13 @@ def load(self):
190
210
if key [0 ] == GLOBAL [0 ]:
191
211
module = readline ()[:- 1 ].decode ("utf-8" )
192
212
name = readline ()[:- 1 ].decode ("utf-8" )
213
+ # Patch since torch.save default protocol is 2
214
+ # users will be running this code in python > 3
215
+ if self .proto == 2 and has_compat_pickle :
216
+ if (module , name ) in NAME_MAPPING :
217
+ module , name = NAME_MAPPING [(module , name )]
218
+ elif module in IMPORT_MAPPING :
219
+ module = IMPORT_MAPPING [module ]
193
220
full_path = f"{ module } .{ name } "
194
221
if full_path in _get_allowed_globals ():
195
222
self .append (_get_allowed_globals ()[full_path ])
@@ -334,8 +361,14 @@ def load(self):
334
361
self .append (decode_long (data ))
335
362
# First and last deserializer ops
336
363
elif key [0 ] == PROTO [0 ]:
337
- # Read and ignore proto version
338
- read (1 )[0 ]
364
+ self .proto = read (1 )[0 ]
365
+ if self .proto != 2 :
366
+ warnings .warn (
367
+ f"Detected pickle protocol { self .proto } in the checkpoint, which was "
368
+ "not the default pickle protocol used by `torch.load` (2). The weights_only "
369
+ "Unpickler might not support all instructions implemented by this protocol, "
370
+ "please file an issue for adding support if you encounter this."
371
+ )
339
372
elif key [0 ] == STOP [0 ]:
340
373
rc = self .stack .pop ()
341
374
return rc
0 commit comments