8000 [FSDP][StateDict] Allow FULL_STATE_DICT option for 2D (#120837) · pytorch/pytorch@de8af28 · GitHub
[go: up one dir, main page]

Skip to content

Commit de8af28

Browse files
wz337pytorchmergebot
authored andcommitted
[FSDP][StateDict] Allow FULL_STATE_DICT option for 2D (#120837)
Fixes #120722 TL;DR for the issue: As users are expected to use get_model_state_dict to do state_dict retrieval, I think it's fine to remove the warning and RuntimeError. More context in #120722. Pull Request resolved: #120837 Approved by: https://github.com/Skylion007
1 parent 507611f commit de8af28

File tree

3 files changed

+0
-39
lines changed

3 files changed

+0
-39
lines changed

test/distributed/fsdp/test_fsdp_dtensor_state_dict.py

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -313,30 +313,6 @@ def test_raises_warning_or_errors(self):
313313
with FSDP.state_dict_type(model, StateDictType.LOCAL_STATE_DICT):
314314
optim_state_dict = FSDP.optim_state_dict(model, optim)
315315

316-
with self.assertLogs(
317-
"torch.distributed.fsdp._state_dict_utils", level="WARNING"
318-
) as log:
319-
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT):
320-
state_dict = model.state_dict()
321-
self.assertEqual(len(log.records), 1)
322-
self.assertEqual(len(log.output), 1)
323-
self.assertIn(
324-
"Found both state_dict_type FULL_STATE_DICT and device_mesh.",
325-
log.output[0],
326-
)
327-
328-
with self.assertLogs(
329-
"torch.distributed.fsdp._optim_utils", level="WARNING"
330-
) as log:
331-
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT):
332-
state_dict = FSDP.optim_state_dict(model, optim)
333-
self.assertEqual(len(log.records), 1)
334-
self.assertEqual(len(log.output), 1)
335-
self.assertIn(
336-
"Found both state_dict_type FULL_STATE_DICT and device_mesh.",
337-
log.output[0],
338-
)
339-
340316

341317
instantiate_parametrized_tests(TestFSDPWithDeviceMeshAndDTensor)
342318
if __name__ == "__main__":

torch/distributed/fsdp/_optim_utils.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2082,10 +2082,5 @@ def _set_optim_use_dtensor(
20822082
"DeviceMesh is not compatible with LOCAL_STATE_DICT.",
20832083
"Please set state_dict_type to SHARDED_STATE_DICT to get DTensor state_dict.",
20842084
)
2085-
elif state_dict_type == StateDictType.FULL_STATE_DICT:
2086-
logger.warning(
2087-
"Found both state_dict_type FULL_STATE_DICT and device_mesh. " # noqa: G004
2088-
"Please set state_dict_type to SHARDED_STATE_DICT to get DTensor state_dict."
2089-
)
20902085
else:
20912086
state_dict_settings.optim_state_dict_config._use_dtensor = True

torch/distributed/fsdp/_state_dict_utils.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -302,11 +302,6 @@ def _full_pre_state_dict_hook(
302302
"""
303303
if getattr(fsdp_state, "_device_mesh", False):
304304
parent_mesh = _mesh_resources.get_parent_mesh(fsdp_state._device_mesh)
305-
if parent_mesh:
306-
raise RuntimeError(
307-
f"Found FSDP's device_mesh {fsdp_state._device_mesh} has a parent device_mesh {parent_mesh}.",
308-
"We do not support FULL_STATE_DICT for 2D FSDP + TP. Please use FSDP SHARDED_STATE_DICT instead.",
309-
)
310305

311306
_common_pre_state_dict_hook(module, fsdp_state)
312307
_common_unshard_pre_state_dict_hook(
@@ -808,11 +803,6 @@ def _set_use_dtensor(fsdp_state: _FSDPState) -> None:
808803
"DeviceMesh is not compatible with LOCAL_STATE_DICT.",
809804
"Please set state_dict_type to SHARDED_STATE_DICT to get DTensor state_dict.",
810805
)
811-
elif state_dict_type == StateDictType.FULL_STATE_DICT:
812-
logger.warning(
813-
"Found both state_dict_type FULL_STATE_DICT and device_mesh. " # noqa: G004
814-
"Please set state_dict_type to SHARDED_STATE_DICT to get DTensor state_dict."
815-
)
816806
else:
817807
fsdp_state._state_dict_config._use_dtensor = True
818808

0 commit comments

Comments
 (0)
0