5
5
import io
6
6
import os
7
7
import pickle
8
+ import re
8
9
import shutil
9
10
import struct
10
11
import sys
@@ -1107,12 +1108,33 @@ def load(
1107
1108
"""
1108
1109
torch ._C ._log_api_usage_once ("torch.load" )
1109
1110
UNSAFE_MESSAGE = (
1110
- "Weights only load failed. Re-running `torch.load` with `weights_only` set to `False`"
1111
- " will likely succeed, but it can result in arbitrary code execution."
1112
- " Do it only if you get the file from a trusted source. Alternatively, to load"
1113
- " with `weights_only` please check the recommended steps in the following error message."
1114
- " WeightsUnpickler error: "
1111
+ "Re-running `torch.load` with `weights_only` set to `False` will likely succeed, "
1112
+ "but it can result in arbitrary code execution. Do it only if you got the file from a "
1113
+ "trusted source."
1115
1114
)
1115
+ DOCS_MESSAGE = (
1116
+ "\n \n Check the documentation of torch.load to learn more about types accepted by default with "
1117
+ "weights_only https://pytorch.org/docs/stable/generated/torch.load.html."
1118
+ )
1119
+
1120
+ def _get_wo_message (message : str ) -> str :
1121
+ pattern = r"GLOBAL (\S+) was not an allowed global by default."
1122
+ has_unsafe_global = re .search (pattern , message ) is not None
1123
+ if has_unsafe_global :
1124
+ updated_message = (
1125
+ "Weights only load failed. This file can still be loaded, to do so you have two options "
1126
+ f"\n \t (1) { UNSAFE_MESSAGE } \n \t (2) Alternatively, to load with `weights_only=True` please check "
1127
+ "the recommended steps in the following error message.\n \t WeightsUnpickler error: "
1128
+ + message
1129
+ )
1130
+ else :
1131
+ updated_message = (
1132
+ f"Weights only load failed. { UNSAFE_MESSAGE } \n Please file an issue with the following "
1133
+ "so that we can make `weights_only=True` compatible with your use case: WeightsUnpickler "
1134
+ "error: " + message
1135
+ )
1136
+ return updated_message + DOCS_MESSAGE
1137
+
1116
1138
if weights_only is None :
1117
1139
weights_only , warn_weights_only = False , True
1118
1140
else :
@@ -1200,7 +1222,7 @@ def load(
1200
1222
** pickle_load_args ,
1201
1223
)
1202
1224
except RuntimeError as e :
1203
- raise pickle .UnpicklingError (UNSAFE_MESSAGE + str (e )) from None
1225
+ raise pickle .UnpicklingError (_get_wo_message ( str (e ) )) from None
1204
1226
return _load (
1205
1227
opened_zipfile ,
1206
1228
map_location ,
@@ -1224,7 +1246,7 @@ def load(
1224
1246
** pickle_load_args ,
1225
1247
)
1226
1248
except RuntimeError as e :
1227
- raise pickle .UnpicklingError (UNSAFE_MESSAGE + str (e )) from None
1249
+ raise pickle .UnpicklingError (_get_wo_message ( str (e ) )) from None
1228
1250
return _legacy_load (
1229
1251
opened_file , map_location , pickle_module , ** pickle_load_args
1230
1252
)
0 commit comments