File tree Expand file tree Collapse file tree 1 file changed +4
-3
lines changed
torch/distributed/checkpoint Expand file tree Collapse file tree 1 file changed +4
-3
lines changed Original file line number Diff line number Diff line change @@ -205,7 +205,6 @@ def dcp_to_torch_save(
205
205
To avoid OOM, it's recommended to only run this function on a single rank.
206
206
"""
207
207
sd : STATE_DICT_TYPE = {}
208
-
209
208
_load_state_dict (
210
209
sd ,
211
210
storage_reader = FileSystemReader (dcp_checkpoint_dir ),
@@ -263,13 +262,15 @@ class FormatMode(Enum):
263
262
checkpoint_missing_warning = (
264
263
f"No checkpoint found at { args .src } . Skipping conversion."
265
264
)
266
- if args .mode == FormatMode .TORCH_TO_DCP :
265
+ if args .mode == FormatMode .TORCH_TO_DCP . value :
267
266
if os .path .isfile (args .src ):
268
267
torch_save_to_dcp (args .src , args .dst )
269
268
else :
270
269
print (checkpoint_missing_warning )
271
- elif args .mode == FormatMode .DCP_TO_TORCH :
270
+ elif args .mode == FormatMode .DCP_TO_TORCH . value :
272
271
if os .path .isdir (args .src ):
273
272
dcp_to_torch_save (args .src , args .dst )
274
273
else :
275
274
print (checkpoint_missing_warning )
275
+ else :
276
+ raise ValueError (f"Unknown conversion mode: { args .mode } " )
You can’t perform that action at this time.
0 commit comments