8000 Revert "Add warning for weights_only (#129239)" · pytorch/pytorch@b1f486a · GitHub
[go: up one dir, main page]

Skip to content

Commit b1f486a

Browse files
Revert "Add warning for weights_only (#129239)"
This reverts commit 381ce08. Reverted #129239 on behalf of https://github.com/huydhn due to Sorry for reverting your change but I am seeing some test_nn failures from ROCm https://hud.pytorch.org/pytorch/pytorch/commit/381ce0821c3fa2b342f0b8660c76cc27f48543c4, trying to revert this to see if trunk recovers ([comment](#129239 (comment)))
1 parent 7cf454e commit b1f486a

File tree

3 files changed

+9
-37
lines changed

3 files changed

+9
-37
lines changed

test/test_nn.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1802,35 +1802,26 @@ def test_parameterlistdict_pickle(self):
18021802
m = nn.ParameterList(map(nn.Parameter, [torch.rand(2), torch.rand(2)]))
18031803
with warnings.catch_warnings(record=True) as w:
18041804
m = pickle.loads(pickle.dumps(m))
1805-
# warning from torch.load call in _load_from_bytes
1806-
num_warnings = 2 if torch._dynamo.is_compiling() else 1
1807-
self.assertTrue(len(w) == num_warnings)
1808-
self.assertEqual(w[0].category, FutureWarning)
1805+
self.assertTrue(len(w) == 0)
18091806

18101807
# Test whether loading from older checkpoints works without triggering warnings
18111808
m = nn.ParameterList(map(nn.Parameter, [torch.rand(2), torch.rand(2)]))
18121809
del m._forward_pre_hooks, m._state_dict_hooks, m._load_state_dict_pre_hooks, m._non_persistent_buffers_set
18131810
with warnings.catch_warnings(record=True) as w:
18141811
m = pickle.loads(pickle.dumps(m))
1815-
# warning from torch.load call in _load_from_bytes
1816-
self.assertTrue(len(w) == 1)
1817-
self.assertEqual(w[0].category, FutureWarning)
1812+
self.assertTrue(len(w) == 0)
18181813

18191814
m = nn.ParameterDict({"a": nn.Parameter(torch.rand(2)), "b": nn.Parameter(torch.rand(2))})
18201815
with warnings.catch_warnings(record=True) as w:
18211816
m = pickle.loads(pickle.dumps(m))
1822-
# warning from torch.load call in _load_from_bytes
1823-
self.assertTrue(len(w) == 1)
1824-
self.assertEqual(w[0].category, FutureWarning)
1817+
self.assertTrue(len(w) == 0)
18251818

18261819
# Test whether loading from older checkpoints works without triggering warnings
18271820
m = nn.ParameterDict({"a": nn.Parameter(torch.rand(2)), "b": nn.Parameter(torch.rand(2))})
18281821
del m._forward_pre_hooks, m._state_dict_hooks, m._load_state_dict_pre_hooks, m._non_persistent_buffers_set
18291822
with warnings.catch_warnings(record=True) as w:
18301823
m = pickle.loads(pickle.dumps(m))
1831-
# warning from torch.load call in _load_from_bytes
1832-
self.assertTrue(len(w) == 1)
1833-
self.assertEqual(w[0].category, FutureWarning)
1824+
self.assertTrue(len(w) == 0)
18341825

18351826
def test_weight_norm_pickle(self):
18361827
m = torch.nn.utils.weight_norm(nn.Linear(5, 7))

test/test_serialization.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -837,6 +837,7 @@ def test(f_new, f_old):
837837
test(f_new, f_old)
838838
self.assertTrue(len(w) == 0, msg=f"Expected no warnings but got {[str(x) for x in w]}")
839839

840+
840841
class TestOldSerialization(TestCase, SerializationMixin):
841842
# unique_key is necessary because on Python 2.7, if a warning passed to
842843
# the warning module is the same, it is not raised again.
@@ -864,8 +865,7 @@ def import_module(name, filename):
864865
loaded = torch.load(checkpoint)
865866
self.assertTrue(isinstance(loaded, module.Net))
866867
if can_retrieve_source:
867-
self.assertEqual(len(w), 1)
868-
self.assertEqual(w[0].category, FutureWarning)
868+
self.assertEqual(len(w), 0)
869869

870870
# Replace the module with different source
871871
fname 10000 = get_file_path_2(os.path.dirname(os.path.dirname(torch.__file__)), 'torch', 'testing',
@@ -876,8 +876,8 @@ def import_module(name, filename):
876876
loaded = torch.load(checkpoint)
877877
self.assertTrue(isinstance(loaded, module.Net))
878878
if can_retrieve_source:
879-
self.assertEqual(len(w), 2)
880-
self.assertTrue(w[1].category, 'SourceChangeWarning')
879+
self.assertEqual(len(w), 1)
880+
self.assertTrue(w[0].category, 'SourceChangeWarning')
881881

882882
def test_serialization_container(self):
883883
self._test_serialization_container('file', tempfile.NamedTemporaryFile)

torch/serialization.py

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -987,7 +987,7 @@ def load(
987987
map_location: MAP_LOCATION = None,
988988
pickle_module: Any = None,
989989
*,
990-
weights_only: Optional[bool] = None,
990+
weights_only: bool = False,
991991
mmap: Optional[bool] = None,
992992
**pickle_load_args: Any,
993993
) -> Any:
@@ -1097,11 +1097,6 @@ def load(
10971097
" with `weights_only` please check the recommended steps in the following error message."
10981098
" WeightsUnpickler error: "
10991099
)
1100-
if weights_only is None:
1101-
weights_only, warn_weights_only = False, True
1102-
else:
1103-
warn_weights_only = False
1104-
11051100
# Add ability to force safe only weight loads via environment variable
11061101
if os.getenv("TORCH_FORCE_WEIGHTS_ONLY_LOAD", "0").lower() in [
11071102
"1",
@@ -1118,20 +1113,6 @@ def load(
11181113
)
11191114
else:
11201115
if pickle_module is None:
1121-
if warn_weights_only:
1122-
warnings.warn(
1123-
"You are using `torch.load` with `weights_only=False` (the current default value), which uses "
1124-
"the default pickle module implicitly. It is possible to construct malicious pickle data "
1125-
"which will execute arbitrary code during unpickling (See "
1126-
"https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). "
1127-
"In a future release, the default value for `weights_only` will be flipped to `True`. This "
1128-
"limits the functions that could be executed during unpickling. Arbitrary objects will no "
1129-
"longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the "
1130-
"user via `torch.serialization.add_safe_globals`. We recommend you start setting "
1131-
"`weights_only=True` for any use case where you don't have full control of the loaded file. "
1132-
"Please open an issue on GitHub for any issues related to this experimental feature.",
1133-
FutureWarning,
1134-
)
11351116
pickle_module = pickle
11361117

11371118
# make flipping default BC-compatible

0 commit comments

Comments
 (0)
0