8000 [FSDP2+TP] Disable 2D state_dict by wz337 · Pull Request #129519 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[FSDP2+TP] Disable 2D state_dict #129519

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
disallow_state_dict_call_on_2d
  • Loading branch information
wz337 committed Jun 28, 2024
commit 790e7862fe7c79e7397f9fbfd09c2de3d4b010e3
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,10 @@ def test_1d_state_dict_cpu_offload(self):
for name, dtensor in state_dict.items():
self.assertEqual(dtensor.device.type, "cpu")

# Temporarily disable 2D state dict test, while strided sharding is being devleoped.
# TODO: re-enable this test once 2d state_dict is ready.
@skip_if_lt_x_gpu(2)
def test_2d_state_dict_save_load(self):
def _temp_disable_test_2d_state_dict_save_load(self):
dp_size = 2
global_mesh = init_device_mesh(
"cuda", (dp_size, self.world_size // dp_size), mesh_dim_names=("dp", "tp")
Expand Down
24 changes: 23 additions & 1 deletion test/distributed/_composable/fsdp/test_fully_shard_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -1069,9 +1069,31 @@ def test_tp_with_fsdp_offloading(self):
optim.step()
ref_optim.step()

# TODO: remove this test when 2d state_dict is ready.
@skip_if_lt_x_gpu(2)
@skipIfRocm
def test_raise_not_implemented_state_dict_if_2d(self):
def parallelize(_model: Transformer, mesh: DeviceMesh, use_seq_parallel: bool):
_model = Transformer.parallelize(_model, mesh["tp"], use_seq_parallel)
for layer in _model.layers:
fully_shard(layer, mesh=mesh["dp"])
fully_shard(_model, mesh=mesh["dp"])
return _model

global_mesh = self.init_global_mesh()
seed = 42
torch.manual_seed(seed)
model_args = ModelArgs(dropout_p=0.0)
model = parallelize(Transformer(model_args), global_mesh, True)

with self.assertRaisesRegex(NotImplementedError, "2D"):
get_model_state_dict(model)

# Temporarily disable 2D state dict test, while strided sharding is being devleoped.
# TODO: re-enable this test once 2d state_dict is ready.
@skip_if_lt_x_gpu(2)
@with_temp_dir
def test_train_parity_2d_transformer_checkpoint_resume(self):
def _temp_disable_test_train_parity_2d_transformer_checkpoint_resume(self):
"""
Tests train parity of a 2D transformer without checkpointing against a
2D transformer with a checkpoint save/load.
Expand Down
16 changes: 16 additions & 0 deletions torch/distributed/_composable/fsdp/_fsdp_param_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,22 @@ def __init__(
# partial reduce output (only reduce-scattered but not all-reduced)
self._partial_reduce_output: Optional[torch.Tensor] = None

# TODO: remove this hook and hook register once 2D state dict is supported.
def _raise_not_implemented_if_2d(*args: Any, **kwargs: Any) -> None:
raise NotImplementedError(
"2D state_dict is under development. Please check "
"https://github.com/pytorch/pytorch/issues/129627 for more details."
)

modules_with_2d_params: Set[nn.Module] = set()
for fsdp_param in self.fsdp_params:
module = fsdp_param._module_info.module
if len(fsdp_param._spmd_placements) > 1:
modules_with_2d_params.add(module)
for module in modules_with_2d_params:
module.register_state_dict_pre_hook(_raise_not_implemented_if_2d)
module._register_load_state_dict_pre_hook(_raise_not_implemented_if_2d)

# Initialization #
def _init_mp_dtypes(self) -> None:
for fsdp_param in self.fsdp_params:
Expand Down
Loading
0