File tree Expand file tree Collapse file tree 3 files changed +0
-39
lines changed Expand file tree Collapse file tree 3 files changed +0
-39
lines changed Original file line number Diff line number Diff line change @@ -313,30 +313,6 @@ def test_raises_warning_or_errors(self):
313
313
with FSDP .state_dict_type (model , StateDictType .LOCAL_STATE_DICT ):
314
314
optim_state_dict = FSDP .optim_state_dict (model , optim )
315
315
316
- with self .assertLogs (
317
- "torch.distributed.fsdp._state_dict_utils" , level = "WARNING"
318
- ) as log :
319
- with FSDP .state_dict_type (model , StateDictType .FULL_STATE_DICT ):
320
- state_dict = model .state_dict ()
321
- self .assertEqual (len (log .records ), 1 )
322
- self .assertEqual (len (log .output ), 1 )
323
- self .assertIn (
324
- "Found both state_dict_type FULL_STATE_DICT and device_mesh." ,
325
- log .output [0 ],
326
- )
327
-
328
- with self .assertLogs (
329
- "torch.distributed.fsdp._optim_utils" , level = "WARNING"
330
- ) as log :
331
- with FSDP .state_dict_type (model , StateDictType .FULL_STATE_DICT ):
332
- state_dict = FSDP .optim_state_dict (model , optim )
333
- self .assertEqual (len (log .records ), 1 )
334
- self .assertEqual (len (log .output ), 1 )
335
- self .assertIn (
336
- "Found both state_dict_type FULL_STATE_DICT and device_mesh." ,
337
- log .output [0 ],
338
- )
339
-
340
316
341
317
instantiate_parametrized_tests (TestFSDPWithDeviceMeshAndDTensor )
342
318
if __name__ == "__main__" :
Original file line number Diff line number Diff line change @@ -2082,10 +2082,5 @@ def _set_optim_use_dtensor(
2082
2082
"DeviceMesh is not compatible with LOCAL_STATE_DICT." ,
2083
2083
"Please set state_dict_type to SHARDED_STATE_DICT to get DTensor state_dict." ,
2084
2084
)
2085
- elif state_dict_type == StateDictType .FULL_STATE_DICT :
2086
- logger .warning (
2087
- "Found both state_dict_type FULL_STATE_DICT and device_mesh. " # noqa: G004
2088
- "Please set state_dict_type to SHARDED_STATE_DICT to get DTensor state_dict."
2089
- )
2090
2085
else :
2091
2086
state_dict_settings .optim_state_dict_config ._use_dtensor = True
Original file line number Diff line number Diff line change @@ -292,11 +292,6 @@ def _full_pre_state_dict_hook(
292
292
"""
293
293
if getattr (fsdp_state , "_device_mesh" , False ):
294
294
parent_mesh = _mesh_resources .get_parent_mesh (fsdp_state ._device_mesh )
295
- if parent_mesh :
296
- raise RuntimeError (
297
- f"Found FSDP's device_mesh { fsdp_state ._device_mesh } has a parent device_mesh { parent_mesh } ." ,
298
- "We do not support FULL_STATE_DICT for 2D FSDP + TP. Please use FSDP SHARDED_STATE_DICT instead." ,
299
- )
300
295
301
296
_common_pre_state_dict_hook (module , fsdp_state )
302
297
_common_unshard_pre_state_dict_hook (
@@ -798,11 +793,6 @@ def _set_use_dtensor(fsdp_state: _FSDPState) -> None:
798
793
"DeviceMesh is not compatible with LOCAL_STATE_DICT." ,
799
794
"Please set state_dict_type to SHARDED_STATE_DICT to get DTensor state_dict." ,
800
795
)
801
- elif state_dict_type == StateDictType .FULL_STATE_DICT :
802
- logger .warning (
803
- "Found both state_dict_type FULL_STATE_DICT and device_mesh. " # noqa: G004
804
- "Please set state_dict_type to SHARDED_STATE_DICT to get DTensor state_dict."
805
- )
806
796
else :
807
797
fsdp_state ._state_dict_config ._use_dtensor = True
808
798
You can’t perform that action at this time.
0 commit comments