8000 [graph partition] Add way to register custom rule (#163310) · pytorch/pytorch@ee7bdd8 · GitHub
[go: up one dir, main page]

Skip to content

Commit ee7bdd8

Browse files
zou3519pytorchmergebot
authored andcommitted
[graph partition] Add way to register custom rule (#163310)
This PR adds an experimental way to register a custom rule for if inductor should partition the graph around an operator. Test Plan: - new test Pull Request resolved: #163310 Approved by: https://github.com/ProExpertProg, https://github.com/BoyuanFeng, https://github.com/eellison ghstack dependencies: #162117, #162307, #162651
1 parent 0098e56 commit ee7bdd8

File tree

2 files changed

+105
-0
lines changed

2 files changed

+105
-0
lines changed

test/inductor/test_cudagraph_trees.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import gc
66
import importlib
77
import itertools
8+
import re
89
import sys
910
import unittest
1011
import warnings
@@ -912,6 +913,67 @@ def test_unaligned_static_input_non_trees(self):
912913
def test_unaligned_static_input_no_cudagraphs(self):
913914
self._test_unaligned_static_input_impl(expected_clones=0)
914915

916+
@torch._inductor.config.patch("graph_partition", True)
917+
@torch._inductor.config.patch("implicit_fallbacks", True)
918+
def test_graph_partition_custom_rule(self):
919+
def get_num_partitions(code):
920+
code = "".join(code)
921+
found = re.search(r"partitions=\[(.*)\]", code)
922+
assert found is not None
923+
partitions = found.group(1)
924+
num_partitions = len([p for p in partitions.split(",") if p])
925+
return num_partitions
926+
927+
@torch.library.custom_op("mylib::bar", mutates_args=())
928+
def bar(x: torch.Tensor, flag: int) -> torch.Tensor:
929+
return x.clone()
930+
931+
@bar.register_fake
932+
def _(x, flag):
933+
return x.clone()
934+
935+
def f(x, flag):
936+
x = x + 1
937+
x = bar(x, flag)
938+
x = x + 1
939+
return x
940+
941+
x = torch.randn(2, device="cuda")
942+
f_compiled = torch.compile(f, mode="reduce-overhead", fullgraph=True)
943+
_, code = run_and_get_code(f_compiled, x, True)
944+
num_partitions = get_num_partitions(code)
945+
self.assertEqual(num_partitions, 1)
946+
947+
@torch.library.custom_op("mylib::baz", mutates_args=())
948+
def baz(x: torch.Tensor, flag: int) -> torch.Tensor:
949+
return x.clone()
950+
951+
@baz.register_fake
952+
def _(x, flag):
953+
return x.clone()
954+
955+
def should_partition(x, flag):
956+
return flag
957+
958+
torch._inductor.scheduler.register_should_partition_rule(
959+
torch.ops.mylib.baz.default, should_partition
960+
)
961+
962+
def f(x, flag):
963+
x = x + 1
964+
x = baz(x, flag)
965+
x = x + 1
966+
return x
967+
968+
f_compiled = torch.compile(f, mode="reduce-overhead", fullgraph=True)
969+
_, code = run_and_get_code(f_compiled, x, True)
970+
num_partitions = get_num_partitions(code)
971+
self.assertEqual(num_partitions, 2)
972+
973+
_, code = run_and_get_code(f_compiled, x, False)
974+
num_partitions = get_num_partitions(code)
975+
self.assertEqual(num_partitions, 1)
976+
915977
@torch._inductor.config.patch("graph_partition", True)
916978
@torch._inductor.config.patch("triton.cudagraph_trees", False)
917979
def test_graph_partition_gc(self):

torch/_inductor/scheduler.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
from collections.abc import Iterator, Sequence
2424
from types import ModuleType
2525

26+
import weakref
27+
2628
import sympy
2729

2830
import torch
@@ -92,6 +94,28 @@
9294
_P = ParamSpec("_P")
9395

9496

97+
_custom_should_partition_fns: weakref.WeakKeyDictionary[
98+
torch._ops.OpOverload, Callable[..., bool]
99+
] = weakref.WeakKeyDictionary()
100+
101+
102+
def register_should_partition_rule(
103+
op: torch._ops.OpOverload,
104+
func: Callable[..., bool],
105+
) -> None:
106+
"""Register a function that says if Inductor should partition the graph on this op.
107+
108+
The function should be have the same signature as the operator.
109+
Inductor will invoke the function with FakeTensors when it needs to decide
110+
if the graph should be partitioned.
111+
112+
`register_should_partition_rule` is currently private and experimental.
113+
Use at your own risk.
114+
"""
115+
assert isinstance(op, torch._ops.OpOverload)
116+
_custom_should_partition_fns[op] = func
117+
118+
95119
@dataclasses.dataclass
96120
class SchedulerBuffer:
97121
scheduler: Scheduler
@@ -4632,6 +4656,25 @@ def should_partition(
46324656
) -> bool:
46334657
"""Return True if we should partition the inductor graph on this node"""
46344658

4659+
# Allow users to manually specify if a node should be partitioned
4660+
# Can only do this for FallbackKernels
4661+
ir_node = node.node
4662+
if isinstance(ir_node, torch._inductor.ir.FallbackKernel):
4663+
operator = ir_node.op_overload
4664+
if operator is not None and operator in _custom_should_partition_fns:
4665+
assert isinstance(operator, torch._ops.OpOverload)
4666+
should_partition_fn = _custom_should_partition_fns[operator]
4667+
fx_node = ir_node.get_origin_node()
4668+
assert fx_node is not None
4669+
success, fake_args, fake_kwargs = (
4670+
torch._inductor.fx_utils.get_fake_args_kwargs(fx_node)
4671+
)
4672+
assert success, (
4673+
"If this op came from a custom inductor pass, make sure to run FakeTensorUpdator"
4674+
)
4675+
should_partition = should_partition_fn(*fake_args, **fake_kwargs)
4676+
return should_partition
4677+
46354678
# When not using cudagraphs, keep all kernels in the `call` function
46364679
# instead of graph partition functions, since graph partition only brings
46374680
# benefit to cudagraph

0 commit comments

Comments
 (0)
0