8000 [inductor][cpp][gemm] fix perf regression xcit_large_24_p8_224 (#1346… · pytorch/pytorch@c0436c5 · GitHub
[go: up one dir, main page]

Skip to content

Commit c0436c5

Browse files
Jiong Gongpytorchmergebot
authored andcommitted
[inductor][cpp][gemm] fix perf regression xcit_large_24_p8_224 (#134686) (#135438)
Fix #134686. PR #132729 makes GEMM template faster for one of the GEMMs in xcit_large_24_p8_224: SingleProcess AUTOTUNE benchmarking takes 1.7088 seconds and 1.9207 seconds precompiling AUTOTUNE linear_unary(12544x3072, 768x3072, 768) cpp_packed_gemm_2 2.9371 ms 100.0% _linear_pointwise 3.1584 ms 93.0% But it is slower than Aten in the e2e run due to different cache behavior. The access to the input data (12544x3072) is LLC latency bound and bottlenecks seen due to the memory synchronization (data transfers and coherence updates across processors). This PR tries to mitigate the problem by cooperatively loading different chunks of input data from different processors that share the input data. Pull Request resolved: #135438 Approved by: https://github.com/leslie-fang-intel
1 parent 60e8dc4 commit c0436c5

File tree

2 files changed

+46
-53
lines changed
10000

2 files changed

+46
-53
lines changed

torch/_inductor/codegen/cpp_gemm_template.py

Lines changed: 46 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,9 @@
8686
);
8787
const int64_t num_Mc_blocks = (Mr_blocks + Mc_blocks - 1) / Mc_blocks;
8888
const int64_t num_Nc_blocks = (Nr_blocks + Nc_blocks - 1) / Nc_blocks;
89-
const int64_t num_k_slices = (Kr_blocks + Kt_blocks - 1) / Kt_blocks;
89+
const int64_t num_Mt_blocks = (Mr_blocks + Mt_blocks - 1) / Mt_blocks;
90+
const int64_t num_Nt_blocks = (Nr_blocks + Nt_blocks - 1) / Nt_blocks;
91+
const int64_t num_Kt_blocks = (Kr_blocks + Kt_blocks - 1) / Kt_blocks;
9092
{%- else %}
9193
constexpr int64_t M = {{kernel.size(GemmOut, 0)}};
9294
constexpr int64_t Mr_blocks = (M + Mr - 1) / Mr;
@@ -98,7 +100,9 @@
98100
constexpr int64_t Kc_blocks = {{template.cache_blocking().block_k}};
99101
constexpr int64_t num_Mc_blocks = (Mr_blocks + Mc_blocks - 1) / Mc_blocks;
100102
constexpr int64_t num_Nc_blocks = (Nr_blocks + Nc_blocks - 1) / Nc_blocks;
101-
constexpr int64_t num_k_slices = (Kr_blocks + Kt_blocks - 1) / Kt_blocks;
103+
constexpr int64_t num_Mt_blocks = (Mr_blocks + Mt_blocks - 1) / Mt_blocks;
104+
constexpr int64_t num_Nt_blocks = (Nr_blocks + Nt_blocks - 1) / Nt_blocks;
105+
constexpr int64_t num_Kt_blocks = (Kr_blocks + Kt_blocks - 1) / Kt_blocks;
102106
{%- endif %}
103107
104108
// make sure all partitions are assigned
@@ -109,39 +113,53 @@
109113
110114
{%- if maybe_k_slicing %}
111115
std::unique_ptr<std::unique_ptr<{{DTYPE_TO_CPP[acc_buf_dtype]}}[]>[]> local_buf_ptrs;
112-
if (num_k_slices > 1) {
113-
local_buf_ptrs.reset(new std::unique_ptr<{{DTYPE_TO_CPP[acc_buf_dtype]}}[]>[num_Mc_blocks * num_Nc_blocks * num_k_slices]);
116+
if (num_Kt_blocks > 1) {
117+
local_buf_ptrs.reset(new std::unique_ptr<{{DTYPE_TO_CPP[acc_buf_dtype]}}[]>[num_Mc_blocks * num_Nc_blocks * num_Kt_blocks]);
114118
}
115119
{%- endif %}
116120
117121
{%- if num_threads > 1 %}
118122
#pragma omp parallel num_threads({{num_threads}})
119123
{
120124
const int tid = omp_get_thread_num();
121-
int64_t m_block_start, m_block_end, n_block_start, n_block_end, k_block_start, k_block_end;
122-
mm_get_thread_blocks(
123-
tid, Mr_blocks, Nr_blocks, Kr_blocks, Mt_blocks, Nt_blocks, Kt_blocks,
124-
m_block_start, m_block_end, n_block_start, n_block_end, k_block_start, k_block_end);
125-
{%- if maybe_k_slicing %}
126-
const int64_t k_group_id = tid / num_k_slices;
127-
const int64_t k_slice_id = tid % num_k_slices;
128-
{%- endif %}
125+
const int64_t k_group_id = tid / num_Kt_blocks;
126+
const int64_t k_slice_id = tid % num_Kt_blocks;
127+
const int64_t n_group_id = k_group_id / num_Nt_blocks;
128+
const int64_t n_slice_id = k_group_id % num_Nt_blocks;
129+
const int64_t k_block_start = k_slice_id * Kt_blocks;
130+
const int64_t k_block_end = std::min(k_block_start + Kt_blocks, Kr_blocks);
131+
const int64_t n_block_start = n_slice_id * Nt_blocks;
132+
const int64_t n_block_end = std::min(n_block_start + Nt_blocks, Nr_blocks);
133+
const int64_t m_block_start = std::min(n_group_id * Mt_blocks, Mr_blocks);
134+
const int64_t m_block_end = std::min(m_block_start + Mt_blocks, Mr_blocks);
135+
const int64_t num_Mc_blocks_per_thread = (m_block_end - m_block_start + Mc_blocks - 1) / Mc_blocks;
129136
{%- else %}
130137
{
131-
const int tid = 0;
132-
const int64_t m_block_start = 0;
133-
const int64_t m_block_end = Mr_blocks;
134-
const int64_t n_block_start = 0;
135-
const int64_t n_block_end = Nr_blocks;
136-
const int64_t k_block_start = 0;
137-
const int64_t k_block_end = Kr_blocks;
138+
constexpr int tid = 0;
139+
constexpr int64_t k_group_id = 0;
140+
constexpr int64_t k_slice_id = 0;
141+
constexpr int64_t n_group_id = 0;
142+
constexpr int64_t n_slice_id = 0;
143+
constexpr int64_t m_block_start = 0;
144+
constexpr int64_t m_block_end = Mr_blocks;
145+
constexpr int64_t n_block_start = 0;
146+
constexpr int64_t n_block_end = Nr_blocks;
147+
constexpr int64_t k_block_start = 0;
148+
constexpr int64_t k_block_end = Kr_blocks;
149+
{%- if is_dynamic_M %}
150+
const int64_t num_Mc_blocks_per_thread = num_Mc_blocks;
151+
{%- else %}
152+
constexpr int64_t num_Mc_blocks_per_thread = num_Mc_blocks;
153+
{%- endif %}
138154
{%- endif %}
139155
{{ micro_gemm.codegen_init(kernel) }}
140156
{%- if use_local_acc %}
141157
{%- set acc_buf_name = "local_acc_buf" %}
142158
{{ kernel.define_buffer(acc_buf_name, ["Mc_blocks*Mr", "Nc_blocks*Nr"], acc_buf_dtype) }}
143159
{%- endif %}
144-
for (int64_t mc = m_block_start; mc < m_block_end; mc += Mc_blocks) {
160+
for (int64_t mc_block_id = 0; mc_block_id < num_Mc_blocks_per_thread; mc_block_id++) {
161+
const int64_t my_mc_block_id = (mc_block_id + n_slice_id) % num_Mc_blocks_per_thread;
162+
const int64_t mc = m_block_start + my_mc_block_id * Mc_blocks;
145163
const int64_t m_start = mc * Mr;
146164
const int64_t m_end = std::min(std::min(mc + Mc_blocks, m_block_end) * Mr, M);
147165
const int64_t m_size = m_end - m_start;
@@ -173,9 +191,10 @@
173191
}
174192
}
175193
{%- if maybe_k_slicing %}
176-
if (num_k_slices > 1) {
194+
if (num_Kt_blocks > 1) {
177195
const int64_t mxn_cache_block_id = (mc / Mc_blocks) * num_Nc_blocks + nc;
178-
local_buf_ptrs[mxn_cache_block_id * num_k_slices + k_slice_id].reset({{ kernel.release_buffer(acc_buf_name) }});
196+
local_buf_ptrs[mxn_cache_block_id * num_Kt_blocks + k_slice_id].reset(
197+
{{ kernel.release_buffer(acc_buf_name) }});
179198
} else
180199
{%- endif %}
181200
{
@@ -189,14 +208,14 @@
189208
}
190209
}
191210
{%- if maybe_k_slicing %}
192-
if (num_k_slices > 1) {
211+
if (num_Kt_blocks > 1) {
193212
#pragma omp barrier
194213
for (int64_t mc = m_block_start; mc < m_block_end; mc += Mc_blocks) {
195214
// We slice M-dim and each thread in the k-slicing group works on a slice
196215
const int64_t m_start_unsliced = mc * Mr;
197216
const int64_t m_end_unsliced = std::min(std::min(mc + Mc_blocks, m_block_end) * Mr, M);
198217
const int64_t m_size_unsliced = m_end_unsliced - m_start_unsliced;
199-
const int64_t m_slice_size = (m_size_unsliced + num_k_slices - 1) / num_k_slices;
218+
const int64_t m_slice_size = (m_size_unsliced + num_Kt_blocks - 1) / num_Kt_blocks;
200219
const int64_t m_start = std::min(m_start_unsliced + m_slice_size * k_slice_id, m_end_unsliced);
201220
const int64_t m_end = std::min(m_start_unsliced + m_slice_size * (k_slice_id + 1), m_end_unsliced);
202221
const int64_t m_size = m_end - m_start;
@@ -206,9 +225,9 @@
206225
const int64_t n_end = std::min(std::min(nc + Nc_blocks, n_block_end) * Nr, N);
207226
const int64_t n_size = n_end - n_start;
208227
const int64_t mxn_cache_block_id = (mc / Mc_blocks) * num_Nc_blocks + nc;
209-
auto {{acc_buf_name}} = local_buf_ptrs[mxn_cache_block_id * num_k_slices].get();
210-
for (int64_t other_slice = 1; other_slice < num_k_slices; other_slice++) {
211-
auto other_acc = local_buf_ptrs[mxn_cache_block_id * num_k_slices + other_slice].get();
228+
auto {{acc_buf_name}} = local_buf_ptrs[mxn_cache_block_id * num_Kt_blocks].get();
229+
for (int64_t other_slice = 1; other_slice < num_Kt_blocks; other_slice++) {
230+
auto other_acc = local_buf_ptrs[mxn_cache_block_id * num_Kt_blocks + other_slice].get();
212231
for (int64_t m = m_offset; m < m_offset + m_size; m++) {
213232
#pragma omp simd
214233
for (int64_t n = 0; n < n_size; n++) {

torch/_inductor/codegen/cpp_prefix.h

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -884,32 +884,6 @@ void mm_get_cache_blocking(
884884
}
885885
}
886886

887-
inline void mm_get_thread_blocks(
888-
int thread_id,
889-
int64_t M_blocks,
890-
int64_t N_blocks,
891-
int64_t K_blocks,
892-
int64_t Mt_blocks,
893-
int64_t Nt_blocks,
894-
int64_t Kt_blocks,
895-
int64_t& m_block_start,
896-
int64_t& m_block_end,
897-
int64_t& n_block_start,
898-
int64_t& n_block_end,
899-
int64_t& k_block_start,
900-
int64_t& k_block_end) {
901-
int64_t num_Kt = (K_blocks + Kt_blocks - 1) / Kt_blocks;
902-
k_block_start = (thread_id % num_Kt) * Kt_blocks;
903-
k_block_end = std::min(k_block_start + Kt_blocks, K_blocks);
904-
thread_id /= num_Kt;
905-
int64_t num_Nt = (N_blocks + Nt_blocks - 1) / Nt_blocks;
906-
n_block_start = (thread_id % num_Nt) * Nt_blocks;
907-
n_block_end = std::min(n_block_start + Nt_blocks, N_blocks);
908-
thread_id /= num_Nt;
909-
m_block_start = std::min(thread_id * Mt_blocks, M_blocks);
910-
m_block_end = std::min(m_block_start + Mt_blocks, M_blocks);
911-
}
912-
913887
struct amx_tilecfg {
914888
uint8_t palette_id;
915889
uint8_t start_row;

0 commit comments

Comments
 (0)
0