8000 [hop_schema] support gen_schema for invoke_subgraph by ydwu4 · Pull Request #152984 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[hop_schema] support gen_schema for invoke_subgraph #152984

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
< 8000 /details>
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 131 additions & 0 deletions test/higher_order_ops/test_invoke_subgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# flake8: noqa: E731

import unittest
import unittest.mock as mock

from parameterized import parameterized_class

Expand All @@ -14,10 +15,12 @@
from functorch.compile import aot_function, nop
from torch._dynamo.testing import (
AotEagerAndRecordGraphs,
EagerAndRecordGraphs,
InductorAndRecordGraphs,
normalize_gm,
)
from torch._higher_order_ops.invoke_subgraph import mark_compile_region
from torch._higher_order_ops.schema import find_hop_schema
from torch.testing._internal.common_utils import (
run_tests,
skipIfTorchDynamo,
Expand Down Expand Up @@ -167,6 +170,134 @@ def fn(x, y):
self.assertEqual(x.grad, x_clone.grad)
self.assertEqual(y.grad, y_clone.grad)

def test_gen_schema(self):
class Mod(torch.nn.Module):
def __init__(self):
super().__init__()
self.c = 5

@mark_compile_region
def forward(self, x, y):
return torch.mul(x, y).sin() + self.c

mod = Mod()

def fn(x, y):
return mod(x, y) + mod(x, y)

x = torch.randn(8, requires_grad=True)
y = torch.randn(8, requires_grad=True)

x_clone = x.detach().clone().requires_grad_(True)
y_clone = y.detach().clone().requires_grad_(True)
backend = AotEagerAndRecordGraphs()
res = torch.compile(fn, backend=backend, fullgraph=True)(x_clone, y_clone)
res.sum().backward()

self.assertEqual(len(backend.fw_graphs), 1)
self.assertEqual(len(backend.bw_graphs), 1)
fw_schema = find_hop_schema(
backend.fw_graphs[0], torch.ops.higher_order.invoke_subgraph
)
bw_schema = find_hop_schema(
backend.bw_graphs[0], torch.ops.higher_order.invoke_subgraph
)
self.assertExpectedInline(
str(fw_schema[0]),
"""invoke_subgraph(Any subgraph, str identifier, Tensor arg0, Tensor arg1) -> (Tensor, Tensor, Tensor)""",
)
self.assertExpectedInline(
str(fw_schema[1]),
"""invoke_subgraph(Any subgraph, str identifier, Tensor arg0, Tensor arg1) -> (Tensor, Tensor, Tensor)""",
)
self.assertExpectedInline(
str(bw_schema[0]),
"""invoke_subgraph(Any subgraph, str identifier, Tensor arg0, Tensor arg1, Tensor arg2) -> (Tensor, Tensor)""",
)
self.assertExpectedInline(
str(bw_schema[1]),
"""invoke_subgraph(Any subgraph, str identifier, Tensor arg0, Tensor arg1, Tensor arg2) -> (Tensor, Tensor)""",
)

def test_gen_schema_with_buffer_mutation(self):
class Mod(torch.nn.Module):
def __init__(self):
super().__init__()
self.c = 5
self.register_buffer("buf", torch.ones(8, requires_grad=False))

@mark_compile_region
def forward(self, x, y):
self.buf.add_(1)
return torch.mul(x, y).sin() + self.c + self.buf

mod_ref = Mod()
mod = Mod()

def fn(mod, x, y):
return mod(x, y) + mod(x, y)

x = torch.randn(8, requires_grad=True)
y = torch.randn(8, requires_grad=True)
ref = fn(mod_ref, x, y)

x_clone = x.detach().clone().requires_grad_(True)
y_clone = y.detach().clone().requires_grad_(True)
backend = EagerAndRecordGraphs()
with mock.patch(
"torch._dynamo.variables.higher_order_ops.InvokeSubgraphHigherOrderVariable.supports_input_mutation",
True,
):
res = torch.compile(fn, backend=backend, fullgraph=True)(
mod, x_clone, y_clone
)

self.assertEqual(len(backend.graphs), 1)
fw_schema = find_hop_schema(
backend.graphs[0], torch.ops.higher_order.invoke_subgraph
)
if not TEST_WITH_CROSSREF:
self.assertExpectedInline(
normalize_gm(backend.graphs[0].print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_x_: "f32[8]", L_y_: "f32[8]", L_mod_buffers_buf_: "f32[8]"):
l_x_ = L_x_
l_y_ = L_y_
l_mod_buffers_buf_ = L_mod_buffers_buf_

subgraph_0 = self.subgraph_0
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', l_mod_buffers_buf_, l_x_, l_y_); subgraph_0 = None
getitem: "f32[8]" = invoke_subgraph[0]; invoke_subgraph = None
subgraph_1 = self.subgraph_0
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_1, 'subgraph_0', l_mod_buffers_buf_, l_x_, l_y_); subgraph_1 = l_mod_buffers_buf_ = l_x_ = l_y_ = None
getitem_1: "f32[8]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None

add: "f32[8]" = getitem + getitem_1; getitem = getitem_1 = None
return (add,)

class subgraph_0(torch.nn.Module):
def forward(self, l_mod_buffers_buf_: "f32[8]", l_x_: "f32[8]", l_y_: "f32[8]"):
add_: "f32[8]" = l_mod_buffers_buf_.add_(1); add_ = None

mul: "f32[8]" = torch.mul(l_x_, l_y_); l_x_ = l_y_ = None
sin: "f32[8]" = mul.sin(); mul = None
add: "f32[8]" = sin + 5; sin = None
add_1: "f32[8]" = add + l_mod_buffers_buf_; add = l_mod_buffers_buf_ = None
return (add_1,)
""",
)
self.assertExpectedInline(
str(fw_schema[0]),
"""invoke_subgraph(Any subgraph, str identifier, Tensor(a2!) arg0, Tensor arg1, Tensor arg2) -> ((Tensor))""",
)
self.assertExpectedInline(
str(fw_schema[1]),
"""invoke_subgraph(Any subgraph, str identifier, Tensor(a2!) arg0, Tensor arg1, Tensor arg2) -> ((Tensor))""",
)
self.assertEqual(res, ref)
self.assertEqual(mod.buf, mod_ref.buf)

def test_list(self):
@mark_compile_region
def gn(x, y):
Expand Down
23 changes: 23 additions & 0 deletions torch/_higher_order_ops/invoke_subgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,29 @@ def __call__(
< 10000 span class='blob-code-inner blob-code-marker ' data-code-marker=" ">
return super().__call__(subgraph, identifier, *operands)

def gen_schema(self, subgraph, identifier, *operands):
from torch._higher_order_ops.schema import HopSchemaGenerator
from torch._higher_order_ops.utils import (
check_input_alias_and_mutation_return_ouputs,
)

schema_gen = HopSchemaGenerator(self)
schema_gen.add_arg("subgraph", subgraph)
schema_gen.add_arg("identifier", identifier)
example_inputs = [
n.meta["val"] if "val" in n.meta else n.meta["example_value"]
for n in subgraph.graph.find_nodes(op="placeholder")
]
_, _, _, mutated_inputs, outputs = check_input_alias_and_mutation_return_ouputs(
subgraph, example_inputs
)
for idx, arg in enumerate(operands):
schema_gen.add_arg(f"arg{idx}", arg, is_mutated=idx in mutated_inputs)
for out in outputs:
schema_gen.add_output(out)

return schema_gen.gen_schema()


invoke_subgraph = InvokeSubgraphHOP()

Expand Down
6 changes: 5 additions & 1 deletion torch/_higher_order_ops/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,11 @@ def _get_example_value(node: torch.fx.Node) -> Any:
assert isinstance(node.target, str)
return getattr(gm, node.target)
else:
return node.meta["example_value"]
return (
node.meta["example_value"]
if "example_value" in node.meta
else node.meta["val"]
)

fake_args, fake_kwargs = pytree.tree_map_only(
torch.fx.Node,
Expand Down
2 changes: 1 addition & 1 deletion torch/_higher_order_ops/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -758,7 +758,7 @@ def check_input_alias_and_mutation(

def check_input_alias_and_mutation_return_ouputs(
gm: torch.fx.GraphModule,
fake_args: list[FakeTensor],
fake_args: Union[list[FakeTensor], tuple[FakeTensor, ...]],
) -> tuple[
dict[int, int],
dict[int, int],
Expand Down
Loading
0