8000 Back out "[const_fold] Set requires_grad based on the folded tensor; … · pytorch/pytorch@b8dcd6a · GitHub
[go: up one dir, main page]

Skip to content

Commit b8dcd6a

Browse files
singlaiiitfacebook-github-bot
authored andcommitted
Back out "[const_fold] Set requires_grad based on the folded tensor; add device_for_folding option" (#79655)
Summary: Pull Request resolved: #79655 Reviewed By: yinghai, khabinov Differential Revision: D37192230 fbshipit-source-id: 4ae1863e28ff30c23795c4a7a79461dd76484143
1 parent 5be938d commit b8dcd6a

File tree

2 files changed

+8
-78
lines changed

2 files changed

+8
-78
lines changed

test/fx/test_fx_const_fold.py

Lines changed: 0 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -5,25 +5,10 @@
55
import torch
66
import torch.fx
77
from torch.fx.experimental import const_fold
8-
from torch.fx.passes.shape_prop import _extract_tensor_metadata, ShapeProp
98
from torch.testing._internal.common_utils import TestCase
109

1110

1211
class TestConstFold(TestCase):
13-
def _get_attr(self, node):
14-
mod = node.graph.owning_module
15-
target = str(node.target)
16-
target_atoms = target.split(".")
17-
curr_obj = mod
18-
for i, atom in enumerate(target_atoms):
19-
if not hasattr(curr_obj, atom):
20-
raise RuntimeError(
21-
f"Node referenced nonexistent target '{'.'.join(target_atoms[:i])}'; "
22-
f" original whole target: '{target}'"
23-
)
24-
curr_obj = getattr(curr_obj, atom)
25-
return curr_obj
26-
2712
def _verify_const_fold_mod(self, mod_folded: const_fold.FoldedGraphModule):
2813
self.assertTrue(mod_folded.const_subgraph_module is not None)
2914

@@ -668,42 +653,3 @@ def forward(self, x):
668653
# Now run both folded and non-folded to check results equal.
669654
inp = torch.randn(4, 4)
670655
self.assertTrue(torch.equal(mod_folded(inp), mod(inp)))
671-
672-
def test_const_fold_tensor_meta(self):
673-
self._test_const_fold_tensor_meta(True)
674-
self._test_const_fold_tensor_meta(False)
675-
676-
def _test_const_fold_tensor_meta(self, requires_grad):
677-
"""
678-
Verify tensor_meta is handled correctly.
679-
"""
680-
681-
class ConstFoldTestModule(torch.nn.Module):
682-
def __init__(self):
683-
super().__init__()
684-
self.attr_1 = torch.nn.Parameter(torch.tensor([[-0.9]]), requires_grad)
685-
self.attr_2 = torch.nn.Parameter(torch.tensor([[17.1]]), requires_grad)
686-
687-
def forward(self, x, y):
688-
a = self.attr_1 + self.attr_1
689-
x = x - a
690-
return x * y + self.attr_2
691-
692-
mod = ConstFoldTestModule()
693-
gm = torch.fx.symbolic_trace(mod)
694-
in_x, in_y = torch.tensor([[-0.45]]), torch.tensor([0.9])
695-
ShapeProp(gm).propagate(in_x, in_y)
696-
mod_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(gm)
697-
self._verify_const_fold_mod(mod_folded)
698-
699-
mod_folded.run_folding()
700-
701-
for n in mod_folded.graph.nodes:
702-
if n.op == "get_attr":
703-
attr = self._get_attr(n)
704-
self.assertEquals(_extract_tensor_metadata(attr), n.meta["tensor_meta"])
705-
706-
# Now run both folded and non-folded to check results equal.
707-
base_result = mod(in_x, in_y)
708-
fold_result = mod_folded(in_x, in_y)
709-
self.assertTrue(torch.equal(fold_result, base_result))

torch/fx/experimental/const_fold.py

Lines changed: 8 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import re
2-
from typing import Callable, Dict, Optional, Set, Union
2+
from typing import Callable, Dict, Set, Optional, Union
33

44
import torch.fx
55
from torch.fx.node import map_arg
@@ -22,7 +22,6 @@ def __init__(
2222
graph: torch.fx.Graph,
2323
const_subgraph: Optional[torch.fx.Graph] = None,
2424
fx_const_folded_attrs_name: str = None,
25-
device_for_folded_attrs: str = "cpu",
2625
):
2726
# In init, we set graph's owning module to root which will make graph's
2827
# owning module be None because graph already have a owning module. We
@@ -37,7 +36,6 @@ def __init__(
3736
)
3837
self.has_folding_been_run = False
3938
self.fx_const_folded_attrs_name = fx_const_folded_attrs_name
40-
self.device_for_folded_attrs = device_for_folded_attrs
4139

4240
def __call__(self, *args, **kwargs):
4341
if not self.has_folding_been_run:
@@ -60,19 +58,12 @@ def run_folding(self):
6058
# subgraphs output a single Tensor while multiple outputs are returned as
6159
# Tuple[Tensor,].
6260
folded_attrs = self.const_subgraph_module()
63-
64-
def _create_param(i):
65-
return torch.nn.Parameter(
66-
i
67-
if not isinstance(i, int)
68-
else torch.Tensor([i]).to(device=self.device_for_folded_attrs),
69-
requires_grad=i.requires_grad if isinstance(i, torch.Tensor) else False,
70-
)
71-
7261
params = (
73-
torch.nn.ParameterList([_create_param(i) for i in folded_attrs])
62+
torch.nn.ParameterList([torch.nn.Parameter(
63+
i if not isinstance(i, int) else torch.Tensor([i]).cuda()) for i in folded_attrs])
7464
if isinstance(folded_attrs, tuple)
75-
else _create_param(folded_attrs)
65+
else torch.nn.Parameter(
66+
folded_attrs if not isinstance(folded_attrs, int) else torch.Tensor([folded_attrs]).cuda())
7667
)
7768
setattr(self, self.fx_const_folded_attrs_name, params)
7869

@@ -144,8 +135,7 @@ def get_unique_attr_name_in_module(mod_traced: torch.fx.GraphModule, name: str)
144135

145136
def split_const_subgraphs(
146137
module: Union[torch.nn.Module, torch.fx.GraphModule],
147-
skip_folding_node_fn: Optional[Callable[[torch.fx.Node], bool]] = None,
148-
device_for_folded_attrs: str = "cpu",
138+
skip_folding_node_fn: Optional[Callable[[torch.fx.Node], bool]] = None
149139
) -> FoldedGraphModule:
150140
"""
151141
Looks through `module` for any nodes that have all constant attribute inputs
@@ -171,9 +161,7 @@ def split_const_subgraphs(
171161

172162
# If the node itself is constant, or all of its inputs are constant,
173163
# then tag it as constant.
174-
if node.op != "get_attr" and not set(node.all_input_nodes).issubset(
175-
const_nodes
176-
):
164+
if node.op != "get_attr" and not set(node.all_input_nodes).issubset(const_nodes):
177165
continue
178166

179167
# If provided skip folding function says to skip, then skip.
@@ -280,9 +268,5 @@ def mod_partition(node: torch.fx.Node):
280268
_inline_module(split, non_const_mod_name)
281269

282270
return FoldedGraphModule(
283-
split,
284-
split.graph,
285-
root_const_gm.graph,
286-
fx_const_folded_attrs_name,
287-
device_for_folded_attrs,
271+
split, split.graph, root_const_gm.graph, fx_const_folded_attrs_name
288272
)

0 commit comments

Comments
 (0)
0