File tree Expand file tree Collapse file tree 1 file changed +4
-3
lines changed
torch/distributed/checkpoint Expand file tree Collapse file tree 1 file changed +4
-3
lines changed Original file line number Diff line number Diff line change @@ -152,8 +152,11 @@ def _get_fqns(
152
152
Returns:
153
153
The canonical FQNs based on the model traversal.
154
154
"""
155
+
156
+ # Remove the checkpoint prefix, if it exists.
157
+ name = name .replace (_CHECKPOINT_PREFIX , "" )
155
158
if "." not in name :
156
- return {name . replace ( _CHECKPOINT_PREFIX , "" ) }
159
+ return {name }
157
160
158
161
obj_names = name .split ("." )
159
162
fqn_obj_names = []
@@ -170,8 +173,6 @@ def _get_fqns(
170
173
flat_param = getattr (curr_obj , FLAT_PARAM )
171
174
if prefix :
172
175
prefix = f"{ prefix } ."
173
- # FSDP already handles removal of checkpoint prefix, so we can return
174
- # directly
175
176
return {f"{ prefix } { fqn } " for fqn in flat_param ._fqns }
176
177
curr_obj = getattr (curr_obj , FSDP_WRAPPED_MODULE )
177
178
if curr_obj_name != FSDP_WRAPPED_MODULE :
You can’t perform that action at this time.
0 commit comments