1
1
import re
2
- from typing import Callable , Dict , Optional , Set , Union
2
+ from typing import Callable , Dict , Set , Optional , Union
3
3
4
4
import torch .fx
5
5
from torch .fx .node import map_arg
@@ -22,7 +22,6 @@ def __init__(
22
22
graph : torch .fx .Graph ,
23
23
const_subgraph : Optional [torch .fx .Graph ] = None ,
24
24
fx_const_folded_attrs_name : str = None ,
25
- device_for_folded_attrs : str = "cpu" ,
26
25
):
27
26
# In init, we set graph's owning module to root which will make graph's
28
27
# owning module be None because graph already have a owning module. We
@@ -37,7 +36,6 @@ def __init__(
37
36
)
38
37
self .has_folding_been_run = False
39
38
self .fx_const_folded_attrs_name = fx_const_folded_attrs_name
40
- self .device_for_folded_attrs = device_for_folded_attrs
41
39
42
40
def __call__ (self , * args , ** kwargs ):
43
41
if not self .has_folding_been_run :
@@ -60,19 +58,12 @@ def run_folding(self):
60
58
# subgraphs output a single Tensor while multiple outputs are returned as
61
59
# Tuple[Tensor,].
62
60
folded_attrs = self .const_subgraph_module ()
63
-
64
- def _create_param (i ):
65
- return torch .nn .Parameter (
66
- i
67
- if not isinstance (i , int )
68
- else torch .Tensor ([i ]).to (device = self .device_for_folded_attrs ),
69
- requires_grad = i .requires_grad if isinstance (i , torch .Tensor ) else False ,
70
- )
71
-
72
61
params = (
73
- torch .nn .ParameterList ([_create_param (i ) for i in folded_attrs ])
62
+ torch .nn .ParameterList ([torch .nn .Parameter (
63
+ i if not isinstance (i , int ) else torch .Tensor ([i ]).cuda ()) for i in folded_attrs ])
74
64
if isinstance (folded_attrs , tuple )
75
- else _create_param (folded_attrs )
65
+ else torch .nn .Parameter (
66
+ folded_attrs if not isinstance (folded_attrs , int ) else torch .Tensor ([folded_attrs ]).cuda ())
76
67
)
77
68
setattr (self , self .fx_const_folded_attrs_name , params )
78
69
@@ -144,8 +135,7 @@ def get_unique_attr_name_in_module(mod_traced: torch.fx.GraphModule, name: str)
144
135
145
136
def split_const_subgraphs (
146
137
module : Union [torch .nn .Module , torch .fx .GraphModule ],
147
- skip_folding_node_fn : Optional [Callable [[torch .fx .Node ], bool ]] = None ,
148
- device_for_folded_attrs : str = "cpu" ,
138
+ skip_folding_node_fn : Optional [Callable [[torch .fx .Node ], bool ]] = None
149
139
) -> FoldedGraphModule :
150
140
"""
151
141
Looks through `module` for any nodes that have all constant attribute inputs
@@ -171,9 +161,7 @@ def split_const_subgraphs(
171
161
172
162
# If the node itself is constant, or all of its inputs are constant,
173
163
# then tag it as constant.
174
- if node .op != "get_attr" and not set (node .all_input_nodes ).issubset (
175
- const_nodes
176
- ):
164
+ if node .op != "get_attr" and not set (node .all_input_nodes ).issubset (const_nodes ):
177
165
continue
178
166
179
167
# If provided skip folding function says to skip, then skip.
@@ -280,9 +268,5 @@ def mod_partition(node: torch.fx.Node):
280
268
_inline_module (split , non_const_mod_name )
281
269
282
270
return FoldedGraphModule (
283
- split ,
284
- split .graph ,
285
- root_const_gm .graph ,
286
- fx_const_folded_attrs_name ,
287
- device_for_folded_attrs ,
271
+ split , split .graph , root_const_gm .graph , fx_const_folded_attrs_name
288
272
)
0 commit comments