8000 [inductor][cpp] GEMM template (infra and fp32) (#124021) · tinglvv/pytorch@a00a6a4 · GitHub
[go: up one dir, main page]

Skip to content

Commit a00a6a4

Browse files
Jiong Gongtinglvv
authored andcommitted
[inductor][cpp] GEMM template (infra and fp32) (pytorch#124021)
This PR adds the Cpp template infrastructure and the initial FP32 gemm template. See RFC pytorch#125683 for more background info. 1. Cpp template infrastructure Similar template abstractions as the CUTLASS template, i.e., `CppTemplate`, `CppTemplateKernel`, `CppTemplateBuffer`. The MicroGemm micro-kernel abstraction that can be used by Cpp GEMM templates. 2. Initial FP32 gemm template This involves a GEMM template implementation `CppPackedGemmTemplate` that supports GEMM with constant weight (`B`) requiring `N` to be a multiple of register blocking while allows the static or dynamic sizes for the `M` (batch dim) of `A`. The `B` matrix would be prepacked. This is a typical setting for inference workloads. The template handles the thread decomposition (via `thread_blocking`) and cache blocking (via `cache_blocking`). Then it invokes `CppMicroGemm` which handles register blocking, instruction selection, and other CPU architecture-specific optimizations. A `CppMicroGemmFP32Vec` micro-kernel implementation is provided for fp32 matmuls implemented with ATen vec abstraction. 3. Correctness and performance The changes have been validated with fp32 inference on the three benchmark suites (torchbench, huggingface and timm_models) with both static shape and dynamic shapes. Since it is an initial implementation, we are still working on further performance improves with follow-up PRs including the optimizations in kernels as well as fusions. The perf gains are only observed from a selective number of models compared to the ATen kernels which are implemented with MKL. The perf gains are more obvious with dynamic shapes since MKL only supports packed gemm for static shapes. Below are details. Static shapes | Benchmark | torchbench | huggingface | timm_models | |------------|-------------|--------------|--------------| | Multi-threaded (baseline) | 1.47x | 1.36x | 1.91x | | Multi-threaded (max-autotune) | 1.47x | 1.36x | 1.92x | | Single-threaded (baseline) | 1.56x | 1.19x | 1.51x | | Single-threaded (max-autotune) | 1.56x | 1.19x | 1.52x | Key models being sped up: drq: 1.14x soft_act: 1.12 cait_m36_384: 1.18x Dynamic shapes | Benchmark | torchbench | huggingface | timm_models | | --- | --- | --- | --- | | Multi-threaded (baseline) | 1.43x | 1.28x | 1.85x | | Multi-threaded (max-autotune) | 1.47x | 1.28x | 1.85x | | Single-threaded (baseline) | 1.55x | 1.20x | 1.51x | | Single-threaded (max-autotune) | 1.56x | 1.19x | 1.53x | Key models being sped up: BERT_pytorch: 1.22x pyhpc_turbulent: 1.13x soft_actor_critic: 1.77x BlenderbotForCausalLM: 1.09x cait_m36_384: 1.17x Pull Request resolved: pytorch#124021 Approved by: https://github.com/jansel
1 parent ba6af98 commit a00a6a4

File tree

14 files changed

+1586
-14
lines changed

14 files changed

+1586
-14
lines changed
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
# Owner(s): ["oncall: cpu inductor"]
2+
import functools
3+
import unittest
4+
from unittest.mock import patch
5+
6+
import torch
7+
import torch._dynamo.config
8+
import torch._dynamo.config as dynamo_config
9+
import torch._inductor.config as inductor_config
10+
import torch._inductor.select_algorithm as select_algorithm
11+
from torch._dynamo.utils import counters
12+
from torch._inductor.test_case import run_tests, TestCase
13+
from torch.testing._internal.common_device_type import (
14+
dtypes,
15+
instantiate_device_type_tests,
16+
)
17+
18+
from torch.testing._internal.common_utils import IS_MACOS, parametrize, TEST_MKL
19+
20+
aten = torch.ops.aten
21+
22+
23+
def patches(fn):
24+
def skip_cache(self, choices, name, key, benchmark):
25+
if benchmark is None:
26+
return {}
27+
return benchmark(choices)
28+
29+
for patcher in [
30+
dynamo_config.patch(verbose=True),
31+
inductor_config.patch(
32+
debug=True,
33+
max_autotune=True,
34+
epilogue_fusion=True,
35+
max_autotune_gemm_backends="CPP,ATEN",
36+
),
37+
patch.object(select_algorithm, "VERIFY", dict(atol=1e-4, rtol=1e-4)),
38+
patch.object(select_algorithm.AlgorithmSelectorCache, "lookup", skip_cache),
39+
]:
40+
fn = patcher(fn)
41+
42+
@functools.wraps(fn)
43+
def wrapped(*args, **kwargs):
44+
counters.clear()
45+
torch.manual_seed(12345)
46+
return fn(*args, **kwargs)
47+
48+
return wrapped
49+
50+
51+
class TestSelectAlgorithm(TestCase):
52+
@inductor_config.patch({"freezing": True})
53+
@patches
54+
@torch.no_grad
55+
@unittest.skipIf(not TEST_MKL, "Test requires MKL")
56+
@parametrize("batch_size", (1, 2, 1000))
57+
@parametrize("in_features", (1, 2, 1000))
58+
@parametrize("out_features", (1, 32, 1024))
59+
@parametrize("bias", (True, False))
60+
@parametrize("input_3d", (True, False))
61+
@dtypes(torch.float)
62+
def test_linear_static_shapes(
63+
self, batch_size, in_features, out_features, bias, input_3d, dtype
64+
):
65+
class M(torch.nn.Module):
66+
def __init__(self, bias):
67+
super().__init__()
68+
self.linear = torch.nn.Linear(in_features, out_features, bias)
69+
70+
@torch.compile
71+
def forward(self, x):
72+
return self.linear(x)
73+
74+
counters.clear()
75+
mod = M(bias=bias).to(dtype=dtype).eval()
76+
B = (2, batch_size) if input_3d else (batch_size,)
77+
v = torch.randn(*B, in_features).to(dtype=dtype)
78+
mod(v)
79+
self.assertEqual(
80+
counters["inductor"]["select_algorithm_autotune"],
81+
1 if out_features != 1 else 0,
82+
)
83+
84+
@inductor_config.patch({"freezing": True})
85+
@patches
86+
@torch.no_grad
87+
@unittest.skipIf(not TEST_MKL, "Test requires MKL")
88+
@parametrize("bias", (True, False))
89+
@dtypes(torch.float)
90+
def test_linear_input_transpose(self, bias, dtype):
91+
batch_size = 384
92+
in_features = 196
93+
out_features = 384
94+
95+
class M(torch.nn.Module):
96+
def __init__(self, bias):
97+
super().__init__()
98+
self.linear = torch.nn.Linear(in_features, out_features, bias)
99+
100+
@torch.compile
101+
def forward(self, x):
102+
return self.linear(x)
103+
104+
counters.clear()
105+
mod = M(bias=bias).to(dtype=dtype).eval()
106+
v = torch.randn(in_features, batch_size).to(dtype=dtype)
107+
mod(v.transpose(0, 1))
108+
# TODO(jgong5): support transposed input
109+
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 0)
110+
111+
112+
@dynamo_config.patch({"dynamic_shapes": True, "assume_static_by_default": False})
113+
class _DynamicShapesTestBase(TestCase):
114+
pass
115+
116+
117+
class TestSelectAlgorithmDynamicShapes(_DynamicShapesTestBase):
118+
test_linear_dynamic_shapes = TestSelectAlgorithm.test_linear_static_shapes
119+
120+
121+
instantiate_device_type_tests(TestSelectAlgorithm, globals(), only_for="cpu")
122+
instantiate_device_type_tests(
123+
TestSelectAlgorithmDynamicShapes, globals(), only_for="cpu"
124+
)
125+
126+
127+
if __name__ == "__main__":
128+
from torch.testing._internal.inductor_utils import HAS_CPU
129+
130+
if HAS_CPU and not IS_MACOS:
131+
run_tests()

torch/_inductor/codegen/cpp.py

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import sys
99
from copy import copy, deepcopy
1010
from enum import Enum
11-
from typing import Any, Dict, List, Optional, Set, Tuple, Union
11+
from typing import Any, cast, Dict, List, Optional, Set, Tuple, Union
1212

1313
import sympy
1414

@@ -20,6 +20,7 @@
2020
from torch.utils._sympy.functions import FloorDiv, ModularIndexing
2121
from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT
2222
from torch.utils._sympy.value_ranges import bound_sympy, ValueRanges
23+
from ..._dynamo.utils import counters
2324

2425
from .. import codecache, config, ir, metrics
2526
from ..codegen.wrapper import WrapperCodeGen
@@ -3521,6 +3522,8 @@ def _can_fuse_horizontal_impl(self, node1, node2):
35213522
return self._why_fuse_nodes(node1, node2) is not None
35223523

35233524
def can_fuse_horizontal(self, node1, node2):
3525+
if node1.is_template() or node2.is_template():
3526+
return False
35243527
if (
35253528
len(node1.get_nodes()) + len(node2.get_nodes())
35263529
> config.cpp.max_horizontal_fusion_size
@@ -3601,6 +3604,9 @@ def get_fusion_pair_priority(self, node1, node2):
36013604
return 0
36023605

36033606
def can_fuse_vertical(self, node1, node2):
3607+
# TODO(jgong5): support vertical fusion for template nodes
3608+
if node1.is_template() or node2.is_template():
3609+
return False
36043610
return (
36053611
self._can_fuse_horizontal_impl(node1, node2) and not node1.is_reduction()
36063612
) or self.can_fuse_vertical_outer_loop(node1, node2)
@@ -3657,6 +3663,42 @@ def codegen_node(
36573663
if args_num > CppScheduling.MAX_FUSED_KERNEL_ARGS_NUM:
36583664
self._set_flush_status(True)
36593665

3666+
def is_cpp_template(self, node: BaseSchedulerNode) -> bool:
3667+
return isinstance(node, SchedulerNode) and isinstance(
3668+
node.node, ir.CppTemplateBuffer
3669+
)
3670+
3671+
def codegen_template(
3672+
self, template_node: BaseSchedulerNode, epilogue_nodes: List[SchedulerNode]
3673+
):
3674+
"""
3675+
Codegen a CPP template, possibly with fused epilogues
3676+
"""
3677+
counters["inductor"]["cpp_epilogue_fusion_counter"] += len(epilogue_nodes)
3678+
assert self.is_cpp_template(
3679+
template_node
3680+
), "Template node passed to CppScheduler.codegen_template must be a SchedulerNode that wraps a CppTemplateBuffer"
3681+
template_node = cast(SchedulerNode, template_node)
3682+
_, (_, rnumel) = template_node.group
3683+
assert rnumel == ()
3684+
ctb: ir.CppTemplateBuffer = cast(ir.CppTemplateBuffer, template_node.node)
3685+
epilogue_ir_nodes: List[ir.Buffer] = [n.node for n in epilogue_nodes]
3686+
assert all(
3687+
isinstance(n, ir.ComputedBuffer) for n in epilogue_ir_nodes
3688+
), "Epilogue nodes must all be instances of ir.ComputedBuffer"
3689+
kernel, render = ctb.make_kernel_render(ctb, epilogue_nodes=epilogue_ir_nodes)
3690+
with kernel:
3691+
for node in [template_node, *epilogue_nodes]:
3692+
node.mark_run()
3693+
src_code = render()
3694+
3695+
with V.set_kernel_handler(kernel):
3696+
node_schedule = [template_node, *epilogue_nodes]
3697+
kernel_name = self.define_kernel(src_code, node_schedule, kernel.args)
3698+
kernel.call_kernel(kernel_name, ctb)
3699+
V.graph.removed_buffers |= kernel.removed_buffers
3700+
self.scheduler.free_buffers()
3701+
36603702
def _get_scheduled_num_args(self):
36613703
return self.kernel_group.get_num_args()
36623704

@@ -3666,7 +3708,7 @@ def ready_to_flush(self):
36663708
def codegen_sync(self):
36673709
pass
36683710

3669-
def define_kernel(self, src_code, nodes):
3711+
def define_kernel(self, src_code, nodes, kernel_args=None):
36703712
wrapper = V.graph.wrapper_code
36713713
fused_name = (
36723714
get_fused_kernel_name(nodes, config.cpp.descriptive_names)
@@ -3682,7 +3724,8 @@ def define_kernel(self, src_code, nodes):
36823724
src_code = src_code.replace("#pragma CMT", "//")
36833725

36843726
compile_wrapper = IndentedBuffer()
3685-
_, _, arg_types = self.kernel_group.args.cpp_argdefs()
3727+
args = self.kernel_group.args if kernel_args is None else kernel_args
3728+
_, _, arg_types = args.cpp_argdefs()
36863729
if not V.graph.cpp_wrapper:
36873730
compile_wrapper.writeline(f"async_compile.cpp_pybinding({arg_types!r}, '''")
36883731
compile_wrapper.splice(src_code, strip=True)

0 commit comments

Comments
 (0)
0