8000 allow full_state_dict option for 2D · pytorch/pytorch@2ab6d61 · GitHub
[go: up one dir, main page]

Skip to content

Commit 2ab6d61

Browse files
committed
allow full_state_dict option for 2D
1 parent 5a0a964 commit 2ab6d61

File tree

3 files changed

+0
-39
lines changed

3 files changed

+0
-39
lines changed

test/distributed/fsdp/test_fsdp_dtensor_state_dict.py

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -313,30 +313,6 @@ def test_raises_warning_or_errors(self):
313313
with FSDP.state_dict_type(model, StateDictType.LOCAL_STATE_DICT):
314314
optim_state_dict = FSDP.optim_state_dict(model, optim)
315315

316-
with self.assertLogs(
317-
"torch.distributed.fsdp._state_dict_utils", level="WARNING"
318-
) as log:
319-
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT):
320-
state_dict = model.state_dict()
321-
self.assertEqual(len(log.records), 1)
322-
self.assertEqual(len(log.output), 1)
323-
self.assertIn(
324-
"Found both state_dict_type FULL_STATE_DICT and device_mesh.",
325-
log.output[0],
326-
)
327-
328-
with self.assertLogs(
329-
"torch.distributed.fsdp._optim_utils", level="WARNING"
330-
) as log:
331-
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT):
332-
state_dict = FSDP.optim_state_dict(model, optim)
333-
self.assertEqual(len(log.records), 1)
334-
self.assertEqual(len(log.output), 1)
335-
self.assertIn(
336-
"Found both state_dict_type FULL_STATE_DICT and device_mesh.",
337-
log.output[0],
338-
)
339-
340316

341317
instantiate_parametrized_tests(TestFSDPWithDeviceMeshAndDTensor)
342318
if __name__ == "__main__":

torch/distributed/fsdp/_optim_utils.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2082,10 +2082,5 @@ def _set_optim_use_dtensor(
20822082
"DeviceMesh is not compatible with LOCAL_STATE_DICT.",
20832083
"Please set state_dict_type to SHARDED_STATE_DICT to get DTensor state_dict.",
20842084
)
2085-
elif state_dict_type == StateDictType.FULL_STATE_DICT:
2086-
logger.warning(
2087-
"Found both state_dict_type FULL_STATE_DICT and device_mesh. " # noqa: G004
2088-
"Please set state_dict_type to SHARDED_STATE_DICT to get DTensor state_dict."
2089-
)
20902085
else:
20912086
state_dict_settings.optim_state_dict_config._use_dtensor = True

torch/distributed/fsdp/_state_dict_utils.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -292,11 +292,6 @@ def _full_pre_state_dict_hook(
292292
"""
293293
if getattr(fsdp_state, "_device_mesh", False):
294294
parent_mesh = _mesh_resources.get_parent_mesh(fsdp_state._device_mesh)
295-
if parent_mesh:
296-
raise RuntimeError(
297-
f"Found FSDP's device_mesh {fsdp_state._device_mesh} has a parent device_mesh {parent_mesh}.",
298-
"We do not support FULL_STATE_DICT for 2D FSDP + TP. Please use FSDP SHARDED_STATE_DICT instead.",
299-
)
300295

301296
_common_pre_state_dict_hook(module, fsdp_state)
302297
_common_unshard_pre_state_dict_hook(
@@ -798,11 +793,6 @@ def _set_use_dtensor(fsdp_state: _FSDPState) -> None:
798793
"DeviceMesh is not compatible with LOCAL_STATE_DICT.",
799794
"Please set state_dict_type to SHARDED_STATE_DICT to get DTensor state_dict.",
800795
)
801-
elif state_dict_type == StateDictType.FULL_STATE_DICT:
802-
logger.warning(
803-
"Found both state_dict_type FULL_STATE_DICT and device_mesh. " # noqa: G004
804-
"Please set state_dict_type to SHARDED_STATE_DICT to get DTensor state_dict."
805-
)
806796
else:
807797
fsdp_state._state_dict_config._use_dtensor = True
808798

0 commit comments

Comments
 (0)
0