From 2ab6d61d6fdd7137ec80f58992b9e154affce48c Mon Sep 17 00:00:00 2001 From: wz337 Date: Wed, 28 Feb 2024 13:56:15 -0800 Subject: [PATCH] allow full_state_dict option for 2D --- .../fsdp/test_fsdp_dtensor_state_dict.py | 24 ------------------- torch/distributed/fsdp/_optim_utils.py | 5 ---- torch/distributed/fsdp/_state_dict_utils.py | 10 -------- 3 files changed, 39 deletions(-) diff --git a/test/distributed/fsdp/test_fsdp_dtensor_state_dict.py b/test/distributed/fsdp/test_fsdp_dtensor_state_dict.py index 4ae7c3c797f8..f3eab4642c84 100644 --- a/test/distributed/fsdp/test_fsdp_dtensor_state_dict.py +++ b/test/distributed/fsdp/test_fsdp_dtensor_state_dict.py @@ -313,30 +313,6 @@ def test_raises_warning_or_errors(self): with FSDP.state_dict_type(model, StateDictType.LOCAL_STATE_DICT): optim_state_dict = FSDP.optim_state_dict(model, optim) - with self.assertLogs( - "torch.distributed.fsdp._state_dict_utils", level="WARNING" - ) as log: - with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT): - state_dict = model.state_dict() - self.assertEqual(len(log.records), 1) - self.assertEqual(len(log.output), 1) - self.assertIn( - "Found both state_dict_type FULL_STATE_DICT and device_mesh.", - log.output[0], - ) - - with self.assertLogs( - "torch.distributed.fsdp._optim_utils", level="WARNING" - ) as log: - with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT): - state_dict = FSDP.optim_state_dict(model, optim) - self.assertEqual(len(log.records), 1) - self.assertEqual(len(log.output), 1) - self.assertIn( - "Found both state_dict_type FULL_STATE_DICT and device_mesh.", - log.output[0], - ) - instantiate_parametrized_tests(TestFSDPWithDeviceMeshAndDTensor) if __name__ == "__main__": diff --git a/torch/distributed/fsdp/_optim_utils.py b/torch/distributed/fsdp/_optim_utils.py index 6e2525ce2af0..682a7f2b299a 100644 --- a/torch/distributed/fsdp/_optim_utils.py +++ b/torch/distributed/fsdp/_optim_utils.py @@ -2082,10 +2082,5 @@ def _set_optim_use_dtensor( "DeviceMesh is not compatible with LOCAL_STATE_DICT.", "Please set state_dict_type to SHARDED_STATE_DICT to get DTensor state_dict.", ) - elif state_dict_type == StateDictType.FULL_STATE_DICT: - logger.warning( - "Found both state_dict_type FULL_STATE_DICT and device_mesh. " # noqa: G004 - "Please set state_dict_type to SHARDED_STATE_DICT to get DTensor state_dict." - ) else: state_dict_settings.optim_state_dict_config._use_dtensor = True diff --git a/torch/distributed/fsdp/_state_dict_utils.py b/torch/distributed/fsdp/_state_dict_utils.py index 66840851d17e..09419e0ad27c 100644 --- a/torch/distributed/fsdp/_state_dict_utils.py +++ b/torch/distributed/fsdp/_state_dict_utils.py @@ -292,11 +292,6 @@ def _full_pre_state_dict_hook( """ if getattr(fsdp_state, "_device_mesh", False): parent_mesh = _mesh_resources.get_parent_mesh(fsdp_state._device_mesh) - if parent_mesh: - raise RuntimeError( - f"Found FSDP's device_mesh {fsdp_state._device_mesh} has a parent device_mesh {parent_mesh}.", - "We do not support FULL_STATE_DICT for 2D FSDP + TP. Please use FSDP SHARDED_STATE_DICT instead.", - ) _common_pre_state_dict_hook(module, fsdp_state) _common_unshard_pre_state_dict_hook( @@ -798,11 +793,6 @@ def _set_use_dtensor(fsdp_state: _FSDPState) -> None: "DeviceMesh is not compatible with LOCAL_STATE_DICT.", "Please set state_dict_type to SHARDED_STATE_DICT to get DTensor state_dict.", ) - elif state_dict_type == StateDictType.FULL_STATE_DICT: - logger.warning( - "Found both state_dict_type FULL_STATE_DICT and device_mesh. " # noqa: G004 - "Please set state_dict_type to SHARDED_STATE_DICT to get DTensor state_dict." - ) else: fsdp_state._state_dict_config._use_dtensor = True