|
5 | 5 | import gc |
6 | 6 | import importlib |
7 | 7 | import itertools |
| 8 | +import re |
8 | 9 | import sys |
9 | 10 | import unittest |
10 | 11 | import warnings |
@@ -912,6 +913,67 @@ def test_unaligned_static_input_non_trees(self): |
912 | 913 | def test_unaligned_static_input_no_cudagraphs(self): |
913 | 914 | self._test_unaligned_static_input_impl(expected_clones=0) |
914 | 915 |
|
| 916 | + @torch._inductor.config.patch("graph_partition", True) |
| 917 | + @torch._inductor.config.patch("implicit_fallbacks", True) |
| 918 | + def test_graph_partition_custom_rule(self): |
| 919 | + def get_num_partitions(code): |
| 920 | + code = "".join(code) |
| 921 | + found = re.search(r"partitions=\[(.*)\]", code) |
| 922 | + assert found is not None |
| 923 | + partitions = found.group(1) |
| 924 | + num_partitions = len([p for p in partitions.split(",") if p]) |
| 925 | + return num_partitions |
| 926 | + |
| 927 | + @torch.library.custom_op("mylib::bar", mutates_args=()) |
| 928 | + def bar(x: torch.Tensor, flag: int) -> torch.Tensor: |
| 929 | + return x.clone() |
| 930 | + |
| 931 | + @bar.register_fake |
| 932 | + def _(x, flag): |
| 933 | + return x.clone() |
| 934 | + |
| 935 | + def f(x, flag): |
| 936 | + x = x + 1 |
| 937 | + x = bar(x, flag) |
| 938 | + x = x + 1 |
| 939 | + return x |
| 940 | + |
| 941 | + x = torch.randn(2, device="cuda") |
| 942 | + f_compiled = torch.compile(f, mode="reduce-overhead", fullgraph=True) |
| 943 | + _, code = run_and_get_code(f_compiled, x, True) |
| 944 | + num_partitions = get_num_partitions(code) |
| 945 | + self.assertEqual(num_partitions, 1) |
| 946 | + |
| 947 | + @torch.library.custom_op("mylib::baz", mutates_args=()) |
| 948 | + def baz(x: torch.Tensor, flag: int) -> torch.Tensor: |
| 949 | + return x.clone() |
| 950 | + |
| 951 | + @baz.register_fake |
| 952 | + def _(x, flag): |
| 953 | + return x.clone() |
| 954 | + |
| 955 | + def should_partition(x, flag): |
| 956 | + return flag |
| 957 | + |
| 958 | + torch._inductor.scheduler.register_should_partition_rule( |
| 959 | + torch.ops.mylib.baz.default, should_partition |
| 960 | + ) |
| 961 | + |
| 962 | + def f(x, flag): |
| 963 | + x = x + 1 |
| 964 | + x = baz(x, flag) |
| 965 | + x = x + 1 |
| 966 | + return x |
| 967 | + |
| 968 | + f_compiled = torch.compile(f, mode="reduce-overhead", fullgraph=True) |
| 969 | + _, code = run_and_get_code(f_compiled, x, True) |
| 970 | + num_partitions = get_num_partitions(code) |
| 971 | + self.assertEqual(num_partitions, 2) |
| 972 | + |
| 973 | + _, code = run_and_get_code(f_compiled, x, False) |
| 974 | + num_partitions = get_num_partitions(code) |
| 975 | + self.assertEqual(num_partitions, 1) |
| 976 | + |
915 | 977 | @torch._inductor.config.patch("graph_partition", True) |
916 | 978 | @torch._inductor.config.patch("triton.cudagraph_trees", False) |
917 | 979 | def test_graph_partition_gc(self): |
|
0 commit comments