8000 [DCP][state_dict] DCP state_dict cannot correctly find FQN when the leaf module is wrapped by FSDP by fegin · Pull Request #115592 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[DCP][state_dict] DCP state_dict cannot correctly find FQN when the leaf module is wrapped by FSDP #115592

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions test/distributed/checkpoint/test_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
import copy
import sys
from itertools import chain
from typing import Callable
from typing import Callable, Tuple

import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed._composable import fully_shard, replicate
from torch.distributed._shard.sharded_tensor import ShardedTensor
from torch.distributed._tensor import DTensor, init_device_mesh
Expand Down Expand Up @@ -133,7 +134,12 @@ def _test_save_load(
self._verify_osd(model, optim, osd, dist_osd)

def _test_fsdp(
self, use_orig_params: bool, use_composable: bool, use_dtensor: bool
self,
*,
use_orig_params: bool,
use_composable: bool,
use_dtensor: bool,
wrapping: Tuple[nn.Module] = (),
) -> None:
if not use_orig_params and use_composable:
return
Expand All @@ -149,23 +155,27 @@ def init_model_optim():
orig_model = CompositeParamModel(device=torch.device("cuda"))
orig_optim = torch.optim.Adam(orig_model.parameters(), lr=1e-3)
copy_optim = torch.optim.Adam(orig_model.parameters(), lr=1e-3)
if wrapping:
strategy = set(wrapping)
else:
strategy = {UnitModule}
if use_composable:
dist_model = fully_shard(
copy.deepcopy(orig_model), policy=ModuleWrapPolicy({UnitModule})
copy.deepcopy(orig_model), policy=ModuleWrapPolicy(strategy)
)
else:
if use_dtensor:
device_mesh = init_device_mesh("cuda", (self.world_size,))
dist_model = FSDP(
copy.deepcopy(orig_model),
auto_wrap_policy=ModuleWrapPolicy({UnitModule}),
auto_wrap_policy=ModuleWrapPolicy(strategy),
use_orig_params=use_orig_params,
device_mesh=device_mesh,
)
else:
dist_model = FSDP(
copy.deepcopy(orig_model),
auto_wrap_policy=ModuleWrapPolicy({UnitModule}),
auto_wrap_policy=ModuleWrapPolicy(strategy),
use_orig_params=use_orig_params,
)

Expand All @@ -182,6 +192,7 @@ def test_fsdp(self) -> None:
"use_orig_params": [True, False],
"use_composable": [True, False],
"use_dtensor": [True, False],
"wrapping": [tuple(), (nn.Linear, UnitModule)],
},
self._test_fsdp,
)
Expand Down
2 changes: 1 addition & 1 deletion torch/distributed/checkpoint/state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def _get_fqns(model: nn.Module, name: str, skip_ddp_prefix: bool = True) -> FQNS
if not skip_ddp_prefix:
fqn_obj_names.append(curr_obj_name)
elif isinstance(curr_obj, FSDP):
if obj_names[i + 1] == FLAT_PARAM:
if i < len(obj_names) - 1 and obj_names[i + 1] == FLAT_PARAM:
prefix = ".".join(fqn_obj_names)
flat_param = getattr(curr_obj, FLAT_PARAM)
if prefix:
Expand Down
0