8000 fix `distributed.checkpoint.state_dict.set_model_state_dict` returned _IncompatibleKeys when `full_state_dict=True` by YassineYousfi · Pull Request #153351 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

fix distributed.checkpoint.state_dict.set_model_state_dict returned _IncompatibleKeys when full_state_dict=True #153351

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions test/distributed/checkpoint/test_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ def test_single_gpu(self) -> None:
self._test_single_gpu(torch.optim.Adam)
self._test_single_gpu(torch.optim.AdamW)

def _test_strict(self, parallelism: str) -> None:
def _test_strict(self, parallelism: str, full_state_dict: bool) -> None:
model = CompositeParamModel(device=torch.device("cuda"))
if parallelism == "DDP":
model = DDP(model)
Expand A 8000 ll @@ -403,7 +403,7 @@ def _test_strict(self, parallelism: str) -> None:
incompatible_keys = set_model_state_dict(
model,
model_state_dict=model_state_dict,
options=StateDictOptions(strict=False),
options=StateDictOptions(strict=False, full_state_dict=full_state_dict),
)
self.assertEqual(incompatible_keys.missing_keys, [key])
self.assertEqual(incompatible_keys.unexpected_keys, ["abc"])
Expand All @@ -415,7 +415,7 @@ def _test_strict(self, parallelism: str) -> None:
@skip_if_lt_x_gpu(1)
def test_strict(self) -> None:
self.run_subtests(
{"parallelism": ["DDP", "fully_shard"]},
{"parallelism": ["DDP", "fully_shard"], "full_state_dict": [True, False]},
self._test_strict,
)

Expand Down
4 changes: 3 additions & 1 deletion torch/distributed/checkpoint/state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,7 +595,9 @@ def _load_model_state_dict(
)
elif info.full_state_dict:
_distribute_state_dict(state_dict, local_state_dict, device=devices.pop())
state_dict.update(local_state_dict)
for fqn, local_state in local_state_dict.items():
if fqn in state_dict:
state_dict[fqn] = local_state

with info.fsdp_context():
return cast(
Expand Down
0