8000 [Inductor] Construct subgraph with benchmarking args not example_inputs · pytorch/pytorch@aebd8a6 · GitHub
[go: up one dir, main page]

Skip to content

Commit aebd8a6

Browse files
committed
[Inductor] Construct subgraph with benchmarking args not example_inputs
If the inputs to a subgraph has FlexibleLayout, the subgraph does not currently freeze the layouts here. Therefore, the `example_inputs` generated might not be consistent in layout with the `args` based in for benchmarking Differential Revision: [D74900879](https://our.internmc.facebook.com/intern/diff/D74900879/) ghstack-source-id: f0da12f Pull Request resolved: #153753
1 parent 9d3b6ee commit aebd8a6

File tree

3 files changed

+122
-44
lines changed

3 files changed

+122
-44
lines changed

test/inductor/test_subgraph_choice.py

+102-28
6D40
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,36 @@
11
# Owner(s): ["module: inductor"]
22
import functools
33
import unittest
4+
from unittest import mock
5+
from unittest.mock import MagicMock
46

57
import torch
68
from torch._dispatch.python import enable_python_dispatcher
79
from torch._inductor.codegen.subgraph import SubgraphTemplate
810
from torch._inductor.decomposition import select_decomp_table
9-
from torch._inductor.ir import Buffer, FixedLayout
11+
from torch._inductor.ir import Buffer, FixedLayout, FlexibleLayout
1012
from torch._inductor.lowering import register_lowering
11-
from torch._inductor.select_algorithm import (
12-
AlgorithmSelectorCache,
13-
autotune_select_algorithm,
14-
)
13+
from torch._inductor.select_algorithm import autotune_select_algorithm
1514
from torch._inductor.test_case import run_tests, TestCase
1615
from torch.fx.experimental.proxy_tensor import make_fx
1716
from torch.testing._internal.common_utils import skipIfXpu, TEST_WITH_ROCM
1817
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_GPU
1918

2019

20+
def decomposeK(a, b, kPartitions):
21+
m = a.shape[0]
22+
n = b.shape[1]
23+
k = a.shape[1]
24+
25+
B = k // kPartitions
26+
a_reshaped = torch.permute(a.reshape(m, B, kPartitions), (1, 0, 2))
27+
b_reshaped = b.reshape(B, kPartitions, n)
28+
result = torch.bmm(a_reshaped, b_reshaped, out_dtype=torch.float32)
29+
result_fp32 = result.to(torch.float32)
30+
reduced_buf = torch.sum(result_fp32, 0)
31+
return reduced_buf.to(a.dtype)
32+
33+
2134
class TestSubgraphChoice(TestCase):
2235
def setUp(self):
2336
super().setUp()
@@ -34,6 +47,8 @@ def test_subgraph_decompose_k(self):
3447
from torch._inductor.kernel.mm import aten_mm
3548
from torch._inductor.kernel.mm_common import mm_args
3649

50+
mat1_shape, mat2_shape = (32, 4096), (4096, 32)
51+
3752
@torch.library.custom_op("mylib::matmul_decompose", mutates_args={})
3853
def matmul_decompose(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
3954
return a @ b
@@ -42,28 +57,12 @@ def matmul_decompose(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
4257
def _(a, b):
4358
return a @ b
4459

45-
def decomposeK(a, b, kPartitions):
46-
m = a.shape[0]
47-
n = b.shape[1]
48-
k = a.shape[1]
49-
50-
B = k // kPartitions
51-
a_reshaped = torch.permute(a.reshape(m, B, kPartitions), (1, 0, 2))
52-
b_reshaped = b.reshape(B, kPartitions, n)
53-
result = torch.bmm(a_reshaped, b_reshaped, out_dtype=torch.float32)
54-
result_fp32 = result.to(torch.float32)
55-
reduced_buf = torch.sum(result_fp32, 0)
56-
return reduced_buf.to(a.dtype)
57-
58-
mat1_shape, mat2_shape = (32, 4096), (4096, 32)
59-
6060
@register_lowering(torch.ops.mylib.matmul_decompose)
6161
def _(a, b):
6262
_, _, _, layout, mat1, mat2 = mm_args(a, b)
6363

6464
choices = [aten_mm.bind((mat1, mat2), layout)]
6565

66-
# TODO (PaulZhang12): Once decomposeK lands in Inductor, move this
6766
kPartitions = 256
6867
with enable_python_dispatcher():
6968
decompositions = select_decomp_table()
@@ -77,15 +76,10 @@ def _(a, b):
7776
),
7877
)
7978

80-
mat1_tensor, mat2_tensor = (
81-
AlgorithmSelectorCache.benchmark_example_value(mat1),
82-
AlgorithmSelectorCache.benchmark_example_value(mat2),
83-
)
8479
decompose_k_subgraph_template.maybe_append_choice(
8580
choices,
8681
input_nodes=(mat1, mat2),
8782
layout=layout,
88-
example_inputs=[mat1_tensor, mat2_tensor],
8983
)
9084

9185
# Test benchmarking against aten
@@ -112,8 +106,88 @@ def func(mat1, mat2):
112106
res = compiled_func(a_in, b_in)
113107

114108
# Check same results of compiled result and regular torch.mm
115-
# Relax precision as decomposeK does first accumulation in fp16
116-
torch.testing.assert_close(res, a_in @ b_in, atol=1e-1, rtol=1e-1)
109+
torch.testing.assert_close(res, a_in @ b_in, atol=1e-2, rtol=1e-2)
110+
111+
@skipIfXpu
112+
@unittest.skipIf(TEST_WITH_ROCM, "decompose_k not supported on ROCm")
113+
def test_subgraph_freeze_layout(self):
114+
from torch._inductor.kernel.mm_common import mm_args
115+
116+
M, N, K = (4, 128, 14240)
117+
a_in = torch.randn(
118+
(M, K), dtype=torch.bfloat16, device=torch.device(f"{GPU_TYPE}:0")
119+
)
120+
b_in = torch.randn(
121+
(K, N), dtype=torch.bfloat16, device=torch.device(f"{GPU_TYPE}:0")
122+
)
123+
124+
@torch.library.custom_op("mylib::matmul_decompose_padding", mutates_args={})
125+
def matmul_decompose(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
126+
return a @ b
127+
128+
@matmul_decompose.register_fake
129+
def _(a, b):
130+
return a @ b
131+
132+
@register_lowering(torch.ops.mylib.matmul_decompose_padding)
133+
def _(a, b):
134+
_, _, _, layout, mat1, mat2 = mm_args(a, b)
135+
mat1_layout = mat1.layout
136+
assert isinstance(mat1_layout, FlexibleLayout)
137+
mat1_stride = mat1_layout.stride
138+
139+
choices = []
140+
141+
kPartitions = 2
142+
with enable_python_dispatcher():
143+
decompositions = select_decomp_table()
144+
145+
decompose_k_subgraph_template = SubgraphTemplate(
146+
name="decompose_k_mm",
147+
make_fx_graph=make_fx(
148+
functools.partial(decomposeK, kPartitions=kPartitions),
149+
decompositions,
150+
),
151+
)
152+
153+
decompose_k_subgraph_template.maybe_append_choice(
154+
choices,
155+
input_nodes=(mat1, mat2),
156+
layout=layout,
157+
)
158+
159+
choice = choices[0]
160+
assert isinstance(mat1.layout, FixedLayout)
161+
162+
# Creating the subgraph choice should have frozen the layout
163+
# We ensure padding so the stride should differ
164+
assert mat1.layout.stride != mat1_stride
165+
166+
for example_stride, layout_stride in zip(
167+
choice.example_inputs[0].stride(), mat1.layout.stride
168+
):
169+
# Example inputs should have same stride as current layout
170+
assert example_stride == layout_stride
171+
172+
return autotune_select_algorithm(
173+
"test_subgraph_choice", choices, [a, b], layout
174+
)
175+
176+
def func(mat1, mat2):
177+
return torch.ops.mylib.matmul_decompose_padding((mat1 + 1.0), mat2)
178+
179+
with mock.patch("torch._inductor.ir.V.get_current_node") as get_node_mock:
180+
node_mock = MagicMock()
181+
node_mock.meta = {"dislike_padding": False}
182+
get_node_mock.return_value = node_mock
183+
184+
compiled_func = torch.compile(func, mode="max-autotune", dynamic=False)
185+
186+
res = compiled_func(a_in, b_in)
187+
188+
# Check same results of compiled result and regular torch.mm
189+
# Relax precision as decomposeK does first accumulation in fp16
190+
torch.testing.assert_close(res, (a_in + 1.0) @ b_in, atol=1e-2, rtol=1e-2)
117191

118192

119193
if __name__ == "__main__":

torch/_inductor/codegen/subgraph.py

+19-10
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import torch
66
from torch._inductor import ir
77
from torch._inductor.codegen.common import KernelTemplate
8-
from torch._inductor.ir import Buffer, Layout
8+
from torch._inductor.ir import Buffer, ir_node_to_tensor, Layout
99
from torch._inductor.runtime.benchmarking import benchmarker
1010
from torch._inductor.virtualized import V
1111

@@ -25,12 +25,17 @@ def __init__(
2525
input_nodes: list[Buffer],
2626
layout: Layout,
2727
description: str,
28-
gm: torch.fx.GraphModule,
29-
example_inputs: list[Any],
28+
make_fx_graph: Callable[..., Any],
3029
) -> None:
3130
super().__init__(name, input_nodes, layout, description)
32-
self.gm = gm
33-
self.example_inputs = example_inputs
31+
32+
self.example_inputs = []
33+
with V.fake_mode:
34+
for inp in self.input_nodes:
35+
inp.data.freeze_layout() # `type: ignore[attr-defined]`
36+
self.example_inputs.append(ir_node_to_tensor(inp))
37+
38+
self.gm = make_fx_graph(*self.example_inputs)
3439

3540
def __str__(self) -> str:
3641
return f"SubgraphCaller({self.name})"
@@ -54,14 +59,21 @@ def benchmark(self, *args: list[Any], out: torch.Tensor) -> float:
5459
name=f"benchmark_{self.name}",
5560
)
5661

62+
for ar, example_inp in zip(args, self.example_inputs):
63+
# Sanity check that args are same layout as example inputs
64+
if isinstance(ar, torch.Tensor):
65+
assert isinstance(example_inp, torch.Tensor)
66+
assert ar.shape == example_inp.shape
67+
assert ar.stride() == example_inp.stride()
68+
5769
with V.set_graph_handler(bm_graph_lowering):
5870
# Don't bother autotuning on Triton here
5971
with inductor_config.patch(
6072
max_autotune=False,
6173
max_autotune_gemm=False,
6274
max_autotune_gemm_backends="ATEN",
6375
):
64-
bm_graph_lowering.run(*self.example_inputs)
76+
bm_graph_lowering.run(*args)
6577
mod = bm_graph_lowering.compile_to_module()
6678
bm_func = mod.call
6779

@@ -139,7 +151,6 @@ def generate( # type: ignore[override]
139151
self,
140152
input_nodes: list[Buffer],
141153
layout: Layout,
142-
example_inputs: list[Any],
143154
**kwargs: Any,
144155
) -> SubgraphChoiceCaller:
145156
"""
@@ -154,13 +165,11 @@ def generate( # type: ignore[override]
154165
Returns:
155166
SubgraphChoiceCaller: A callable object that can be used for autotuning
156167
"""
157-
gm = self.make_fx_graph(*example_inputs)
158168

159169
return SubgraphChoiceCaller(
160170
name=self.name,
161171
input_nodes=input_nodes,
162172
layout=layout,
163173
description="",
164-
gm=gm,
165-
example_inputs=example_inputs,
174+
make_fx_graph=self.make_fx_graph,
166175
)

torch/_inductor/kernel/mm.py

+1-6
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from ..codegen.cuda.gemm_template import CUTLASS2xGemmTemplate, CUTLASS3xGemmTemplate
2424
from ..codegen.rocm.ck_universal_gemm_template import CKGemmTemplate
2525
from ..codegen.subgraph import SubgraphTemplate
26-
from ..ir import FlexibleLayout, ir_node_to_tensor, is_triton
26+
from ..ir import FlexibleLayout, is_triton
2727
from ..lowering import (
2828
add_layout_constraint,
2929
constrain_to_fx_strides,
@@ -698,15 +698,10 @@ def tuned_mm(mat1, mat2, *, layout=None):
698698
),
699699
)
700700

701-
with V.fake_mode:
702-
mat1_tensor = ir_node_to_tensor(mat1)
703-
mat2_tensor = ir_node_to_tensor(mat2)
704-
705701
decompose_k_subgraph_template.maybe_append_choice(
706702
choices,
707703
input_nodes=(mat1, mat2),
708704
layout=layout,
709-
example_inputs=[mat1_tensor, mat2_tensor],
710705
)
711706

712707
if is_nonzero and use_cutlass_template(layout, m, n, k):

0 commit comments

Comments
 (0)
0