8000 Pass `torch.load(weights_only=)` internally to avoid FutureWarning (#… · pytorch/pytorch@914b9b9 · GitHub
[go: up one dir, main page]

Skip to content
< 8000 script crossorigin="anonymous" defer="defer" type="application/javascript" src="https://github.githubassets.com/assets/sessions-2b47303584e4.js">

Commit 914b9b9

Browse files
awaelchlimlazos
authored andcommitted
Pass torch.load(weights_only=) internally to avoid FutureWarning (#130663)
Fixes #130658 Pull Request resolved: #130663 Approved by: https://github.com/malfet, https://github.com/LucasLLC
1 parent 3b405e4 commit 914b9b9

File tree

5 files changed

+15
-7
lines changed

5 files changed

+15
-7
lines changed

torch/distributed/checkpoint/default_planner.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -210,10 +210,12 @@ def load_bytes(self, read_item: ReadItem, value: io.BytesIO) -> None:
210210
set_element(
211211
self.original_state_dict,
212212
self.mappings[read_item.dest_index.fqn],
213-
torch.load(value),
213+
torch.load(value, weights_only=False),
214214
)
215215
else:
216-
self.state_dict[read_item.dest_index.fqn] = torch.load(value)
216+
self.state_dict[read_item.dest_index.fqn] = torch.load(
217+
value, weights_only=False
218+
)
217219

218220
def resolve_tensor(self, read_item: ReadItem):
219221
tensor = self.lookup_tensor(read_item.dest_index)

torch/distributed/checkpoint/filesystem.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -654,7 +654,11 @@ def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]:
654654
else:
655655
tensor = cast(
656656
Tensor,
657-
torch.load(cast(IO[bytes], file_slice), map_location="cpu"),
657+
torch.load(
658+
cast(IO[bytes], file_slice),
659+
map_location="cpu",
660+
weights_only=True,
661+
),
658662
)
659663
tensor = narrow_tensor_by_index(
660664
tensor, req.storage_offsets, req.lengths

torch/distributed/checkpoint/format_utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,9 @@ def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]:
8484
# TODO: read on each host, instead of only the coordinator
8585
if self.is_coordinator:
8686
assert self.checkpoint_id is not None
87-
torch_state_dict = torch.load(self.checkpoint_id, map_location="cpu")
87+
torch_state_dict = torch.load(
88+
self.checkpoint_id, map_location="cpu", weights_only=False
89+
)
8890
if planner.flatten_state_dict:
8991
torch_state_dict, _ = flatten_state_dict(torch_state_dict)
9092
else:
@@ -231,7 +233,7 @@ def torch_save_to_dcp(
231233
To avoid OOM, it's recommended to only run this function on a single rank.
232234
"""
233235

234-
state_dict = torch.load(torch_save_path)
236+
state_dict = torch.load(torch_save_path, weights_only=False)
235237
# we don't need stateful behavior here because the expectation is anything loaded by
236238
# torch.load would not contain stateful objects.
237239
_save_state_dict(

torch/distributed/checkpoint/planner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,7 @@ class LoadPlanner:
330330
>>>
331331
>>> def load_bytes(self, read_item, value):
332332
>>> # Remove the "foo_" prefix
333-
>>> self.original_state_dict[read_item.dest_index.fqn[4:]] = torch.load(value)
333+
>>> self.original_state_dict[read_item.dest_index.fqn[4:]] = torch.load(value, weights_only=False)
334334
335335
336336
Modifying resolve_tensor and commit_tensor to handle load time transformation.

torch/distributed/optim/zero_redundancy_optimizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def _broadcast_object(
108108
)
109109
dist.broadcast(data_recv_tensor, src=src_rank, group=group, async_op=False)
110110
buffer = io.BytesIO(data_recv_tensor.cpu().numpy())
111-
obj = torch.load(buffer, map_location=device)
111+
obj = torch.load(buffer, map_location=device, weights_only=False)
112112
return obj
113113

114114

0 commit comments

Comments
 (0)
0