File tree 3 files changed +2
-14
lines changed 3 files changed +2
-14
lines changed Original file line number Diff line number Diff line change 73
73
from torch .fx .graph import _PyTreeCodeGen , _PyTreeInfo
74
74
from torch .utils ._sympy .value_ranges import ValueRangeError , ValueRanges
75
75
76
- from .exported_program import (
77
- CallSpec ,
78
- )
79
76
from .passes .add_runtime_assertions_for_constraints_pass import (
80
77
_AddRuntimeAssertionsForInlineConstraintsPass ,
81
78
)
Original file line number Diff line number Diff line change 1
- import dataclasses
2
- from typing import Optional
3
1
import warnings
4
2
5
3
6
4
import torch
7
5
import torch .fx
8
- import torch .utils ._pytree as pytree
9
6
10
7
11
8
# TODO(ycao): This is added to avoid breaking existing code temporarily.
32
29
]
33
30
34
31
35
- # Information to maintain user calling/returning specs
36
- @dataclasses .dataclass
37
- class CallSpec :
38
- in_spec : Optional [pytree .TreeSpec ]
39
- out_spec : Optional [pytree .TreeSpec ]
40
-
41
-
42
32
def _create_graph_module_for_export (root , graph ):
43
33
try :
44
34
gm = torch .fx .GraphModule (root , graph )
Original file line number Diff line number Diff line change 2
2
import dataclasses
3
3
import functools
4
4
import types
5
+ from collections import namedtuple
5
6
from typing import (
6
7
Any ,
7
8
Callable ,
@@ -213,7 +214,7 @@ def example_inputs(self):
213
214
@property
214
215
@compatibility (is_backward_compatible = False )
215
216
def call_spec (self ):
216
- from torch . _export . exported_program import CallSpec
217
+ CallSpec = namedtuple ( " CallSpec" , [ "in_spec" , "out_spec" ])
217
218
218
219
if len (self .module_call_graph ) == 0 :
219
220
return CallSpec (in_spec = None , out_spec = None )
You can’t perform that action at this time.
0 commit comments