1+ import logging
12import os
23import warnings
34import zipfile
5253
5354PassType = Callable [[torch .fx .GraphModule ], Optional [PassResult ]]
5455
56+ log : logging .Logger = logging .getLogger (__name__ )
57+
5558
5659@deprecated (
5760 "`torch.export.export_for_training` is deprecated and will be removed in PyTorch 2.10. "
@@ -440,7 +443,8 @@ def load(
440443 f ,
441444 expected_opset_version = expected_opset_version ,
442445 )
443- except RuntimeError :
446+ except RuntimeError as e :
447+ log .warning ("Ran into the following error when deserializing: %s" , e )
444448 pt2_contents = PT2ArchiveContents ({}, {}, {})
445449
446450 if len (pt2_contents .exported_programs ) > 0 or len (pt2_contents .extra_files ) > 0 :
@@ -450,10 +454,18 @@ def load(
450454 return pt2_contents .exported_programs ["model" ]
451455
452456 # TODO: For backward compatibility, we support loading a zip file from 2.7. Delete this path in 2.9(?)
453- warnings .warn (
454- "This version of file is deprecated. Please generate a new pt2 saved file."
455- )
456457 with zipfile .ZipFile (f , "r" ) as zipf :
458+ if "version" not in zipf .namelist ():
459+ raise RuntimeError (
460+ "We ran into an error when deserializing the saved file. "
461+ "Please check the warnings above for possible errors. "
462+ )
463+
464+ log .warning (
465+ "Trying to deserialize for the older format. This version of file is "
466+ "deprecated. Please generate a new pt2 saved file."
467+ )
468+
457469 # Check the version
458470 version = zipf .read ("version" ).decode ().split ("." )
459471 from torch ._export .serde .schema import (
0 commit comments