|
86 | 86 | );
|
87 | 87 | const int64_t num_Mc_blocks = (Mr_blocks + Mc_blocks - 1) / Mc_blocks;
|
88 | 88 | const int64_t num_Nc_blocks = (Nr_blocks + Nc_blocks - 1) / Nc_blocks;
|
89 |
| - const int64_t num_k_slices = (Kr_blocks + Kt_blocks - 1) / Kt_blocks; |
| 89 | + const int64_t num_Mt_blocks = (Mr_blocks + Mt_blocks - 1) / Mt_blocks; |
| 90 | + const int64_t num_Nt_blocks = (Nr_blocks + Nt_blocks - 1) / Nt_blocks; |
| 91 | + const int64_t num_Kt_blocks = (Kr_blocks + Kt_blocks - 1) / Kt_blocks; |
90 | 92 | {%- else %}
|
91 | 93 | constexpr int64_t M = {{kernel.size(GemmOut, 0)}};
|
92 | 94 | constexpr int64_t Mr_blocks = (M + Mr - 1) / Mr;
|
|
98 | 100 | constexpr int64_t Kc_blocks = {{template.cache_blocking().block_k}};
|
99 | 101 | constexpr int64_t num_Mc_blocks = (Mr_blocks + Mc_blocks - 1) / Mc_blocks;
|
100 | 102 | constexpr int64_t num_Nc_blocks = (Nr_blocks + Nc_blocks - 1) / Nc_blocks;
|
101 |
| - constexpr int64_t num_k_slices = (Kr_blocks + Kt_blocks - 1) / Kt_blocks; |
| 103 | + constexpr int64_t num_Mt_blocks = (Mr_blocks + Mt_blocks - 1) / Mt_blocks; |
| 104 | + constexpr int64_t num_Nt_blocks = (Nr_blocks + Nt_blocks - 1) / Nt_blocks; |
| 105 | + constexpr int64_t num_Kt_blocks = (Kr_blocks + Kt_blocks - 1) / Kt_blocks; |
102 | 106 | {%- endif %}
|
103 | 107 |
|
104 | 108 | // make sure all partitions are assigned
|
|
109 | 113 |
|
110 | 114 | {%- if maybe_k_slicing %}
|
111 | 115 | std::unique_ptr<std::unique_ptr<{{DTYPE_TO_CPP[acc_buf_dtype]}}[]>[]> local_buf_ptrs;
|
112 |
| - if (num_k_slices > 1) { |
113 |
| - local_buf_ptrs.reset(new std::unique_ptr<{{DTYPE_TO_CPP[acc_buf_dtype]}}[]>[num_Mc_blocks * num_Nc_blocks * num_k_slices]); |
| 116 | + if (num_Kt_blocks > 1) { |
| 117 | + local_buf_ptrs.reset(new std::unique_ptr<{{DTYPE_TO_CPP[acc_buf_dtype]}}[]>[num_Mc_blocks * num_Nc_blocks * num_Kt_blocks]); |
114 | 118 | }
|
115 | 119 | {%- endif %}
|
116 | 120 |
|
117 | 121 | {%- if num_threads > 1 %}
|
118 | 122 | #pragma omp parallel num_threads({{num_threads}})
|
119 | 123 | {
|
120 | 124 | const int tid = omp_get_thread_num();
|
121 |
| - int64_t m_block_start, m_block_end, n_block_start, n_block_end, k_block_start, k_block_end; |
122 |
| - mm_get_thread_blocks( |
123 |
| - tid, Mr_blocks, Nr_blocks, Kr_blocks, Mt_blocks, Nt_blocks, Kt_blocks, |
124 |
| - m_block_start, m_block_end, n_block_start, n_block_end, k_block_start, k_block_end); |
125 |
| - {%- if maybe_k_slicing %} |
126 |
| - const int64_t k_group_id = tid / num_k_slices; |
127 |
| - const int64_t k_slice_id = tid % num_k_slices; |
128 |
| - {%- endif %} |
| 125 | + const int64_t k_group_id = tid / num_Kt_blocks; |
| 126 | + const int64_t k_slice_id = tid % num_Kt_blocks; |
| 127 | + const int64_t n_group_id = k_group_id / num_Nt_blocks; |
| 128 | + const int64_t n_slice_id = k_group_id % num_Nt_blocks; |
| 129 | + const int64_t k_block_start = k_slice_id * Kt_blocks; |
| 130 | + const int64_t k_block_end = std::min(k_block_start + Kt_blocks, Kr_blocks); |
| 131 | + const int64_t n_block_start = n_slice_id * Nt_blocks; |
| 132 | + const int64_t n_block_end = std::min(n_block_start + Nt_blocks, Nr_blocks); |
| 133 | + const int64_t m_block_start = std::min(n_group_id * Mt_blocks, Mr_blocks); |
| 134 | + const int64_t m_block_end = std::min(m_block_start + Mt_blocks, Mr_blocks); |
| 135 | + const int64_t num_Mc_blocks_per_thread = (m_block_end - m_block_start + Mc_blocks - 1) / Mc_blocks; |
129 | 136 | {%- else %}
|
130 | 137 | {
|
131 |
| - const int tid = 0; |
132 |
| - const int64_t m_block_start = 0; |
133 |
| - const int64_t m_block_end = Mr_blocks; |
134 |
| - const int64_t n_block_start = 0; |
135 |
| - const int64_t n_block_end = Nr_blocks; |
136 |
| - const int64_t k_block_start = 0; |
137 |
| - const int64_t k_block_end = Kr_blocks; |
| 138 | + constexpr int tid = 0; |
| 139 | + constexpr int64_t k_group_id = 0; |
| 140 | + constexpr int64_t k_slice_id = 0; |
| 141 | + constexpr int64_t n_group_id = 0; |
| 142 | + constexpr int64_t n_slice_id = 0; |
| 143 | + constexpr int64_t m_block_start = 0; |
| 144 | + constexpr int64_t m_block_end = Mr_blocks; |
| 145 | + constexpr int64_t n_block_start = 0; |
| 146 | + constexpr int64_t n_block_end = Nr_blocks; |
| 147 | + constexpr int64_t k_block_start = 0; |
| 148 | + constexpr int64_t k_block_end = Kr_blocks; |
| 149 | + {%- if is_dynamic_M %} |
| 150 | + const int64_t num_Mc_blocks_per_thread = num_Mc_blocks; |
| 151 | + {%- else %} |
| 152 | + constexpr int64_t num_Mc_blocks_per_thread = num_Mc_blocks; |
| 153 | + {%- endif %} |
138 | 154 | {%- endif %}
8000
|
139 | 155 | {{ micro_gemm.codegen_init(kernel) }}
|
140 | 156 | {%- if use_local_acc %}
|
141 | 157 | {%- set acc_buf_name = "local_acc_buf" %}
|
142 | 158 | {{ kernel.define_buffer(acc_buf_name, ["Mc_blocks*Mr", "Nc_blocks*Nr"], acc_buf_dtype) }}
|
143 | 159 | {%- endif %}
|
144 |
| - for (int64_t mc = m_block_start; mc < m_block_end; mc += Mc_blocks) { |
| 160 | + for (int64_t mc_block_id = 0; mc_block_id < num_Mc_blocks_per_thread; mc_block_id++) { |
| 161 | + const int64_t my_mc_block_id = (mc_block_id + n_slice_id) % num_Mc_blocks_per_thread; |
| 162 | + const int64_t mc = m_block_start + my_mc_block_id * Mc_blocks; |
145 | 163 | const int64_t m_start = mc * Mr;
|
146 | 164 | const int64_t m_end = std::min(std::min(mc + Mc_blocks, m_block_end) * Mr, M);
|
147 | 165 | const int64_t m_size = m_end - m_start;
|
|
173 | 191 | }
|
174 | 192 | }
|
175 | 193 | {%- if maybe_k_slicing %}
|
176 |
| - if (num_k_slices > 1) { |
| 194 | + if (num_Kt_blocks > 1) { |
177 | 195 | const int64_t mxn_cache_block_id = (mc / Mc_blocks) * num_Nc_blocks + nc;
|
178 |
| - local_buf_ptrs[mxn_cache_block_id * num_k_slices + k_slice_id].reset({{ kernel.release_buffer(acc_buf_name) }}); |
| 196 | + local_buf_ptrs[mxn_cache_block_id * num_Kt_blocks + k_slice_id].reset( |
| 197 | + {{ kernel.release_buffer(acc_buf_name) }}); |
179 | 198 | } else
|
180 | 199 | {%- endif %}
|
181 | 200 | {
|
|
189 | 208 | }
|
190 | 209 | }
|
191 | 210 | {%- if maybe_k_slicing %}
|
192 |
| - if (num_k_slices > 1) { |
| 211 | + if (num_Kt_blocks > 1) { |
193 | 212 | #pragma omp barrier
|
194 | 213 | for (int64_t mc = m_block_start; mc < m_block_end; mc += Mc_blocks) {
|
195 | 214 | // We slice M-dim and each thread in the k-slicing group works on a slice
|
196 | 215 | const int64_t m_start_unsliced = mc * Mr;
|
197 | 216 | const int64_t m_end_unsliced = std::min(std::min(mc + Mc_blocks, m_block_end) * Mr, M);
|
198 | 217 | const int64_t m_size_unsliced = m_end_unsliced - m_start_unsliced;
|
199 |
| - const int64_t m_slice_size = (m_size_unsliced + num_k_slices - 1) / num_k_slices; |
| 218 | + const int64_t m_slice_size = (m_size_unsliced + num_Kt_blocks - 1) / num_Kt_blocks; |
200 | 219 | const int64_t m_start = std::min(m_start_unsliced + m_slice_size * k_slice_id, m_end_unsliced);
|
201 | 220 | const int64_t m_end = std::min(m_start_unsliced + m_slice_size * (k_slice_id + 1), m_end_unsliced);
|
202 | 221 | const int64_t m_size = m_end - m_start;
|
|
206 | 225 | const int64_t n_end = std::min(std::min(nc + Nc_blocks, n_block_end) * Nr, N);
|
207 | 226 | const int64_t n_size = n_end - n_start;
|
208 | 227 | const int64_t mxn_cache_block_id = (mc / Mc_blocks) * num_Nc_blocks + nc;
|
209 |
| - auto {{acc_buf_name}} = local_buf_ptrs[mxn_cache_block_id * num_k_slices].get(); |
210 |
| - for (int64_t other_slice = 1; other_slice < num_k_slices; other_slice++) { |
211 |
| - auto other_acc = local_buf_ptrs[mxn_cache_block_id * num_k_slices + other_slice].get(); |
| 228 | + auto {{acc_buf_name}} = local_buf_ptrs[mxn_cache_block_id * num_Kt_blocks].get(); |
| 229 | + for (int64_t other_slice = 1; other_slice < num_Kt_blocks; other_slice++) { |
| 230 | + auto other_acc = local_buf_ptrs[mxn_cache_block_id * num_Kt_blocks + other_slice].get(); |
212 | 231 | for (int64_t m = m_offset; m < m_offset + m_size; m++) {
|
213 | 232 | #pragma omp simd
|
214 | 233 | for (int64_t n = 0; n < n_size; n++) {
|
|
0 commit comments