8000 Improve error message for weights_only load · pytorch/pytorch@985126e · GitHub
[go: up one dir, main page]

Skip to content

Commit 985126e

Browse files
Improve error message for weights_only load
ghstack-source-id: 51e7424 Pull Request resolved: #129705
1 parent 805f94e commit 985126e

File tree

3 files changed

+47
-9
lines changed

3 files changed

+47
-9
lines changed

test/test_serialization.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1112,6 +1112,22 @@ def fake_set_state(obj, *args):
11121112
torch.serialization.clear_safe_globals()
11131113
ClassThatUsesBuildInstruction.__setstate__ = None
11141114

1115+
@parametrize("unsafe_global", [True, False])
1116+
def test_weights_only_error(self, unsafe_global):
1117+
sd = {'t': TwoTensor(torch.randn(2), torch.randn(2))}
1118+
pickle_protocol = torch.serialization.DEFAULT_PROTOCOL if unsafe_global else 5
1119+
with BytesIOContext() as f:
1120+
torch.save(sd, f, pickle_protocol=pickle_protocol)
1121+
f.seek(0)
1122+
if unsafe_global:
1123+
with self.assertRaisesRegex(pickle.UnpicklingError,
1124+
r"use `torch.serialization.add_safe_globals\(\[TwoTensor\]\)` to allowlist"):
1125+
torch.load(f, weights_only=True)
1126+
else:
1127+
with self.assertRaisesRegex(pickle.UnpicklingError,
1128+
"file an issue with the following so that we can make `weights_only=True`"):
1129+
torch.load(f, weights_only=True)
1130+
11151131
@parametrize('weights_only', (False, True))
11161132
def test_serialization_math_bits(self, weights_only):
11171133
t = torch.randn(1, dtype=torch.cfloat)

torch/_weights_only_unpickler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -210,8 +210,8 @@ def load(self):
210210
else:
211211
raise RuntimeError(
212212
f"Unsupported global: GLOBAL {full_path} was not an allowed global by default. "
213-
"Please use `torch.serialization.add_safe_globals` to allowlist this global "
214-
"if you trust this class/function."
213+
f"Please use `torch.serialization.add_safe_globals([{name}])` to allowlist "
214+
"this global if you trust this class/function."
215215
)
216216
elif key[0] == NEWOBJ[0]:
217217
args = self.stack.pop()

torch/serialization.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import io
66
import os
77
import pickle
8+
import re
89
import shutil
910
import struct
1011
import sys
@@ -1107,12 +1108,33 @@ def load(
11071108
"""
11081109
torch._C._log_api_usage_once("torch.load")
11091110
UNSAFE_MESSAGE = (
1110-
"Weights only load failed. Re-running `torch.load` with `weights_only` set to `False`"
1111-
" will likely succeed, but it can result in arbitrary code execution."
1112-
" Do it only if you get the file from a trusted source. Alternatively, to load"
1113-
" with `weights_only` please check the recommended steps in the following error message."
1114-
" WeightsUnpickler error: "
1111+
"Re-running `torch.load` with `weights_only` set to `False` will likely succeed, "
1112+
"but it can result in arbitrary code execution. Do it only if you got the file from a "
1113+
"trusted source."
11151114
)
1115+
DOCS_MESSAGE = (
1116+
"\n\nCheck the documentation of torch.load to learn more about types accepted by default with "
1117+
"weights_only https://pytorch.org/docs/stable/generated/torch.load.html."
1118+
)
1119+
1120+
def _get_wo_message(message: str) -> str:
1121+
pattern = r"GLOBAL (\S+) was not an allowed global by default."
1122+
has_unsafe_global = re.search(pattern, message) is not None
1123+
if has_unsafe_global:
1124+
updated_message = (
1125+
"Weights only load failed. This file can still be loaded, to do so you have two options "
1126+
f"\n\t(1) {UNSAFE_MESSAGE}\n\t(2) Alternatively, to load with `weights_only=True` please check "
1127+
"the recommended steps in the following error message.\n\tWeightsUnpickler error: "
1128+
+ message
1129+
)
1130+
else:
1131+
updated_message = (
1132+
f"Weights only load failed. {UNSAFE_MESSAGE}\n Please file an issue with the following "
1133+
"so that we can make `weights_only=True` compatible with your use case: WeightsUnpickler "
1134+
"error: " + message
1135+
)
1136+
return updated_message + DOCS_MESSAGE
1137+
11161138
if weights_only is None:
11171139
weights_only, warn_weights_only = False, True
11181140
else:
@@ -1200,7 +1222,7 @@ def load(
12001222
**pickle_load_args,
12011223
)
12021224
except RuntimeError as e:
1203-
raise pickle.UnpicklingError(UNSAFE_MESSAGE + str(e)) from None
1225+
raise pickle.UnpicklingError(_get_wo_message(str(e))) from None
12041226
return _load(
12051227
opened_zipfile,
12061228
map_location,
@@ -1224,7 +1246,7 @@ def load(
12241246
**pickle_load_args,
12251247
)
12261248
except RuntimeError as e:
1227-
raise pickle.UnpicklingError(UNSAFE_MESSAGE + str(e)) from None
1249+
raise pickle.UnpicklingError(_get_wo_message(str(e))) from None
12281250
return _legacy_load(
12291251
opened_file, map_location, pickle_module, **pickle_load_args
12301252
)

0 commit comments

Comments
 (0)
0