@@ -224,6 +224,9 @@ def _test_serialization(self, weights_only):
224
224
def test_serialization (self ):
225
225
self ._test_serialization (False )
226
226
227
+ def test_serialization_safe (self ):
228
+ self ._test_serialization (True )
229
+
227
230
def test_serialization_filelike (self ):
228
231
# Test serialization (load and save) with a filelike object
229
232
b = self ._test_serialization_data ()
@@ -359,6 +362,9 @@ def _test_serialization(conversion):
359
362
def test_serialization_sparse (self ):
360
363
self ._test_serialization (False )
361
364
365
+ def test_serialization_sparse_safe (self ):
366
+ self ._test_serialization (True )
367
+
362
368
def test_serialization_sparse_invalid (self ):
363
369
x = torch .zeros (3 , 3 )
364
370
x [1 ][1 ] = 1
@@ -504,6 +510,9 @@ def __reduce__(self):
504
510
def test_serialization_backwards_compat (self ):
505
511
self ._test_serialization_backwards_compat (False )
506
512
513
+ def test_serialization_backwards_compat_safe (self ):
514
+ self ._test_serialization_backwards_compat (True )
515
+
507
516
def test_serialization_save_warnings (self ):
508
517
with warnings .catch_warnings (record = True ) as warns :
509
518
with tempfile .NamedTemporaryFile () as checkpoint :
@@ -548,8 +557,7 @@ def load_bytes():
548
557
def check_map_locations (map_locations , dtype , intended_device ):
549
558
for fileobject_lambda in fileobject_lambdas :
550
559
for map_location in map_locations :
551
- # weigts_only=False as the downloaded file path uses the old serialization format
552
- tensor = torch .load (fileobject_lambda (), map_location = map_location , weights_only = False )
560
+ tensor = torch .load (fileobject_lambda (), map_location = map_location )
553
561
554
562
self .assertEqual (tensor .device , intended_device )
555
563
self .assertEqual (tensor .dtype , dtype )
@@ -592,8 +600,7 @@ def test_load_nonexistent_device(self):
592
600
593
601
error_msg = r'Attempting to deserialize object on a CUDA device'
594
602
with self .assertRaisesRegex (RuntimeError , error_msg ):
595
- # weights_only=False as serialized is in legacy format
596
- _ = torch .load (buf , weights_only = False )
603
+ _ = torch .load (buf )
597
604
598
605
@unittest .skipIf ((3 , 8 , 0 ) <= sys .version_info < (3 , 8 , 2 ), "See https://bugs.python.org/issue39681" )
599
606
def test_serialization_filelike_api_requirements (self ):
@@ -713,8 +720,7 @@ def test_serialization_storage_slice(self):
713
720
b'\x00 \x00 \x00 \x00 ' )
714
721
715
722
buf = io .BytesIO (serialized )
716
- # serialized was saved with PyTorch 0.3.1
717
- (s1 , s2 ) = torch .load (buf , weights_only = False )
723
+ (s1 , s2 ) = torch .load (buf )
718
724
self .assertEqual (s1 [0 ], 0 )
719
725
self .assertEqual (s2 [0 ], 0 )
720
726
self .assertEqual (s1 .data_ptr () + 4 , s2 .data_ptr ())
@@ -831,24 +837,6 @@ def wrapper(*args, **kwargs):
831
837
def __exit__ (self , * args , ** kwargs ):
832
838
torch .save = self .torch_save
833
839
834
-
835
- # used to set weights_only=False in _use_new_zipfile_serialization=False tests
836
- class load_method :
837
- def __init__ (self , weights_only ):
838
- self .weights_only = weights_only
839
- self .torch_load = torch .load
840
-
841
- def __enter__ (self , * args , ** kwargs ):
842
- def wrapper (* args , ** kwargs ):
843
- kwargs ['weights_only' ] = self .weights_only
844
- return self .torch_load (* args , ** kwargs )
845
-
846
- torch .load = wrapper
847
-
848
- def __exit__ (self , * args , ** kwargs ):
849
- torch .load = self .torch_load
850
-
851
-
852
840
Point = namedtuple ('Point' , ['x' , 'y' ])
853
841
854
842
class ClassThatUsesBuildInstruction :
@@ -885,25 +873,14 @@ def test(f_new, f_old):
885
873
886
874
torch .save (x , f_old , _use_new_zipfile_serialization = False )
887
875
f_old .seek (0 )
888
- x_old_load = torch .load (f_old , weights_only = False )
876
+ x_old_load = torch .load (f_old , weights_only = weights_only )
889
877
self .assertEqual (x_old_load , x_new_load )
890
878
891
879
with AlwaysWarnTypedStorageRemoval (True ), warnings .catch_warnings (record = True ) as w :
892
880
with tempfile .NamedTemporaryFile () as f_new , tempfile .NamedTemporaryFile () as f_old :
893
881
test (f_new , f_old )
894
882
self .assertTrue (len (w ) == 0 , msg = f"Expected no warnings but got { [str (x ) for x in w ]} " )
895
883
896
- def test_old_serialization_fails_with_weights_only (self ):
897
- a = torch .randn (5 , 5 )
898
- with BytesIOContext () as f :
899
- torch .save (a , f , _use_new_zipfile_serialization = False )
900
- f .seek (0 )
901
- with self .assertRaisesRegex (
902
- RuntimeError ,
903
- "Cannot use ``weights_only=True`` with files saved in the .tar format used before version 1.6."
904
- ):
905
- torch .load (f , weights_only = True )
906
-
907
884
908
885
class TestOldSerialization (TestCase , SerializationMixin ):
909
886
# unique_key is necessary because on Python 2.7, if a warning passed to
@@ -979,7 +956,8 @@ def test_serialization_offset(self):
979
956
self .assertEqual (i , i_loaded )
980
957
self .assertEqual (j , j_loaded )
981
958
982
- def test_serialization_offset_filelike (self ):
959
+ @parametrize ('weights_only' , (True , False ))
960
+ def test_serialization_offset_filelike (self , weights_only ):
983
961
a = torch .randn (5 , 5 )
984
962
b = torch .randn (1024 , 1024 , 512 , dtype = torch .float32 )
985
963
i , j = 41 , 43
@@ -991,16 +969,16 @@ def test_serialization_offset_filelike(self):
991
969
self .assertTrue (f .tell () > 2 * 1024 * 1024 * 1024 )
992
970
f .seek (0 )
993
971
i_loaded = pickle .load (f )
994
- a_loaded = torch .load (f )
972
+ a_loaded = torch .load (f , weights_only = weights_only )
995
973
j_loaded = pickle .load (f )
996
- b_loaded = torch .load (f )
974
+ b_loaded = torch .load (f , weights_only = weights_only )
997
975
self .assertTrue (torch .equal (a , a_loaded ))
998
976
self .assertTrue (torch .equal (b , b_loaded ))
999
977
self .assertEqual (i , i_loaded )
1000
978
self .assertEqual (j , j_loaded )
1001
979
1002
980
def run (self , * args , ** kwargs ):
1003
- with serialization_method (use_zip = False ), load_method ( weights_only = False ) :
981
+ with serialization_method (use_zip = False ):
1004
982
return super ().run (* args , ** kwargs )
1005
983
1006
984
0 commit comments