8000 [hop_schema] add HopSchemaGenerator to make it easier to create hop schema by ydwu4 · Pull Request #152974 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[hop_schema] add HopSchemaGenerator to make it easier to create hop schema #152974

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 13 additions & 38 deletions torch/_higher_order_ops/base_hop.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,15 +153,11 @@ def _call_Functionalize(self, ctx, subgraph, *operands, **kwargs):
return ctx.wrap_tensors(out)

def gen_schema(self, subgraph, *operands, **kwargs):
from .schema import CFunctionSchemaGen, HopArgumentInfoGen
from .schema import HopSchemaGenerator

if not isinstance(subgraph, torch.fx.GraphModule):
subgraph = materialize_as_graph(subgraph, operands)

assert isinstance(
subgraph, torch.fx.GraphModule
), f"NYI non GraphModule subgraph got {subgraph}"

fake_args = [
ph.meta["example_value"] if "example_value" in ph.meta else ph.meta["val"]
for ph in subgraph.graph.find_nodes(op="placeholder")
Expand Down Expand Up @@ -189,40 +185,19 @@ def gen_schema(self, subgraph, *operands, **kwargs):
f"Alias info: inp-inp alias: {inp_inp_alias}, inp-out alias: {inp_out_alias}, out-out alias{out_out_alias}"
f"This may lead to silent incorrectness."
)
args = [
HopArgumentInfoGen.from_example(
subgraph, name="subgraph", default_value=None, is_mutated=False
)
]
for idx, arg in enumerate((*operands, *kwargs.items())):
if isinstance(arg, tuple):
# kwargs value are treated as default argument
arg_name, example_value = arg
default = example_value
kw_only = True
else:
arg_name = f"arg{idx}"
example_value = arg
default = None
kw_only = False
args.append(
HopArgumentInfoGen.from_example(
example_value=example_value,
name=arg_name,
default_value=default,
is_mutated=idx in mutated_inp_idx,
kw_only=kw_only,
)
)

# The output is represented as a single argument
out = HopArgumentInfoGen.from_example(
example_value=output,
name="out",
default_value=None,
is_mutated=False,
)
return CFunctionSchemaGen.from_hop_argument_info(str(self), args, out)
schema_gen = HopSchemaGenerator(self)
schema_gen.add_arg("subgraph", subgraph)
for idx, arg in enumerate(operands):
schema_gen.add_arg(f"arg{idx}", arg, is_mutated=idx in mutated_inp_idx)

for name, arg in kwargs.items():
schema_gen.add_arg(name, arg, default_value=arg, kw_only=True)

for out in output:
schema_gen.add_output(out)

return schema_gen.gen_schema()


class BaseHOPFunction(torch.autograd.Function):
Expand Down
52 changes: 48 additions & 4 deletions torch/_higher_order_ops/schema.py
8000
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Any, Optional

import torch
import torch.utils._pytree as pytree
from torch.fx.node import Target


Expand All @@ -28,12 +29,15 @@ def from_example(
example_value: Any,
*,
name: str = "",
default_value: Optional[Any],
default_value: Optional[Any] = None,
is_mutated: bool = False,
kw_only: bool = False,
) -> HopArgumentInfo:
if default_value is not None:
assert type(example_value) == type(default_value)
assert type(example_value) == type(
default_value
), f"example_value type {type(example_value)} doesn't match default_value type: {type(default_value)}"

return HopArgumentInfo(
name=name,
example_value=example_value,
Expand Down Expand Up @@ -84,6 +88,48 @@ def from_hop_argument_info(
)


class HopSchemaGenerator:
def __init__(self, hop: torch._ops.HigherOrderOperator):
self.arg_infos: list[HopArgumentInfo] = []
self.example_outputs: list[Any] = []
self.hop = hop

def add_arg(
self,
name: str,
example_value: Any,
default_value: Optional[Any] = None,
is_mutated: bool = False,
kw_only: bool = False,
) -> None:
if callable(example_value):
assert isinstance(
example_value, (torch.fx.GraphModule, torch._ops.OperatorBase)
), (
"Expect callable to be a GraphModule or an. Please call materialize_as_graph first "
f"to turn callable arguments {example_value} into a GraphModule."
)

arg_info = HopArgumentInfoGen.from_example(
example_value=example_value,
name=name,
default_value=default_value,
is_mutated=is_mutated,
kw_only=kw_only,
)
self.arg_infos.append(arg_info)

def add_output(self, output: Any) -> None:
self.example_outputs.append(output)

def gen_schema(self) -> torch._C.FunctionSchema:
return CFunctionSchemaGen.from_hop_argument_info(
str(self.hop),
self.arg_infos,
HopArgumentInfoGen.from_example(tuple(self.example_outputs), name="out"),
)


class CFunctionSchemaGen:
"""
Note: [HigherOrderOperator schema generation]
Expand Down Expand Up @@ -168,8 +214,6 @@ def from_hop_argument_info(
def find_hop_schema(
gm: torch.fx.GraphModule, target: Target
) -> list[torch._C.FunctionSchema]:
import torch.utils._pytree as pytree

schemas = []
for node in gm.graph.find_nodes(op="call_function", target=target):

Expand Down
Loading
0