2
2
import contextlib
3
3
import logging
4
4
import math
5
+ from functools import lru_cache
5
6
from typing import Any , Callable , cast , List , Optional , Set , Union
6
7
from unittest .mock import patch
7
8
8
9
import torch
9
10
import torch .utils
10
11
11
12
from ..._dynamo .utils import counters
12
- from .. import ir , lowering as L
13
+ from .. import config , ir , lowering as L
13
14
from ..kernel .mm_common import mm_args
14
15
from ..select_algorithm import DataProcessorTemplateWrapper
15
16
from ..utils import cache_on_self , has_free_symbols , parallel_num_threads
16
17
from ..virtualized import ops , V
17
18
from .cpp_micro_gemm import CppMicroGemmAMX , create_micro_gemm , LayoutType
18
19
from .cpp_template import CppTemplate
19
20
from .cpp_template_kernel import CppTemplateKernel
20
- from .cpp_utils import GemmBlocking , get_gemm_template_output_and_compute_dtype
21
+ from .cpp_utils import (
22
+ DTYPE_TO_CPP ,
23
+ GemmBlocking ,
24
+ get_gemm_template_output_and_compute_dtype ,
25
+ )
21
26
22
27
log = logging .getLogger (__name__ )
23
28
58
63
{%- endif %}
59
64
const int64_t Mc_blocks = Mt_blocks;
60
65
const int64_t Kc_blocks = Kt_blocks;
66
+ const int64_t num_Mc_blocks = (M0_blocks + Mc_blocks - 1) / Mc_blocks;
67
+ const int64_t num_Nc_blocks = N0_blocks;
68
+ const int64_t num_k_slices = (K0_blocks + Kt_blocks - 1) / Kt_blocks;
61
69
{%- else %}
62
70
constexpr int64_t M = {{kernel.size(GemmOut, 0)}};
63
71
constexpr int64_t M0_blocks = (M + M0 - 1) / M0;
66
74
constexpr int64_t Kt_blocks = {{template.thread_blocking().block_k}};
67
75
constexpr int64_t Mc_blocks = {{template.cache_blocking().block_m}};
68
76
constexpr int64_t Kc_blocks = {{template.cache_blocking().block_k}};
77
+ constexpr int64_t num_Mc_blocks = (M0_blocks + Mc_blocks - 1) / Mc_blocks;
78
+ constexpr int64_t num_Nc_blocks = N0_blocks;
79
+ constexpr int64_t num_k_slices = (K0_blocks + Kt_blocks - 1) / Kt_blocks;
69
80
{%- endif %}
70
81
71
- // TODO(jgong5): support k-slicing
72
- {{kernel.assert_function}}(Kt_blocks == K0_blocks, "Do not support k slicing yet.");
73
82
// make sure all partitions are assigned
74
83
{{kernel.assert_function}}(
75
84
Mt_blocks * Nt_blocks * Kt_blocks * {{num_threads}} >= M0_blocks * N0_blocks * K0_blocks,
76
85
"Not all partitions are assigned."
77
86
);
78
87
88
+ {%- if maybe_k_slicing %}
89
+ std::unique_ptr<std::unique_ptr<{{DTYPE_TO_CPP[acc_buf_dtype]}}[]>[]> local_buf_ptrs;
90
+ if (num_k_slices > 1) {
91
+ local_buf_ptrs.reset(new std::unique_ptr<{{DTYPE_TO_CPP[acc_buf_dtype]}}[]>[num_Mc_blocks * num_Nc_blocks * num_k_slices]);
92
+ }
93
+ {%- endif %}
94
+
79
95
{%- if num_threads > 1 %}
80
96
#pragma omp parallel num_threads({{num_threads}})
81
97
{
82
- int tid = omp_get_thread_num();
98
+ const int tid = omp_get_thread_num();
83
99
int64_t m_block_start, m_block_end, n_block_start, n_block_end, k_block_start, k_block_end;
84
100
mm_get_thread_blocks(
85
101
tid, M0_blocks, N0_blocks, K0_blocks, Mt_blocks, Nt_blocks, Kt_blocks,
86
102
m_block_start, m_block_end, n_block_start, n_block_end, k_block_start, k_block_end);
103
+ {%- if maybe_k_slicing %}
104
+ const int64_t k_group_id = tid / num_k_slices;
105
+ const int64_t k_slice_id = tid % num_k_slices;
106
+ {%- endif %}
87
107
{%- else %}
88
108
{
89
- int64_t m_block_start = 0;
90
- int64_t m_block_end = M0_blocks;
91
- int64_t n_block_start = 0;
92
- int64_t n_block_end = N0_blocks;
93
- int64_t k_block_start = 0;
94
- int64_t k_block_end = K0_blocks;
109
+ const int tid = 0;
110
+ const int64_t m_block_start = 0;
111
+ const int64_t m_block_end = M0_blocks;
112
+ const int64_t n_block_start = 0;
113
+ const int64_t n_block_end = N0_blocks;
114
+ const int64_t k_block_start = 0;
115
+ const int64_t k_block_end = K0_blocks;
95
116
{%- endif %}
96
117
{{ micro_gemm.codegen_init(kernel) }}
97
118
for (int64_t mc = m_block_start; mc < m_block_end; mc += Mc_blocks) {
98
119
const int64_t m_start = mc * M0;
99
120
const int64_t m_end = std::min(std::min(mc + Mc_blocks, m_block_end) * M0, M);
100
121
const int64_t m_size = m_end - m_start;
101
122
{%- if use_local_acc %}
123
+ {%- set acc_buf_name = "local_acc_buf" %}
102
124
{{ kernel.define_buffer(acc_buf_name, ["m_end - m_start", "N0"], acc_buf_dtype) }}
103
125
{%- endif %}
104
126
for (int64_t nc = n_block_start; nc < n_block_end; ++nc) {
105
127
const int64_t n_start = nc * N0;
106
128
const int64_t n_end = std::min((nc + 1) * N0, N);
129
+ const int64_t n_size = n_end - n_start;
107
130
{%- if use_local_acc %}
108
131
{%- set acc = kernel.local_buffers[acc_buf_name] %}
132
+ {{ kernel.reinit_buffer_if_null(acc_buf_name) }}
109
133
{%- else %}
110
134
{%- set acc = kernel.slice_nd(GemmOut, [("m_start", "m_end"), ("n_start", "n_start + N0")]) %}
111
135
{%- endif %}
112
136
for (int64_t kc = k_block_start; kc < k_block_end; kc += Kc_blocks) {
113
137
int64_t k_start = kc * K0;
114
- int64_t k_end = std::min((kc + Kc_blocks) * K0, K);
138
+ int64_t k_end = std::min(std::min (kc + Kc_blocks, k_block_end ) * K0, K);
115
139
{%- set tile_X = kernel.slice_nd(X, [("m_start", "m_end"), ("k_start", "k_end")]) %}
116
140
{%- set tile_W_3d = kernel.slice_nd(W, [("nc", "nc + 1"), ("k_start", "k_end"), ()]) %}
117
141
{%- set tile_W = kernel.view(tile_W_3d, ["k_end - k_start", micro_gemm.register_blocking.block_n]) %}
121
145
{{ micro_gemm.codegen_call(kernel, tile_X, tile_W, acc, accum=True)|indent(24, false) }}
122
146
}
123
147
}
148
+ {%- if maybe_k_slicing %}
149
+ if (num_k_slices > 1) {
150
+ const int64_t mxn_cache_block_id = mc * num_Nc_blocks + nc;
151
+ local_buf_ptrs[mxn_cache_block_id * num_k_slices + k_slice_id].reset({{ kernel.release_buffer(acc_buf_name) }});
152
+ } else
153
+ {%- endif %}
154
+ {
124
155
{%- if N == PADDED_N %}
125
156
{%- set tile_Y = kernel.slice_nd(Y_2d, [("m_start", "m_end"), ("n_start", "n_start + N0")]) %}
126
157
{%- set tile_acc = acc %}
127
158
{%- else %}
128
159
{%- set tile_Y = kernel.slice_nd(Y_2d, [("m_start", "m_end"), ("n_start", "n_end")]) %}
129
160
{%- set tile_acc = kernel.slice_nd(acc, [(), ("0", "n_end - n_start")]) %}
130
161
{%- endif %}
131
- {{ kernel.store_output(
132
- tile_Y, tile_acc, GemmOut, epilogue_nodes, offsets=("m_start", "n_start"), reindexers=reindexers
133
- )|indent(16, false)
134
- }}
162
+ {{ kernel.store_output(
163
+ tile_Y, tile_acc, GemmOut, epilogue_nodes, offsets=("m_start", "n_start"), reindexers=reindexers
164
+ )|indent(20, false)
165
+ }}
166
+ }
167
+ }
168
+ }
169
+ {%- if maybe_k_slicing %}
170
+ if (num_k_slices > 1) {
171
+ #pragma omp barrier
172
+ for (int64_t mc = m_block_start; mc < m_block_end; mc += Mc_blocks) {
173
+ // We slice M-dim and each thread in the k-slicing group works on a slice
174
+ const int64_t m_start_unsliced = mc * M0;
175
+ const int64_t m_end_unsliced = std::min(std::min(mc + Mc_blocks, m_block_end) * M0, M);
176
+ const int64_t m_size_unsliced = m_end_unsliced - m_start_unsliced;
177
+ const int64_t m_slice_size = (m_size_unsliced + num_k_slices - 1) / num_k_slices;
178
+ const int64_t m_start = std::min(m_start_unsliced + m_slice_size * k_slice_id, m_end_unsliced);
179
+ const int64_t m_end = std::min(m_start_unsliced + m_slice_size * (k_slice_id + 1), m_end_unsliced);
180
+ const int64_t m_size = m_end - m_start;
181
+ const int64_t m_offset = m_start - m_start_unsliced;
182
+ for (int64_t nc = n_block_start; nc < n_block_end; ++nc) {
183
+ const int64_t n_start = nc * N0;
184
+ const int64_t n_end = std::min((nc + 1) * N0, N);
185
+ const int64_t n_size = n_end - n_start;
186
+ const int64_t mxn_cache_block_id = mc * num_Nc_blocks + nc;
187
+ auto {{acc_buf_name}} = local_buf_ptrs[mxn_cache_block_id * num_k_slices].get();
188
+ for (int64_t other_slice = 1; other_slice < num_k_slices; other_slice++) {
189
+ auto other_acc = local_buf_ptrs[mxn_cache_block_id * num_k_slices + other_slice].get();
190
+ for (int64_t m = m_offset; m < m_offset + m_size; m++) {
191
+ #pragma omp simd
192
+ for (int64_t n = 0; n < n_size; n++) {
193
+ {{acc_buf_name}}[m*N0 + n] += other_acc[m*N0 + n];
194
+ }
195
+ }
196
+ }
197
+ {%- set tile_acc_m_slice = kernel.slice_nd(tile_acc, [("m_offset", "m_offset + m_end - m_start"), ()]) %}
198
+ {{ kernel.store_output(
199
+ tile_Y, tile_acc_m_slice, GemmOut, epilogue_nodes, offsets=("m_start", "n_start"), reindexers=reindexers
200
+ )|indent(20, false)
201
+ }}
202
+ }
135
203
}
136
204
}
205
+ {%- endif %}
137
206
{{ micro_gemm.codegen_finalize(kernel) }}
138
207
}
139
208
}
@@ -180,12 +249,12 @@ def thread_blocking(self) -> GemmBlocking:
180
249
NOTE [Thread blocking in Cpp GEMM]
181
250
We use simple heuristics to decide the thread blocking:
182
251
1. Make sure all threads are occupied as much as possible.
183
- 2. Favor more square-sized thread blocks for better data reuse.
184
- TODO(jgong5): we only do blocking on on M and N now, add blocking on K
185
- after supporting k-slicing.
252
+ 2. For (m, n) blocks, favor more square-sized thread blocks for better data reuse.
253
+ 3. If (m, n) blocks cannot occupy all the threads, we consider k-slicing.
186
254
TODO(jgong5): allow tuning various blocking options
187
255
"""
188
256
257
+ @lru_cache (maxsize = 100 )
189
258
def get_factors (number ):
190
259
factors = []
191
260
for i in range (int (number ** 0.5 ), 0 , - 1 ):
@@ -194,19 +263,19 @@ def get_factors(number):
194
263
factors .append (i )
195
264
return factors
196
265
197
- def get_blocking (num_threads , factor , m_blocks , n_blocks , k_blocks ):
198
- thread_block_n = ( n_blocks + factor - 1 ) // factor
199
- cofactor = num_threads // factor
200
- thread_block_m = (m_blocks + cofactor - 1 ) // cofactor
201
- return GemmBlocking (thread_block_m , thread_block_n , k_blocks )
266
+ def get_blocking (m_factor , n_factor , k_factor , m_blocks , n_blocks , k_blocks ):
267
+ thread_block_k = math . ceil ( k_blocks / k_factor )
268
+ thread_block_n = math . ceil ( n_blocks / n_factor )
269
+ thread_block_m = math . ceil (m_blocks / m_factor )
270
+ return GemmBlocking (thread_block_m , thread_block_n , thread_block_k )
202
271
203
272
assert (
204
273
not self .is_dynamic_M
205
274
), "Unable to determine thread blocking for dynamic M."
206
275
register_blocking = self .register_blocking
207
- m_blocks = (self .m + register_blocking . block_m - 1 ) // register_blocking .block_m
208
- n_blocks = (self .n + register_blocking . block_n - 1 ) // register_blocking .block_n
209
- k_blocks = (self .k + register_blocking . block_k - 1 ) // register_blocking .block_k
276
+ m_blocks = math . ceil (self .m / register_blocking .block_m )
277
+ n_blocks = math . ceil (self .n / register_blocking .block_n )
278
+ k_blocks = math . ceil (self .k / register_blocking .block_k )
210
279
factors = get_factors (self .num_threads )
211
280
assert len (factors ) > 0
212
281
@@ -219,26 +288,52 @@ def get_better_blocking(blocking, best_blocking):
219
288
block_n_size = blocking .block_n * register_blocking .block_n
220
289
best_block_m_size = best_blocking .block_m * register_blocking .block_m
221
290
best_block_n_size = best_blocking .block_n * register_blocking .block_n
222
- if block_m_size + block_n_size < best_block_m_size + best_block_n_size :
291
+ if blocking .block_k > best_blocking .block_k :
292
+ best_blocking = blocking
293
+ elif (
294
+ blocking .block_k == best_blocking .block_k
295
+ and block_m_size + block_n_size
296
+ < best_block_m_size + best_block_n_size
297
+ ):
223
298
best_blocking = blocking
224
299
return best_blocking
225
300
226
301
best_blocking = None
227
- # check if we can have a thread-blocking to occupy all threads
228
- for factor in factors :
229
- cofactor = self .num_threads // factor
230
- if n_blocks >= factor and m_blocks >= cofactor :
302
+ # check if we can have a thread-blocking to occupy all threads without k-slicing
303
+ for n_factor in factors :
304
+ m_factor = self .num_threads // n_factor
305
+ if n_blocks >= n_factor and m_blocks >= m_factor :
231
306
blocking = get_blocking (
232
- self . num_threads , factor , m_blocks , n_blocks , k_blocks
307
+ m_factor , n_factor , 1 , m_blocks , n_blocks , k_blocks
233
308
)
234
309
best_blocking = get_better_blocking (blocking , best_blocking )
235
310
236
311
if best_blocking is None :
237
- for factor in factors :
238
- cofactor = self .num_threads // factor
239
- if n_blocks >= factor or m_blocks >= cofactor :
312
+ for k_factor in factors :
313
+ if k_blocks >= k_factor and (
314
+ config .cpp .gemm_max_k_slices == 0
315
+ or k_factor <= config .cpp .gemm_max_k_slices
316
+ ):
317
+ n_factors = get_factors (self .num_threads // k_factor )
318
+ for n_factor in n_factors :
319
+ m_factor = (self .num_threads // k_factor ) // n_factor
320
+ if n_blocks >= n_factor and m_blocks >= m_factor :
321
+ blocking = get_blocking (
322
+ m_factor ,
323
+ n_factor ,
324
+ k_factor ,
325
+ m_blocks ,
326
+ n_blocks ,
327
+ k_blocks ,
328
+ )
329
+ best_blocking = get_better_blocking (blocking , best_blocking )
330
+
331
+ if best_blocking is None :
332
+ for n_factor in factors :
333
+ m_factor = self .num_threads // n_factor
334
+ if n_blocks >= n_factor or m_blocks >= m_factor :
240
335
blocking = get_blocking (
241
- self . num_threads , factor , m_blocks , n_blocks , k_blocks
336
+ m_factor , n_factor , 1 , m_blocks , n_blocks , k_blocks
242
337
)
243
338
best_blocking = get_better_blocking (blocking , best_blocking )
244
339
@@ -327,6 +422,17 @@ def get_occupancy():
327
422
f"Number of threads: { self .num_threads } , occupancy: { get_occupancy ()} " # noqa: G004
328
423
)
329
424
425
+ def maybe_k_slicing (self ):
426
+ if self .num_threads == 1 :
427
+ return False
428
+ if self .is_dynamic_M :
429
+ # TODO(jgong5): perhaps use size hint to decide?
430
+ return True
431
+ register_blocking = self .register_blocking
432
+ k_blocks = math .ceil (self .k / register_blocking .block_k )
433
+ thread_blocking = self .thread_blocking ()
434
+ return k_blocks > thread_blocking .block_k
435
+
330
436
@staticmethod
331
437
def add_choices (
332
438
choices ,
@@ -645,9 +751,11 @@ def bias_add_inner(index):
645
751
646
752
Y_2d : Union [ir .Buffer , ir .ReinterpretView ] = Y
647
753
use_local_acc = (
648
- self .layout .dtype != torch .float or int8_gemm or self .padded_n != self .n
754
+ self .layout .dtype != torch .float
755
+ or int8_gemm
756
+ or self .padded_n != self .n
757
+ or self .maybe_k_slicing ()
649
758
)
650
- acc_buf_name = "local_acc_buf"
651
759
if epilogue_nodes :
652
760
epilogues .extend (epilogue_nodes )
653
761
assert Y .get_numel () == epilogues [- 1 ].get_numel ()
@@ -719,12 +827,13 @@ def bias_add_inner(index):
719
827
reindexers = reindexers ,
720
828
Y_2d = Y_2d ,
721
829
use_local_acc = use_local_acc ,
722
- acc_buf_name = acc_buf_name ,
830
+ maybe_k_slicing = self . maybe_k_slicing () ,
723
831
x_scale = x_scale ,
724
832
x_zp = x_zp ,
725
833
w_scale = w_scale ,
726
834
w_zp = w_zp ,
727
835
acc_buf_dtype = torch .int32 if int8_gemm else torch .float ,
836
+ DTYPE_TO_CPP = DTYPE_TO_CPP ,
728
837
)
729
838
with contextlib .ExitStack () as stack :
730
839
for buf in fake_buffers :
0 commit comments