8000 Revert "Prevent _legacy_load with weights_only=True (#144993)" · pytorch/pytorch@9c34a20 · GitHub
[go: up one dir, main page]

Skip to content

Commit 9c34a20

Browse files
committed
Revert "Prevent _legacy_load with weights_only=True (#144993)"
This reverts commit cd15d7b.
1 parent cd15d7b commit 9c34a20

File tree

3 files changed

+29
-48
lines changed

3 files changed

+29
-48
lines changed

test/quantization/bc/test_backward_compatibility.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -112,14 +112,12 @@ def _test_op(
112112
torch.jit.save(torch.jit.trace(qmodule, input_tensor), traced_module_file)
113113
torch.save(qmodule(input_tensor), expected_file)
114114

115-
# weights_only=False as file was saved in .tar format
116-
input_tensor = torch.load(input_file, weights_only=False)
115+
input_tensor = torch.load(input_file)
117116
# weights_only = False as sometimes get ScriptObject here
118117
qmodule.load_state_dict(torch.load(state_dict_file, weights_only=False))
119118
qmodule_scripted = torch.jit.load(scripted_module_file)
120119
qmodule_traced = torch.jit.load(traced_module_file)
121-
# weights_only=False as file was saved in .tar format
122-
expected = torch.load(expected_file, weights_only=False)
120+
expected = torch.load(expected_file)
123121
self.assertEqual(qmodule(input_tensor), expected, atol=prec)
124122
self.assertEqual(qmodule_scripted(input_tensor), expected, atol=prec)
125123
self.assertEqual(qmodule_traced(input_tensor), expected, atol=prec)

test/test_serialization.py

Lines changed: 18 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,9 @@ def _test_serialization(self, weights_only):
224224
def test_serialization(self):
225225
self._test_serialization(False)
226226

227+
def test_serialization_safe(self):
228+
self._test_serialization(True)
229+
227230
def test_serialization_filelike(self):
228231
# Test serialization (load and save) with a filelike object
229232
b = self._test_serialization_data()
@@ -359,6 +362,9 @@ def _test_serialization(conversion):
359362
def test_serialization_sparse(self):
360363
self._test_serialization(False)
361364

365+
def test_serialization_sparse_safe(self):
366+
self._test_serialization(True)
367+
362368
def test_serialization_sparse_invalid(self):
363369
x = torch.zeros(3, 3)
364370
x[1][1] = 1
@@ -504,6 +510,9 @@ def __reduce__(self):
504510
def test_serialization_backwards_compat(self):
505511
self._test_serialization_backwards_compat(False)
506512

513+
def test_serialization_backwards_compat_safe(self):
514+
self._test_serialization_backwards_compat(True)
515+
507516
def test_serialization_save_warnings(self):
508517
with warnings.catch_warnings(record=True) as warns:
509518
with tempfile.NamedTemporaryFile() as checkpoint:
@@ -548,8 +557,7 @@ def load_bytes():
548557
def check_map_locations(map_locations, dtype, intended_device):
549558
for fileobject_lambda in fileobject_lambdas:
550559
for map_location in map_locations:
551-
# weigts_only=False as the downloaded file path uses the old serialization format
552-
tensor = torch.load(fileobject_lambda(), map_location=map_location, weights_only=False)
560+
tensor = torch.load(fileobject_lambda(), map_location=map_location)
553561

554562
self.assertEqual(tensor.device, intended_device)
555563
self.assertEqual(tensor.dtype, dtype)
@@ -592,8 +600,7 @@ def test_load_nonexistent_device(self):
592600

593601
error_msg = r'Attempting to deserialize object on a CUDA device'
594602
with self.assertRaisesRegex(RuntimeError, error_msg):
595-
# weights_only=False as serialized is in legacy format
596-
_ = torch.load(buf, weights_only=False)
603+
_ = torch.load(buf)
597604

598605
@unittest.skipIf((3, 8, 0) <= sys.version_info < (3, 8, 2), "See https://bugs.python.org/issue39681")
599606
def test_serialization_filelike_api_requirements(self):
@@ -713,8 +720,7 @@ def test_serialization_storage_slice(self):
713720
b'\x00\x00\x00\x00')
714721

715722
buf = io.BytesIO(serialized)
716-
# serialized was saved with PyTorch 0.3.1
717-
(s1, s2) = torch.load(buf, weights_only=False)
723+
(s1, s2) = torch.load(buf)
718724
self.assertEqual(s1[0], 0)
719725
self.assertEqual(s2[0], 0)
720726
self.assertEqual(s1.data_ptr() + 4, s2.data_ptr())
@@ -831,24 +837,6 @@ def wrapper(*args, **kwargs):
831837
def __exit__(self, *args, **kwargs):
832838
torch.save = self.torch_save
833839

834-
835-
# used to set weights_only=False in _use_new_zipfile_serialization=False tests
836-
class load_method:
837-
def __init__(self, weights_only):
838-
self.weights_only = weights_only
839-
self.torch_load = torch.load
840-
841-
def __enter__(self, *args, **kwargs):
842-
def wrapper(*args, **kwargs):
843-
kwargs['weights_only'] = self.weights_only
844-
return self.torch_load(*args, **kwargs)
845-
846-
torch.load = wrapper
847-
848-
def __exit__(self, *args, **kwargs):
849-
torch.load = self.torch_load
850-
851-
852840
Point = namedtuple('Point', ['x', 'y'])
853841

854842
class ClassThatUsesBuildInstruction:
@@ -885,25 +873,14 @@ def test(f_new, f_old):
885873

886874
torch.save(x, f_old, _use_new_zipfile_serialization=False)
887875
f_old.seek(0)
888-
x_old_load = torch.load(f_old, weights_only=False)
876+
x_old_load = torch.load(f_old, weights_only=weights_only)
889877
self.assertEqual(x_old_load, x_new_load)
890878

891879
with AlwaysWarnTypedStorageRemoval(True), warnings.catch_warnings(record=True) as w:
892880
with tempfile.NamedTemporaryFile() as f_new, tempfile.NamedTemporaryFile() as f_old:
893881
test(f_new, f_old)
894882
self.assertTrue(len(w) == 0, msg=f"Expected no warnings but got {[str(x) for x in w]}")
895883

896-
def test_old_serialization_fails_with_weights_only(self):
897-
a = torch.randn(5, 5)
898-
with BytesIOContext() as f:
899-
torch.save(a, f, _use_new_zipfile_serialization=False)
900-
f.seek(0)
901-
with self.assertRaisesRegex(
902-
RuntimeError,
903-
"Cannot use ``weights_only=True`` with files saved in the .tar format used before version 1.6."
904-
):
905-
torch.load(f, weights_only=True)
906-
907884

908885
class TestOldSerialization(TestCase, SerializationMixin):
909886
# unique_key is necessary because on Python 2.7, if a warning passed to
@@ -979,7 +956,8 @@ def test_serialization_offset(self):
979956
self.assertEqual(i, i_loaded)
980957
self.assertEqual(j, j_loaded)
981958

982-
def test_serialization_offset_filelike(self):
959+
@parametrize('weights_only', (True, False))
960+
def test_serialization_offset_filelike(self, weights_only):
983961
a = torch.randn(5, 5)
984962
b = torch.randn(1024, 1024, 512, dtype=torch.float32)
985963
i, j = 41, 43
@@ -991,16 +969,16 @@ def test_serialization_offset_filelike(self):
991969
self.assertTrue(f.tell() > 2 * 1024 * 1024 * 1024)
992970
f.seek(0)
993971
i_loaded = pickle.load(f)
994-
a_loaded = torch.load(f)
972+
a_loaded = torch.load(f, weights_only=weights_only)
995973
j_loaded = pickle.load(f)
996-
b_loaded = torch.load(f)
974+
b_loaded = torch.load(f, weights_only=weights_only)
997975
self.assertTrue(torch.equal(a, a_loaded))
998976
self.assertTrue(torch.equal(b, b_loaded))
999977
self.assertEqual(i, i_loaded)
1000978
self.assertEqual(j, j_loaded)
1001979

1002980
def run(self, *args, **kwargs):
1003-
with serialization_method(use_zip=False), load_method(weights_only=False):
981+
with serialization_method(use_zip=False):
1004982
return super().run(*args, **kwargs)
1005983

1006984

torch/serialization.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1482,10 +1482,15 @@ def _get_wo_message(message: str) -> str:
14821482
"please torch.save your checkpoint with this option in order to use mmap."
14831483
)
14841484
if weights_only:
1485-
raise RuntimeError(
1486-
"Cannot use ``weights_only=True`` with files saved in the "
1487-
".tar format used before version 1.6. " + UNSAFE_MESSAGE
1488-
)
1485+
try:
1486+
return _legacy_load(
1487+
opened_file,
1488+
map_location,
1489+
_weights_only_unpickler,
1490+
**pickle_load_args,
1491+
)
1492+
except pickle.UnpicklingError as e:
1493+
raise pickle.UnpicklingError(_get_wo_message(str(e))) from None
14891494
return _legacy_load(
14901495
opened_file, map_location, pickle_module, **pickle_load_args
14911496
)

0 commit comments

Comments
 (0)
0