8000 Prevent legacy_load when weights_only=True (correctly) (#145020) · pytorch/pytorch@0eda02a · GitHub
[go: up one dir, main page]

Skip to content

Commit 0eda02a

Browse files
mikaylagawareckipytorchmergebot
authored andcommitted
Prevent legacy_load when weights_only=True (correctly) (#145020)
Only prevent `legacy_load` (.tar format removed in #713), not the whole of `_legacy_load` (.tar format + _use_new_zipfile_serialization=False) Differential Revision: [D68301405](https://our.internmc.facebook.com/intern/diff/D68301405) Pull Request resolved: #145020 Approved by: https://github.com/kit1980, https://github.com/albanD
1 parent 2ef7b68 commit 0eda02a

File tree

2 files changed

+17
-7
lines changed

2 files changed

+17
-7
lines changed

test/test_serialization.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -466,7 +466,11 @@ def _test_serialization_backwards_compat(self, weights_only):
466466
b += [a[0].storage()]
467467
b += [a[0].reshape(-1)[1:4].clone().storage()]
468468
path = download_file('https://download.pytorch.org/test_data/legacy_serialized.pt')
469-
c = torch.load(path, weights_only=weights_only)
469+
if weights_only:
470+
with self.assertRaisesRegex(RuntimeError,
471+
"Cannot use ``weights_only=True`` with files saved in the legacy .tar format."):
472+
c = torch.load(path, weights_only=weights_only)
473+
c = torch.load(path, weights_only=False)
470474
self.assertEqual(b, c, atol=0, rtol=0)
471475
self.assertTrue(isinstance(c[0], torch.FloatTensor))
472476
self.assertTrue(isinstance(c[1], torch.FloatTensor))

torch/serialization.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,13 @@
8585

8686
IS_WINDOWS = sys.platform == "win32"
8787

88+
UNSAFE_MESSAGE = (
89+
"In PyTorch 2.6, we changed the default value of the `weights_only` argument in `torch.load` "
90+
"from `False` to `True`. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, "
91+
"but it can result in arbitrary code execution. Do it only if you got the file from a "
92+
"trusted source."
93+
)
94+
8895
if not IS_WINDOWS:
8996
from mmap import MAP_PRIVATE, MAP_SHARED
9097
else:
@@ -1341,12 +1348,6 @@ def load(
13411348
>>> torch.load("module.pt", encoding="ascii", weights_only=False)
13421349
"""
13431350
torch._C._log_api_usage_once("torch.load")
1344-
UNSAFE_MESSAGE = (
1345-
"In PyTorch 2.6, we changed the default value of the `weights_only` argument in `torch.load` "
1346-
"from `False` to `True`. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, "
1347-
"but it can result in arbitrary code execution. Do it only if you got the file from a "
1348-
"trusted source."
1349-
)
13501351
DOCS_MESSAGE = (
13511352
"\n\nCheck the documentation of torch.load to learn more about types accepted by default with "
13521353
"weights_only https://pytorch.org/docs/stable/generated/torch.load.html."
@@ -1611,6 +1612,11 @@ def persistent_load(saved_id):
16111612
with closing(
16121613
tarfile.open(fileobj=f, mode="r:", format=tarfile.PAX_FORMAT)
16131614
) as tar, mkdtemp() as tmpdir:
1615+
if pickle_module is _weights_only_unpickler:
1616+
raise RuntimeError(
1617+
"Cannot use ``weights_only=True`` with files saved in the "
1618+
"legacy .tar format. " + UNSAFE_MESSAGE
1619+
)
16141620
tar.extract("storages", path=tmpdir)
16151621
with open(os.path.join(tmpdir, "storages"), "rb", 0) as f:
16161622
num_storages = pickle_module.load(f, **pickle_load_args)

0 commit comments

Comments
 (0)
0