8000 Back out "[const_fold] Set requires_grad based on the folded tensor; add device_for_folding option" by singlaiiit · Pull Request #79655 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Back out "[const_fold] Set requires_grad based on the folded tensor; add device_for_folding option" #79655

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 0 additions & 54 deletions test/fx/test_fx_const_fold.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,10 @@
import torch
import torch.fx
from torch.fx.experimental import const_fold
from torch.fx.passes.shape_prop import _extract_tensor_metadata, ShapeProp
from torch.testing._internal.common_utils import TestCase


class TestConstFold(TestCase):
def _get_attr(self, node):
mod = node.graph.owning_module
target = str(node.target)
target_atoms = target.split(".")
curr_obj = mod
for i, atom in enumerate(target_atoms):
if not hasattr(curr_obj, atom):
raise RuntimeError(
f"Node referenced nonexistent target '{'.'.join(target_atoms[:i])}'; "
f" original whole target: '{target}'"
)
curr_obj = getattr(curr_obj, atom)
return curr_obj

def _verify_const_fold_mod(self, mod_folded: const_fold.FoldedGraphModule):
self.assertTrue(mod_folded.const_subgraph_module is not None)

Expand Down Expand Up @@ -668,42 +653,3 @@ def forward(self, x):
# Now run both folded and non-folded to check results equal.
inp = torch.randn(4, 4)
self.assertTrue(torch.equal(mod_folded(inp), mod(inp)))

def test_const_fold_tensor_meta(self):
self._test_const_fold_tensor_meta(True)
self._test_const_fold_tensor_meta(False)

def _test_const_fold_tensor_meta(self, requires_grad):
"""
Verify tensor_meta is handled correctly.
"""

class ConstFoldTestModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.attr_1 = torch.nn.Parameter(torch.tensor([[-0.9]]), requires_grad)
self.attr_2 = torch.nn.Parameter(torch.tensor([[17.1]]), requires_grad)

def forward(self, x, y):
a = self.attr_1 + self.attr_1
x = x - a
return x * y + self.attr_2

mod = ConstFoldTestModule()
gm = torch.fx.symbolic_trace(mod)
in_x, in_y = torch.tensor([[-0.45]]), torch.tensor([0.9])
ShapeProp(gm).propagate(in_x, in_y)
mod_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(gm)
self._verify_const_fold_mod(mod_folded)

mod_folded.run_folding()

for n in mod_folded.graph.nodes:
if n.op == "get_attr":
attr = self._get_attr(n)
self.assertEquals(_extract_tensor_metadata(attr), n.meta["tensor_meta"])

# Now run both folded and non-folded to check results equal.
base_result = mod(in_x, in_y)
fold_result = mod_folded(in_x, in_y)
self.assertTrue(torch.equal(fold_result, base_result))
32 changes: 8 additions & 24 deletions torch/fx/experimental/const_fold.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import re
from typing import Callable, Dict, Optional, Set, Union
from typing import Callable, Dict, Set, Optional, Union

import torch.fx
from torch.fx.node import map_arg
Expand All @@ -22,7 +22,6 @@ def __init__(
graph: torch.fx.Graph,
const_subgraph: Optional[torch.fx.Graph] = None,
fx_const_folded_attrs_name: str = None,
device_for_folded_attrs: str = "cpu",
):
# In init, we set graph's owning module to root which will make graph's
# owning module be None because graph already have a owning module. We
Expand All @@ -37,7 +36,6 @@ def __init__(
)
self.has_folding_been_run = False
self.fx_const_folded_attrs_name = fx_const_folded_attrs_name
self.device_for_folded_attrs = device_for_folded_attrs

def __call__(self, *args, **kwargs):
if not self.has_folding_been_run:
Expand All @@ -60,19 +58,12 @@ def run_folding(self):
# subgraphs output a single Tensor while multiple outputs are returned as
# Tuple[Tensor,].
folded_attrs = self.const_subgraph_module()

def _create_param(i):
return torch.nn.Parameter(
i
if not isinstance(i, int)
else torch.Tensor([i]).to(device=self.device_for_folded_attrs),
requires_grad=i.requires_grad if isinstance(i, torch.Tensor) else False,
)

params = (
torch.nn.ParameterList([_create_param(i) for i in folded_attrs])
torch.nn.ParameterList([torch.nn.Parameter(
i if not isinstance(i, int) else torch.Tensor([i]).cuda()) for i in folded_attrs])
if isinstance(folded_attrs, tuple)
else _create_param(folded_attrs)
else torch.nn.Parameter(
folded_attrs if not isinstance(folded_attrs, int) else torch.Tensor([folded_attrs]).cuda())
)
setattr(self, self.fx_const_folded_attrs_name, params)

Expand Down Expand Up @@ -144,8 +135,7 @@ def get_unique_attr_name_in_module(mod_traced: torch.fx.GraphModule, name: str)

def split_const_subgraphs(
module: Union[torch.nn.Module, torch.fx.GraphModule],
skip_folding_node_fn: Optional[Callable[[torch.fx.Node], bool]] = None,
device_for_folded_attrs: str = "cpu",
skip_folding_node_fn: Optional[Callable[[torch.fx.Node], bool]] = None
) -> FoldedGraphModule:
"""
Looks through `module` for any nodes that have all constant attribute inputs
Expand All @@ -171,9 +161,7 @@ def split_const_subgraphs(

# If the node itself is constant, or all of its inputs are constant,
# then tag it as constant.
if node.op != "get_attr" and not set(node.all_input_nodes).issubset(
const_nodes
):
if node.op != "get_attr" and not set(node.all_input_nodes).issubset(const_nodes):
continue

# If provided skip folding function says to skip, then skip.
Expand Down Expand Up @@ -280,9 +268,5 @@ def mod_partition(node: torch.fx.Node):
_inline_module(split, non_const_mod_name)

return FoldedGraphModule(
split,
split.graph,
root_const_gm.graph,
fx_const_folded_attrs_name,
device_for_folded_attrs,
split, split.graph, root_const_gm.graph, fx_const_folded_attrs_name
)
0