8000 [DCP][state_dict] DCP state_dict cannot correctly find FQN when the l… · pytorch/pytorch@d954ef2 · GitHub
[go: up one dir, main page]

Skip to content

Commit d954ef2

Browse files
feginpytorchmergebot
authored andcommitted
[DCP][state_dict] DCP state_dict cannot correctly find FQN when the leaf module is wrapped by FSDP (#115592)
Summary: The original logic has an incorrect assumption that there is at one object name left when traversing the module tree. This is not correct when the leaf module is wrapped by FSDP. Test Plan: CI Differential Revision: D52049293 Pull Request resolved: #115592 Approved by: https://github.com/wz337
1 parent 0ff155f commit d954ef2

File tree

2 files changed

+17
-6
lines changed

2 files changed

+17
-6
lines changed

test/distributed/checkpoint/test_state_dict.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@
33
import copy
44
import sys
55
from itertools import chain
6-
from typing import Callable
6+
from typing import Callable, Tuple
77

88
import torch
99
import torch.distributed as dist
10+
import torch.nn as nn
1011
from torch.distributed._composable import fully_shard, replicate
1112
from torch.distributed._shard.sharded_tensor import ShardedTensor
1213
from torch.distributed._tensor import DTensor, init_device_mesh
@@ -133,7 +134,12 @@ def _test_save_load(
133134
self._verify_osd(model, optim, osd, dist_osd)
134135

135136
def _test_fsdp(
136-
self, use_orig_params: bool, use_composable: bool, use_dtensor: bool
137+
self,
138+
*,
139+
use_orig_params: bool,
140+
use_composable: bool,
141+
use_dtensor: bool,
142+
wrapping: Tuple[nn.Module] = (),
137143
) -> None:
138144
if not use_orig_params and use_composable:
139145
return
@@ -149,23 +155,27 @@ def init_model_optim():
149155
orig_model = CompositeParamModel(device=torch.device("cuda"))
150156
orig_optim = torch.optim.Adam(orig_model.parameters(), lr=1e-3)
151157
copy_optim = torch.optim.Adam(orig_model.parameters(), lr=1e-3)
158+
if wrapping:
159+
strategy = set(wrapping)
160+
else:
161+
strategy = {UnitModule}
152162
if use_composable:
153163
dist_model = fully_shard(
154-
copy.deepcopy(orig_model), policy=ModuleWrapPolicy({UnitModule})
164+
copy.deepcopy(orig_model), policy=ModuleWrapPolicy(strategy)
155165
)
156166
else:
157167
if use_dtensor:
158168
device_mesh = init_device_mesh("cuda", (self.world_size,))
159169
dist_model = FSDP(
160170
copy.deepcopy(orig_model),
161-
auto_wrap_policy=ModuleWrapPolicy({UnitModule}),
171+
auto_wrap_policy=ModuleWrapPolicy(strategy),
162172
use_orig_params=use_orig_params,
163173
device_mesh=device_mesh,
164174
)
165175
else:
166176
dist_model = FSDP(
167177
copy.deepcopy(orig_model),
168-
auto_wrap_policy=ModuleWrapPolicy({UnitModule}),
178+
auto_wrap_policy=ModuleWrapPolicy(strategy),
169179
use_orig_params=use_orig_params,
170180
)
171181

@@ -182,6 +192,7 @@ def test_fsdp(self) -> None:
182192
"use_orig_params": [True, False],
183193
"use_composable": [True, False],
184194
"use_dtensor": [True, False],
195+
"wrapping": [tuple(), (nn.Linear, UnitModule)],
185196
},
186197
self._test_fsdp,
187198
)

torch/distributed/checkpoint/state_dict.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ def _get_fqns(model: nn.Module, name: str, skip_ddp_prefix: bool = True) -> FQNS
157157
if not skip_ddp_prefix:
158158
fqn_obj_names.append(curr_obj_name)
159159
elif isinstance(curr_obj, FSDP):
160-
if obj_names[i + 1] == FLAT_PARAM:
160+
if i < len(obj_names) - 1 and obj_names[i + 1] == FLAT_PARAM:
161161
prefix = ".".join(fqn_obj_names)
162162
flat_param = getattr(curr_obj, FLAT_PARAM)
163163
if prefix:

0 commit comments

Comments
 (0)
0