8000 [FSDP][StateDict] Allow FULL_STATE_DICT option for 2D by wz337 · Pull Request #120837 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[FSDP][StateDict] Allow FULL_STATE_DICT option for 2D #120837

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 0 additions & 24 deletions test/distributed/fsdp/test_fsdp_dtensor_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,30 +313,6 @@ def test_raises_warning_or_errors(self):
with FSDP.state_dict_type(model, StateDictType.LOCAL_STATE_DICT):
optim_state_dict = FSDP.optim_state_dict(model, optim)

with self.assertLogs(
"torch.distributed.fsdp._state_dict_utils", level="WARNING"
) as log:
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT):
state_dict = model.state_dict()
self.assertEqual(len(log.records), 1)
self.assertEqual(len(log.output), 1)
self.assertIn(
"Found both state_dict_type FULL_STATE_DICT and device_mesh.",
log.output[0],
)

with self.assertLogs(
"torch.distributed.fsdp._optim_utils", level="WARNING"
) as log:
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT):
state_dict = FSDP.optim_state_dict(model, optim)
self.assertEqual(len(log.records), 1)
self.assertEqual(len(log.output), 1)
self.assertIn(
"Found both state_dict_type FULL_STATE_DICT and device_mesh.",
log.output[0],
)


instantiate_parametrized_tests(TestFSDPWithDeviceMeshAndDTensor)
if __name__ == "__main__":
Expand Down
5 changes: 0 additions & 5 deletions torch/distributed/fsdp/_optim_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2082,10 +2082,5 @@ def _set_optim_use_dtensor(
"DeviceMesh is not compatible with LOCAL_STATE_DICT.",
"Please set state_dict_type to SHARDED_STATE_DICT to get DTensor state_dict.",
)
elif state_dict_type == StateDictType.FULL_STATE_DICT:
logger.warning(
"Found both state_dict_type FULL_STATE_DICT and device_mesh. " # noqa: G004
"Please set state_dict_type to SHARDED_STATE_DICT to get DTensor state_dict."
)
else:
state_dict_settings.optim_state_dict_config._use_dtensor = True
10 changes: 0 additions & 10 deletions torch/distributed/fsdp/_state_dict_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,11 +292,6 @@ def _full_pre_state_dict_hook(
"""
if getattr(fsdp_state, "_device_mesh", False):
parent_mesh = _mesh_resources.get_parent_mesh(fsdp_state._device_mesh)
if parent_mesh:
raise RuntimeError(
f"Found FSDP's device_mesh {fsdp_state._device_mesh} has a parent device_mesh {parent_mesh}.",
"We do not support FULL_STATE_DICT for 2D FSDP + TP. Please use FSDP SHARDED_STATE_DICT instead.",
)

_common_pre_state_dict_hook(module, fsdp_state)
_common_unshard_pre_state_dict_hook(
Expand Down Expand Up @@ -798,11 +793,6 @@ def _set_use_dtensor(fsdp_state: _FSDPState) -> None:
"DeviceMesh is not compatible with LOCAL_STATE_DICT.",
"Please set state_dict_type to SHARDED_STATE_DICT to get DTensor state_dict.",
)
elif state_dict_type == StateDictType.FULL_STATE_DICT:
logger.warning(
"Found both state_dict_type FULL_STATE_DICT and device_mesh. " # noqa: G004
"Please set state_dict_type to SHARDED_STATE_DICT to get DTensor state_dict."
)
else:
fsdp_state._state_dict_config._use_dtensor = True

Expand Down
0