8000 [export] Make draft_export public (#153219) · pytorch/pytorch@d51bc27 · GitHub
[go: up one dir, main page]

Skip to content

Commit d51bc27

Browse files
angelayipytorchmergebot
authored andcommitted
[export] Make draft_export public (#153219)
Fixes #ISSUE_NUMBER Pull Request resolved: #153219 Approved by: https://github.com/pianpwk
1 parent b15b870 commit d51bc27

File tree

5 files changed

+50
-22
lines changed

5 files changed

+50
-22
lines changed

docs/source/draft_export.rst

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,7 @@ To call ``draft-export``, we can replace the ``torch.export`` line with the foll
105105

106106
::
107107

108-
from torch.export._draft_export import draft_export
109-
ep = draft_export(M(), inp)
108+
ep = torch.export.draft_export(M(), inp)
110109

111110
``ep`` is a valid ExportedProgram which can now be passed through further environments!
112111

docs/source/export.rst

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -790,10 +790,9 @@ API Reference
790790
.. autofunction:: export
791791
.. autofunction:: save
792792
.. autofunction:: load
793+
.. autofunction:: draft_export
793794
.. autofunction:: register_dataclass
794795
.. autoclass:: torch.export.dynamic_shapes.Dim
795-
.. autofunction:: torch.export.exported_program.default_decompositions
796-
.. autofunction:: dims
797796
.. autoclass:: torch.export.dynamic_shapes.ShapesCollection
798797

799798
.. automethod:: dynamic_shapes
@@ -805,22 +804,21 @@ API Reference
805804
.. automethod:: verify
806805

807806
.. autofunction:: torch.export.dynamic_shapes.refine_dynamic_shapes_from_suggested_fixes
808-
.. autoclass:: Constraint
809807
.. autoclass:: ExportedProgram
810808

809+
.. attribute:: graph
810+
.. attribute:: graph_signature
811+
.. attribute:: state_dict
812+
.. attribute:: constants
813+
.. attribute:: range_constraints
814+
.. attribute:: module_call_graph
815+
.. attribute:: example_inputs
811816
.. automethod:: module
812-
.. automethod:: buffers
813-
.. automethod:: named_buffers
814-
.. automethod:: parameters
815-
.. automethod:: named_parameters
816817
.. automethod:: run_decompositions
817818

818-
.. autoclass:: ExportBackwardSignature
819819
.. autoclass:: ExportGraphSignature
820820
.. autoclass:: ModuleCallSignature
821821
.. autoclass:: ModuleCallEntry
822-
823-
824822
.. automodule:: torch.export.decomp_utils
825823
.. autoclass:: CustomDecompTable
826824

@@ -830,22 +828,25 @@ API Reference
830828
.. automethod:: materialize
831829
.. automethod:: pop
832830
.. automethod:: update
831+
.. autofunction:: torch.export.exported_program.default_decompositions
833832

834833
.. automodule:: torch.export.exported_program
835834
.. automodule:: torch.export.graph_signature
835+
.. autoclass:: ExportGraphSignature
836+
837+
.. automethod:: replace_all_uses
838+
.. automethod:: get_replace_hook
839+
840+
.. autoclass:: ExportBackwardSignature
836841
.. autoclass:: InputKind
837842
.. autoclass:: InputSpec
838843
.. autoclass:: OutputKind
839844
.. autoclass:: OutputSpec
840845
.. autoclass:: SymIntArgument
841846
.. autoclass:: SymBoolArgument
842847
.. autoclass:: SymFloatArgument
843-
.. autoclass:: ExportGraphSignature
844-
845-
.. automethod:: replace_all_uses
846-
.. automethod:: get_replace_hook
847848

848-
.. autoclass:: torch.export.graph_signature.CustomObjArgument
849+
.. autoclass:: CustomObjArgument
849850

850851
.. py:module:: torch.export.dynamic_shapes
851852
.. py:module:: torch.export.custom_ops

test/export/test_draft_export.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55

66
import torch
77
from torch._subclasses.fake_tensor import FakeTensorMode
8-
from torch.export import Dim, export
9-
from torch.export._draft_export import draft_export, FailureType
8+
from torch.export import Dim, draft_export, export
9+
from torch.export._draft_export import FailureType
1010
from torch.fx.experimental.symbolic_shapes import ShapeEnv
1111
from torch.testing import FileCheck
1212
from torch.testing._internal.common_utils import IS_WINDOWS, run_tests, TestCase

torch/export/__init__.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
"FlatArgsAdapter",
5353
"UnflattenedModule",
5454
"AdditionalInputs",
55+
"draft_export",
5556
]
5657

5758
# To make sure export specific custom ops are loaded
@@ -518,6 +519,32 @@ def load(
518519
return ep
519520

520521

522+
def draft_export(
523+
mod: torch.nn.Module,
524+
args: tuple[Any, ...],
525+
kwargs: Optional[dict[str, Any]] = None,
526+
*,
527+
dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]] = None,
528+
preserve_module_call_signature: tuple[str, ...] = (),
529+
strict: bool = False,
530+
) -> ExportedProgram:
531+
"""
532+
A version of torch.export.export which is designed to consistently produce
533+
an ExportedProgram, even if there are potential soundness issues, and to
534+
generate a report listing the issues found.
535+
"""
536+
from ._draft_export import draft_export
537+
538+
return draft_export(
539+
mod=mod,
540+
args=args,
541+
kwargs=kwargs,
542+
dynamic_shapes=dynamic_shapes,
543+
preserve_module_call_signature=preserve_module_call_signature,
544+
strict=strict,
545+
)
546+
547+
521548
def register_dataclass(
522549
cls: type[Any],
523550
*,

torch/export/_draft_export.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,10 @@
1717
insert_custom_op_guards,
1818
OpProfile,
1919
)
20-
from torch.export import ExportedProgram
21-
from torch.export._trace import _export
22-
from torch.export.dynamic_shapes import _DimHint, _DimHintType, Dim
20+
21+
from ._trace import _export
22+
from .dynamic_shapes import _DimHint, _DimHintType, Dim
23+
from .exported_program import ExportedProgram
2324

2425

2526
log = logging.getLogger(__name__)

0 commit comments

Comments
 (0)
0