8000 allow full_state_dict option for 2D · pytorch/pytorch@e70e223 · GitHub
[go: up one dir, main page]

Skip to content

Commit e70e223

Browse files
committed
allow full_state_dict option for 2D
1 parent 5a0a964 commit e70e223

File tree

2 files changed

+0
-15
lines changed

2 files changed

+0
-15
lines changed

torch/distributed/fsdp/_optim_utils.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2082,10 +2082,5 @@ def _set_optim_use_dtensor(
20822082
"DeviceMesh is not compatible with LOCAL_STATE_DICT.",
20832083
"Please set state_dict_type to SHARDED_STATE_DICT to get DTensor state_dict.",
20842084
)
2085-
elif state_dict_type == StateDictType.FULL_STATE_DICT:
2086-
logger.warning(
2087-
"Found both state_dict_type FULL_STATE_DICT and device_mesh. " # noqa: G004
2088-
"Please set state_dict_type to SHARDED_STATE_DICT to get DTensor state_dict."
2089-
)
20902085
else:
20912086
state_dict_settings.optim_state_dict_config._use_dtensor = True

torch/distributed/fsdp/_state_dict_utils.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -292,11 +292,6 @@ def _full_pre_state_dict_hook(
292292
"""
293293
if getattr(fsdp_state, "_device_mesh", False):
294294
parent_mesh = _mesh_resources.get_parent_mesh(fsdp_state._device_mesh)
295-
if parent_mesh:
296-
raise RuntimeError(
297-
f"Found FSDP's device_mesh {fsdp_state._device_mesh} has a parent device_mesh {parent_mesh}.",
298-
"We do not support FULL_STATE_DICT for 2D FSDP + TP. Please use FSDP SHARDED_STATE_DICT instead.",
299-
)
300295

301296
_common_pre_state_dict_hook(module, fsdp_state)
302297
_common_unshard_pre_state_dict_hook(
@@ -798,11 +793,6 @@ def _set_use_dtensor(fsdp_state: _FSDPState) -> None:
798793
"DeviceMesh is not compatible with LOCAL_STATE_DICT.",
799794
"Please set state_dict_type to SHARDED_STATE_DICT to get DTensor state_dict.",
800795
)
801-
elif state_dict_type == StateDictType.FULL_STATE_DICT:
802-
logger.warning(
803-
"Found both state_dict_type FULL_STATE_DICT and device_mesh. " # noqa: G004
804-
"Please set state_dict_type to SHARDED_STATE_DICT to get DTensor state_dict."
805-
)
806796
else:
807797
fsdp_state._state_dict_config._use_dtensor = True
808798

0 commit comments

Comments
 (0)
0