8000 [inductor][cpp][gemm] support k slicing for static shapes (#130821) · pytorch/pytorch@316c0d3 · GitHub
[go: up one dir, main page]

Skip to content

Commit 316c0d3

Browse files
Jiong Gongpytorchmergebot
authored andcommitted
[inductor][cpp][gemm] support k slicing for static shapes (#130821)
This PR provides the initial support for k-slicing (i.e. parallel reduction along k-dim) of CPP GEMM template. Only static shapes are supported now. When k-slicing is enabled, there would be extra temporary buffers allocated to hold the intermediate results and an extra barrier after initial GEMM compute by each thread, i.e. each thread first stores the GEMM result to temporary accumulation buffers (pointed by `local_buf_ptrs` which is an array of pointers pointing to accumulation buffers), followed by a reduction along k-slices, epilogue computes and store to the final output `Y`. In each k-slicing thread group, the reduction along k-slices and epilogue computes are conducted in parallel along M-dim. The algorithm is designed to reduce the synchronization overhead as much as possible. The k-slicing is enabled when blocking on M and N is unable to occupy all threads. Since k-slicing doesn't always bring benefit, an extra configuration is added to enable it (disable by default). We need to identify a good heuristics in the future to enable k-slicing by default. Performance numbers with 64x4096x64, 64x10000x64, 64x20000x64 as examples on 60-core SPR as examples. As you can see, the perf of k-slicing is only better than non-k-slicing when K is large enough. Without k-slicing AUTOTUNE linear_unary(64x4096, 64x4096, 64) cpp_packed_gemm_0 0.0108 ms 100.0% _linear_pointwise 0.0431 ms 25.1% AUTOTUNE linear_unary(64x10000, 64x10000, 64) cpp_packed_gemm_0 0.0272 ms 100.0% _linear_pointwise 0.0892 ms 30.5% AUTOTUNE linear_unary(64x20000, 64x20000, 64) cpp_packed_gemm_0 0.0781 ms 100.0% _linear_pointwise 0.1693 ms 46.1% With k-slicing: AUTOTUNE linear_unary(64x4096, 64x4096, 64) cpp_packed_gemm_0 0.0260 ms 100.0% _linear_pointwise 0.0444 ms 58.5% AUTOTUNE linear_unary(64x10000, 64x10000, 64) cpp_packed_gemm_0 0.0275 ms 100.0% _linear_pointwise 0.0893 ms 30.8% AUTOTUNE linear_unary(64x20000, 64x20000, 64) cpp_packed_gemm_0 0.0284 ms 100.0% _linear_pointwise 0.1686 ms 16.8% Pull Request resolved: #130821 Approved by: https://github.com/leslie-fang-intel, https://github.com/jansel ghstack dependencies: #131024
1 parent d962dba commit 316c0d3

File tree

4 files changed

+207
-41
lines changed

4 files changed

+207
-41
lines changed

test/inductor/test_cpu_select_algorithm.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -585,6 +585,43 @@ def forward(self, x):
585585
vec_amx = VecAMX()
586586
self._check_amx_counter(vec_amx)
587587

588+
@inductor_config.patch({"freezing": True})
589+
@inductor_config.patch({"cpp.gemm_max_k_slices": 0})
590+
@patches
591+
@torch.no_grad
592+
@unittest.skipIf(not TEST_MKL, "Test requires MKL")
593+
@parametrize("batch_size", (2,))
594+
@parametrize("in_features", (1000,))
595+
@parametrize("out_features", (2,))
596+
@parametrize("bias", (True, False))
597+
@parametrize(
598+
"epilogue",
599+
(
600+
"none",
601+
"relu",
602+
),
603+
)
604+
@dtypes(torch.float, torch.bfloat16, torch.half)
605+
def test_linear_k_slicing(
606+
self, batch_size, in_features, out_features, bias, epilogue, dtype
607+
):
608+
class M(torch.nn.Module):
609+
def __init__(self, bias, epilogue, other):
610+
super().__init__()
611+
self.linear = torch.nn.Linear(in_features, out_features, bias)
612+
self.epilogue = _get_epilogue(epilogue, other)
613+
614+
def forward(self, x):
615+
return self.epilogue(self.linear(x))
616+
617+
counters.clear()
618+
v = torch.randn(batch_size, in_features).to(dtype=dtype)
619+
u = torch.randn(batch_size, out_features).to(dtype=dtype)
620+
mod = M(bias=bias, epilogue=epilogue, other=u).to(dtype=dtype).eval()
621+
with verify(dtype) as (atol, rtol):
622+
self.common(mod, (v,), atol=atol, rtol=rtol)
623+
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
624+
588625

589626
@dynamo_config.patch({"dynamic_shapes": True, "assume_static_by_default": False})
590627
class _DynamicShapesTestBase(BaseTestSelectAlgorithm):

torch/_inductor/codegen/cpp_gemm_template.py

Lines changed: 149 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -2,22 +2,27 @@
22
import contextlib
33
import logging
44
import math
5+
from functools import lru_cache
56
from typing import Any, Callable, cast, List, Optional, Set, Union
67
from unittest.mock import patch
78

89
import torch
910
import torch.utils
1011

1112
from ..._dynamo.utils import counters
12-
from .. import ir, lowering as L
13+
from .. import config, ir, lowering as L
1314
from ..kernel.mm_common import mm_args
1415
from ..select_algorithm import DataProcessorTemplateWrapper
1516
from ..utils import cache_on_self, has_free_symbols, parallel_num_threads
1617
from ..virtualized import ops, V
1718
from .cpp_micro_gemm import CppMicroGemmAMX, create_micro_gemm, LayoutType
1819
from .cpp_template import CppTemplate
1920
from .cpp_template_kernel import CppTemplateKernel
20-
from .cpp_utils import GemmBlocking, get_gemm_template_output_and_compute_dtype
21+
from .cpp_utils import (
22+
DTYPE_TO_CPP,
23+
GemmBlocking,
24+
get_gemm_template_output_and_compute_dtype,
25+
)
2126

2227
log = logging.getLogger(__name__)
2328

@@ -58,6 +63,9 @@
5863
{%- endif %}
5964
const int64_t Mc_blocks = Mt_blocks;
6065
const int64_t Kc_blocks = Kt_blocks;
66+
const int64_t num_Mc_blocks = (M0_blocks + Mc_blocks - 1) / Mc_blocks;
67+
const int64_t num_Nc_blocks = N0_blocks;
68+
const int64_t num_k_slices = (K0_blocks + Kt_blocks - 1) / Kt_blocks;
6169
{%- else %}
6270
constexpr int64_t M = {{kernel.size(GemmOut, 0)}};
6371
constexpr int64_t M0_blocks = (M + M0 - 1) / M0;
@@ -66,52 +74,68 @@
6674
constexpr int64_t Kt_blocks = {{template.thread_blocking().block_k}};
6775
constexpr int64_t Mc_blocks = {{template.cache_blocking().block_m}};
6876
constexpr int64_t Kc_blocks = {{template.cache_blocking().block_k}};
77+
constexpr int64_t num_Mc_blocks = (M0_blocks + Mc_blocks - 1) / Mc_blocks;
78+
constexpr int64_t num_Nc_blocks = N0_blocks;
79+
constexpr int64_t num_k_slices = (K0_blocks + Kt_blocks - 1) / Kt_blocks;
6980
{%- endif %}
7081
71-
// TODO(jgong5): support k-slicing
72-
{{kernel.assert_function}}(Kt_blocks == K0_blocks, "Do not support k slicing yet.");
7382
// make sure all partitions are assigned
7483
{{kernel.assert_function}}(
7584
Mt_blocks * Nt_blocks * Kt_blocks * {{num_threads}} >= M0_blocks * N0_blocks * K0_blocks,
7685
"Not all partitions are assigned."
7786
);
7887
88+
{%- if maybe_k_slicing %}
89+
std::unique_ptr<std::unique_ptr<{{DTYPE_TO_CPP[acc_buf_dtype]}}[]>[]> local_buf_ptrs;
90+
if (num_k_slices > 1) {
91+
local_buf_ptrs.reset(new std::unique_ptr<{{DTYPE_TO_CPP[acc_buf_dtype]}}[]>[num_Mc_blocks * num_Nc_blocks * num_k_slices]);
92+
}
93+
{%- endif %}
94+
7995
{%- if num_threads > 1 %}
8096
#pragma omp parallel num_threads({{num_threads}})
8197
{
82-
int tid = omp_get_thread_num();
98+
const int tid = omp_get_thread_num();
8399
int64_t m_block_start, m_block_end, n_block_start, n_block_end, k_block_start, k_block_end;
84100
mm_get_thread_blocks(
85101
tid, M0_blocks, N0_blocks, K0_blocks, Mt_blocks, Nt_blocks, Kt_blocks,
86102
m_block_start, m_block_end, n_block_start, n_block_end, k_block_start, k_block_end);
103+
{%- if maybe_k_slicing %}
104+
const int64_t k_group_id = tid / num_k_slices;
105+
const int64_t k_slice_id = tid % num_k_slices;
106+
{%- endif %}
87107
{%- else %}
88108
{
89-
int64_t m_block_start = 0;
90-
int64_t m_block_end = M0_blocks;
91-
int64_t n_block_start = 0;
92-
int64_t n_block_end = N0_blocks;
93-
int64_t k_block_start = 0;
94-
int64_t k_block_end = K0_blocks;
109+
const int tid = 0;
110+
const int64_t m_block_start = 0;
111+
const int64_t m_block_end = M0_blocks;
112+
const int64_t n_block_start = 0;
113+
const int64_t n_block_end = N0_blocks;
114+
const int64_t k_block_start = 0;
115+
const int64_t k_block_end = K0_blocks;
95116
{%- endif %}
96117
{{ micro_gemm.codegen_init(kernel) }}
97118
for (int64_t mc = m_block_start; mc < m_block_end; mc += Mc_blocks) {
98119
const int64_t m_start = mc * M0;
99120
const int64_t m_end = std::min(std::min(mc + Mc_blocks, m_block_end) * M0, M);
100121
const int64_t m_size = m_end - m_start;
101122
{%- if use_local_acc %}
123+
{%- set acc_buf_name = "local_acc_buf" %}
102124
{{ kernel.define_buffer(acc_buf_name, ["m_end - m_start", "N0"], acc_buf_dtype) }}
103125
{%- endif %}
104126
for (int64_t nc = n_block_start; nc < n_block_end; ++nc) {
105127
const int64_t n_start = nc * N0;
106128
const int64_t n_end = std::min((nc + 1) * N0, N);
129+
const int64_t n_size = n_end - n_start;
107130
{%- if use_local_acc %}
108131
{%- set acc = kernel.local_buffers[acc_buf_name] %}
132+
{{ kernel.reinit_buffer_if_null(acc_buf_name) }}
109133
{%- else %}
110134
{%- set acc = kernel.slice_nd(GemmOut, [("m_start", "m_end"), ("n_start", "n_start + N0")]) %}
111135
{%- endif %}
112136
for (int64_t kc = k_block_start; kc < k_block_end; kc += Kc_blocks) {
113137
int64_t k_start = kc * K0;
114-
int64_t k_end = std::min((kc + Kc_blocks) * K0, K);
138+
int64_t k_end = std::min(std::min(kc + Kc_blocks, k_block_end) * K0, K);
115139
{%- set tile_X = kernel.slice_nd(X, [("m_start", "m_end"), ("k_start", "k_end")]) %}
116140
{%- set tile_W_3d = kernel.slice_nd(W, [("nc", "nc + 1"), ("k_start", "k_end"), ()]) %}
117141
{%- set tile_W = kernel.view(tile_W_3d, ["k_end - k_start", micro_gemm.register_blocking.block_n]) %}
@@ -121,19 +145,64 @@
121145
{{ micro_gemm.codegen_call(kernel, tile_X, tile_W, acc, accum=True)|indent(24, false) }}
122146
}
123147
}
148+
{%- if maybe_k_slicing %}
149+
if (num_k_slices > 1) {
150+
const int64_t mxn_cache_block_id = mc * num_Nc_blocks + nc;
151+
local_buf_ptrs[mxn_cache_block_id * num_k_slices + k_slice_id].reset({{ kernel.release_buffer(acc_buf_name) }});
152+
} else
153+
{%- endif %}
154+
{
124155
{%- if N == PADDED_N %}
125156
{%- set tile_Y = kernel.slice_nd(Y_2d, [("m_start", "m_end"), ("n_start", "n_start + N0")]) %}
126157
{%- set tile_acc = acc %}
127158
{%- else %}
128159
{%- set tile_Y = kernel.slice_nd(Y_2d, [("m_start", "m_end"), ("n_start", "n_end")]) %}
129160
{%- set tile_acc = kernel.slice_nd(acc, [(), ("0", "n_end - n_start")]) %}
130161
{%- endif %}
131-
{{ kernel.store_output(
132-
tile_Y, tile_acc, GemmOut, epilogue_nodes, offsets=("m_start", "n_start"), reindexers=reindexers
133-
)|indent(16, false)
134-
}}
162+
{{ kernel.store_output(
163+
tile_Y, tile_acc, GemmOut, epilogue_nodes, offsets=("m_start", "n_start"), reindexers=reindexers
164+
)|indent(20, false)
165+
}}
166+
}
167+
}
168+
}
169+
{%- if maybe_k_slicing %}
170+
if (num_k_slices > 1) {
171+
#pragma omp barrier
172+
for (int64_t mc = m_block_start; mc < m_block_end; mc += Mc_blocks) {
173+
// We slice M-dim and each thread in the k-slicing group works on a slice
174+
const int64_t m_start_unsliced = mc * M0;
175+
const int64_t m_end_unsliced = std::min(std::min(mc + Mc_blocks, m_block_end) * M0, M);
176+
const int64_t m_size_unsliced = m_end_unsliced - m_start_unsliced;
177+
const int64_t m_slice_size = (m_size_unsliced + num_k_slices - 1) / num_k_slices;
178+
const int64_t m_start = std::min(m_start_unsliced + m_slice_size * k_slice_id, m_end_unsliced);
179+
const int64_t m_end = std::min(m_start_unsliced + m_slice_size * (k_slice_id + 1), m_end_unsliced);
180+
const int64_t m_size = m_end - m_start;
181+
const int64_t m_offset = m_start - m_start_unsliced;
182+
for (int64_t nc = n_block_start; nc < n_block_end; ++nc) {
183+
const int64_t n_start = nc * N0;
184+
const int64_t n_end = std::min((nc + 1) * N0, N);
185+
const int64_t n_size = n_end - n_start;
186+
const int64_t mxn_cache_block_id = mc * num_Nc_blocks + nc;
187+
auto {{acc_buf_name}} = local_buf_ptrs[mxn_cache_block_id * num_k_slices].get();
188+
for (int64_t other_slice = 1; other_slice < num_k_slices; other_slice++) {
189+
auto other_acc = local_buf_ptrs[mxn_cache_block_id * num_k_slices + other_slice].get();
190+
for (int64_t m = m_offset; m < m_offset + m_size; m++) {
191+
#pragma omp simd
192+
for (int64_t n = 0; n < n_size; n++) {
193+
{{acc_buf_name}}[m*N0 + n] += other_acc[m*N0 + n];
194+
}
195+
}
196+
}
197+
{%- set tile_acc_m_slice = kernel.slice_nd(tile_acc, [("m_offset", "m_offset + m_end - m_start"), ()]) %}
198+
{{ kernel.store_output(
199+
tile_Y, tile_acc_m_slice, GemmOut, epilogue_nodes, offsets=("m_start", "n_start"), reindexers=reindexers
200+
)|indent(20, false)
201+
}}
202+
}
135203
}
136204
}
205+
{%- endif %}
137206
{{ micro_gemm.codegen_finalize(kernel) }}
138207
}
139208
}
@@ -180,12 +249,12 @@ def thread_blocking(self) -> GemmBlocking:
180249
NOTE [Thread blocking in Cpp GEMM]
181250
We use simple heuristics to decide the thread blocking:
182251
1. Make sure all threads are occupied as much as possible.
183-
2. Favor more square-sized thread blocks for better data reuse.
184-
TODO(jgong5): we only do blocking on on M and N now, add blocking on K
185-
after supporting k-slicing.
252+
2. For (m, n) blocks, favor more square-sized thread blocks for better data reuse.
253+
3. If (m, n) blocks cannot occupy all the threads, we consider k-slicing.
186254
TODO(jgong5): allow tuning various blocking options
187255
"""
188256

257+
@lru_cache(maxsize=100)
189258
def get_factors(number):
190259
factors = []
191260
for i in range(int(number**0.5), 0, -1):
@@ -194,19 +263,19 @@ def get_factors(number):
194263
factors.append(i)
195264
return factors
196265

197-
def get_blocking(num_threads, factor, m_blocks, n_blocks, k_blocks):
198-
thread_block_n = (n_blocks + factor - 1) // factor
199-
cofactor = num_threads // factor
200-
thread_block_m = (m_blocks + cofactor - 1) // cofactor
201-
return GemmBlocking(thread_block_m, thread_block_n, k_blocks)
266+
def get_blocking(m_factor, n_factor, k_factor, m_blocks, n_blocks, k_blocks):
267+
thread_block_k = math.ceil(k_blocks / k_factor)
268+
thread_block_n = math.ceil(n_blocks / n_factor)
269+
thread_block_m = math.ceil(m_blocks / m_factor)
270+
return GemmBlocking(thread_block_m, thread_block_n, thread_block_k)
202271

203272
assert (
204273
not self.is_dynamic_M
205274
), "Unable to determine thread blocking for dynamic M."
206275
register_blocking = self.register_blocking
207-
m_blocks = (self.m + register_blocking.block_m - 1) // register_blocking.block_m
208-
n_blocks = (self.n + register_blocking.block_n - 1) // register_blocking.block_n
209-
k_blocks = (self.k + register_blocking.block_k - 1) // register_blocking.block_k
276+
m_blocks = math.ceil(self.m / register_blocking.block_m)
277+
n_blocks = math.ceil(self.n / register_blocking.block_n)
278+
k_blocks = math.ceil(self.k / register_blocking.block_k)
210279
factors = get_factors(self.num_threads)
211280
assert len(factors) > 0
212281

@@ -219,26 +288,52 @@ def get_better_blocking(blocking, best_blocking):
219288
block_n_size = blocking.block_n * register_blocking.block_n
220289
best_block_m_size = best_blocking.block_m * register_blocking.block_m
221290
best_block_n_size = best_blocking.block_n * register_blocking.block_n
222-
if block_m_size + block_n_size < best_block_m_size + best_block_n_size:
291+
if blocking.block_k > best_blocking.block_k:
292+
best_blocking = blocking
293+
elif (
294+
blocking.block_k == best_blocking.block_k
295+
and block_m_size + block_n_size
296+
< best_block_m_size + best_block_n_size
297+
):
223298
best_blocking = blocking
224299
return best_blocking
225300

226301
best_blocking = None
227-
# check if we can have a thread-blocking to occupy all threads
228-
for factor in factors:
229-
cofactor = self.num_threads // factor
230-
if n_blocks >= factor and m_blocks >= cofactor:
302+
# check if we can have a thread-blocking to occupy all threads without k-slicing
303+
for n_factor in factors:
304+
m_factor = self.num_threads // n_factor
305+
if n_blocks >= n_factor and m_blocks >= m_factor:
231306
blocking = get_blocking(
232-
self.num_threads, factor, m_blocks, n_blocks, k_blocks
307+
m_factor, n_factor, 1, m_blocks, n_blocks, k_blocks
233308
)
234309
best_blocking = get_better_blocking(blocking, best_blocking)
235310

236311
if best_blocking is None:
237-
for factor in factors:
238-
cofactor = self.num_threads // factor
239-
if n_blocks >= factor or m_blocks >= cofactor:
312+
for k_factor in factors:
313+
if k_blocks >= k_factor and (
314+
config.cpp.gemm_max_k_slices == 0
315+
or k_factor <= config.cpp.gemm_max_k_slices
316+
):
317+
n_factors = get_factors(self.num_threads // k_factor)
318+
for n_factor in n_factors:
319+
m_factor = (self.num_threads // k_factor) // n_factor
320+
if n_blocks >= n_factor and m_blocks >= m_factor:
321+
blocking = get_blocking(
322+
m_factor,
323+
n_factor,
324+
k_factor,
325+
m_blocks,
326+
n_blocks,
327+
k_blocks,
328+
)
329+
best_blocking = get_better_blocking(blocking, best_blocking)
330+
331+
if best_blocking is None:
332+
for n_factor in factors:
333+
m_factor = self.num_threads // n_factor
334+
if n_blocks >= n_factor or m_blocks >= m_factor:
240335
blocking = get_blocking(
241-
self.num_threads, factor, m_blocks, n_blocks, k_blocks
336+
m_factor, n_factor, 1, m_blocks, n_blocks, k_blocks
242337
)
243338
best_blocking = get_better_blocking(blocking, best_blocking)
244339

@@ -327,6 +422,17 @@ def get_occupancy():
327422
f"Number of threads: {self.num_threads}, occupancy: {get_occupancy()}" # noqa: G004
328423
)
329424

425+
def maybe_k_slicing(self):
426+
if self.num_threads == 1:
427+
return False
428+
if self.is_dynamic_M:
429+
# TODO(jgong5): perhaps use size hint to decide?
430+
return True
431+
register_blocking = self.register_blocking
432+
k_blocks = math.ceil(self.k / register_blocking.block_k)
433+
thread_blocking = self.thread_blocking()
434+
return k_blocks > thread_blocking.block_k
435+
330436
@staticmethod
331437
def add_choices(
332438
choices,
@@ -645,9 +751,11 @@ def bias_add_inner(index):
645751

646752
Y_2d: Union[ir.Buffer, ir.ReinterpretView] = Y
647753
use_local_acc = (
648-
self.layout.dtype != torch.float or int8_gemm or self.padded_n != self.n
754+
self.layout.dtype != torch.float
755+
or int8_gemm
756+
or self.padded_n != self.n
757+
or self.maybe_k_slicing()
649758
)
650-
acc_buf_name = "local_acc_buf"
651759
if epilogue_nodes:
652760
epilogues.extend(epilogue_nodes)
653761
assert Y.get_numel() == epilogues[-1].get_numel()
@@ -719,12 +827,13 @@ def bias_add_inner(index):
719827
reindexers=reindexers,
720828
Y_2d=Y_2d,
721829
use_local_acc=use_local_acc,
722-
acc_buf_name=acc_buf_name,
830+
maybe_k_slicing=self.maybe_k_slicing(),
723831
x_scale=x_scale,
724832
x_zp=x_zp,
725833
w_scale=w_scale,
726834
w_zp=w_zp,
727835
acc_buf_dtype=torch.int32 if int8_gemm else torch.float,
836+
DTYPE_TO_CPP=DTYPE_TO_CPP,
728837
)
729838
with contextlib.ExitStack() as stack:
730839
for buf in fake_buffers:

0 commit comments

Comments
 (0)
0