8000 remove act checkpoint tag · pytorch/pytorch@f7dc14e · GitHub
[go: up one dir, main page]

Skip to content

Commit f7dc14e

Browse files
committed
remove act checkpoint tag
1 parent 7706cd7 commit f7dc14e

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

torch/distributed/checkpoint/state_dict.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -152,8 +152,11 @@ def _get_fqns(
152152
Returns:
153153
The canonical FQNs based on the model traversal.
154154
"""
155+
156+
# Remove the checkpoint prefix, if it exists.
157+
name = name.replace(_CHECKPOINT_PREFIX, "")
155158
if "." not in name:
156-
return {name.replace(_CHECKPOINT_PREFIX, "")}
159+
return {name}
157160

158161
obj_names = name.split(".")
159162
fqn_obj_names = []
@@ -170,8 +173,6 @@ def _get_fqns(
170173
flat_param = getattr(curr_obj, FLAT_PARAM)
171174
if prefix:
172175
prefix = f"{prefix}."
173-
# FSDP already handles removal of checkpoint prefix, so we can return
174-
# directly
175176
return {f"{prefix}{fqn}" for fqn in flat_param._fqns}
176177
curr_obj = getattr(curr_obj, FSDP_WRAPPED_MODULE)
177178
if curr_obj_name != FSDP_WRAPPED_MODULE:

0 commit comments

Comments
 (0)
0