File tree Expand file tree Collapse file tree 1 file changed +4
-3
lines changed
torch/distributed/checkpoint Expand file tree Collapse file tree 1 file changed +4
-3
lines changed Original file line number Diff line number Diff line change @@ -240,7 +240,6 @@ def dcp_to_torch_save(
240
240
To avoid OOM, it's recommended to only run this function on a single rank.
241
241
"""
242
242
sd : STATE_DICT_TYPE = {}
243
-
244
243
_load_state_dict (
245
244
sd ,
246
245
storage_reader = FileSystemReader (dcp_checkpoint_dir ),
@@ -298,13 +297,15 @@ class FormatMode(Enum):
298
297
checkpoint_missing_warning = (
299
298
f"No checkpoint found at { args .src } . Skipping conversion."
300
299
)
301
- if args .mode == FormatMode .TORCH_TO_DCP :
300
+ if args .mode == FormatMode .TORCH_TO_DCP . value :
302
301
if os .path .isfile (args .src ):
303
302
torch_save_to_dcp (args .src , args .dst )
304
303
else :
305
304
print (checkpoint_missing_warning )
306
- elif args .mode == FormatMode .DCP_TO_TORCH :
305
+ elif args .mode == FormatMode .DCP_TO_TORCH . value :
307
306
if os .path .isdir (args .src ):
308
307
dcp_to_torch_save (args .src , args .dst )
309
308
else :
310
309
print (checkpoint_missing_warning )
310
+ else :
311
+ raise ValueError (f"Unknown conversion mode: { args .mode } " )
You can’t perform that action at this time.
0 commit comments