8000 Update on "Prevent legacy_load when weights_only=True (correctly)" · pytorch/pytorch@9ec6b44 · GitHub
[go: up one dir, main page]

Skip to content

Commit 9ec6b44

Browse files
Update on "Prevent legacy_load when weights_only=True (correctly)"
Only prevent `legacy_load` (.tar format removed in #713), not the whole of `_legacy_load` (.tar format + _use_new_zipfile_serialization=False) Differential Revision: [D68301405](https://our.internmc.facebook.com/intern/diff/D68301405) [ghstack-poisoned]
1 parent b829da8 commit 9ec6b44

File tree

1 file changed

+5
-7
lines changed

1 file changed

+5
-7
lines changed

torch/serialization.py

Lines changed: 5 additions & 7 deletions
< 8000 /div>
Original file line numberDiff line numberDiff line change
@@ -1612,6 +1612,11 @@ def persistent_load(saved_id):
16121612
with closing(
16131613
tarfile.open(fileobj=f, mode="r:", format=tarfile.PAX_FORMAT)
16141614
) as tar, mkdtemp() as tmpdir:
1615+
if pickle_module is _weights_only_unpickler:
1616+
raise RuntimeError(
1617+
"Cannot use ``weights_only=True`` with files saved in the "
1618+
"legacy .tar format. " + UNSAFE_MESSAGE
1619+
)
16151620
tar.extract("storages", path=tmpdir)
16161621
with open(os.path.join(tmpdir, "storages"), "rb", 0) as f:
16171622
num_storages = pickle_module.load(f, **pickle_load_args)
@@ -1738,13 +1743,6 @@ def persistent_load(saved_id):
17381743
# legacy_load requires that f has fileno()
17391744
# only if offset is zero we can attempt the legacy tar file loader
17401745
try:
1741-
with closing(tarfile.open(fileobj=f, mode="r:", format=tarfile.PAX_FORMAT)):
1742-
if pickle_module is _weights_only_unpickler:
1743-
raise RuntimeError(
1744-
"Cannot use ``weights_only=True`` with files saved in the "
1745-
"legacy .tar format. " + UNSAFE_MESSAGE
1746-
)
1747-
f.seek(0)
17481746
return legacy_load(f)
17491747
except tarfile.TarError:
17501748
if _is_zipfile(f):

0 commit comments

Comments
 (0)
0