8000 Pass `torch.load(weights_only=)` internally to avoid FutureWarning (#… · pytorch/pytorch@b639d82 · GitHub
[go: up one dir, main page]

Skip to content

Commit b639d82

Browse files
awaelchlipytorchbot
authored andcommitted
Pass torch.load(weights_only=) internally to avoid FutureWarning (#130663)
Fixes #130658 Pull Request resolved: #130663 Approved by: https://github.com/malfet, https://github.com/LucasLLC (cherry picked from commit ad314a2)
1 parent 58ab993 commit b639d82

File tree

5 files changed

+15
-7
lines changed

5 files changed

+15
-7
lines changed

torch/distributed/checkpoint/default_planner.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -209,10 +209,12 @@ def load_bytes(self, read_item: ReadItem, value: io.BytesIO) -> None:
209209
set_element(
210210
self.original_state_dict,
211211
self.mappings[read_item.dest_index.fqn],
212-
torch.load(value),
212+
torch.load(value, weights_only=False),
213213
)
214214
else:
215-
self.state_dict[read_item.dest_index.fqn] = torch.load(value)
215+
self.state_dict[read_item.dest_index.fqn] = torch.load(
216+
value, weights_only=False
217+
)
216218

217219
def resolve_tensor(self, read_item: ReadItem):
218220
tensor = self.lookup_tensor(read_item.dest_index)

torch/distributed/checkpoint/filesystem.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -654,7 +654,11 @@ def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]:
654654
else:
655655
tensor = cast(
656656
Tensor,
657-
torch.load(cast(IO[bytes], file_slice), map_location="cpu"),
657+
torch.load(
658+
cast(IO[bytes], file_slice),
659+
map_location="cpu",
660+
weights_only=True,
661+
),
658662
)
659663
tensor = narrow_tensor_by_index(
660664
tensor, req.storage_offsets, req.lengths

torch/distributed/checkpoint/format_utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,9 @@ def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]:
8484
# TODO: read on each host, instead of only the coordinator
8585
if self.is_coordinator:
8686
assert self.checkpoint_id is not None
87-
torch_state_dict = torch.load(self.checkpoint_id, map_location="cpu")
87+
torch_state_dict = torch.load(
88+
self.checkpoint_id, map_location="cpu", weights_only=False
89+
)
8890
if planner.flatten_state_dict:
8991
torch_state_dict, _ = flatten_state_dict(torch_state_dict)
9092
else:
@@ -230,7 +232,7 @@ def torch_save_to_dcp(
230232
To avoid OOM, it's recommended to only run this function on a single rank.
231233
"""
232234

233-
state_dict = torch.load(torch_save_path)
235+
state_dict = torch.load(torch_save_path, weights_only=False)
234236
# we don't need stateful behavior here because the expectation is anything loaded by
235237
# torch.load would not contain stateful objects.
236238
_save_state_dict(

torch/distributed/checkpoint/planner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,7 @@ class LoadPlanner:
331331
>>>
332332
>>> def load_bytes(self, read_item, value):
333333
>>> # Remove the "foo_" prefix
334-
>>> self.original_state_dict[read_item.dest_index.fqn[4:]] = torch.load(value)
334+
>>> self.original_state_dict[read_item.dest_index.fqn[4:]] = torch.load(value, weights_only=False)
335335
336336
337337
Modifying resolve_tensor and commit_tensor to handle load time transformation.

torch/distributed/optim/zero_redundancy_optimizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def _broadcast_object(
107107
)
108108
dist.broadcast(data_recv_tensor, src=src_rank, group=group, async_op=False)
109109
buffer = io.BytesIO(data_recv_tensor.cpu().numpy())
110-
obj = torch.load(buffer, map_location=device)
110+
obj = torch.load(buffer, map_location=device, weights_only=False)
111111
return obj
112112

113113

0 commit comments

Comments
 (0)
0