@@ -297,6 +297,51 @@ __global__ void max_pool_backward_nhwc(const scalar_t* top_diff,
297
297
int pwend = p_end (iw, pad_w, pooled_width, stride_w);
298
298
int index_shift = ih * width + iw;
299
299
if ((phstart + 1 != phend) || (pwstart + 1 != pwend)) {
300
+
301
+ #if defined (USE_ROCM)
302
+ #define _MAXh 2
303
+ #define _MAXw 2
304
+ if (phend-phstart<=_MAXh && pwend-pwstart<=_MAXw) {
305
+ int msk[_MAXh][_MAXw];
306
+ scalar_t tpd[_MAXh][_MAXw];
307
+ int cached_index = threadIdx .x ;
308
+ #pragma unroll
309
+ for (int c = channel_offset; c < channels; c += blockDim .x *kernel_stride_C) {
310
+ #pragma unroll
311
+ for (int oh = 0 ; oh < _MAXh; ++oh) {
<
8000
div aria-hidden="true" style="left:-2px" class="position-absolute top-0 d-flex user-select-none DiffLineTableCellParts-module__in-progress-comment-indicator--hx3m3">
312
+ #pragma unroll
313
+ for (int ow = 0 ; ow < _MAXw; ++ow) {
314
+ int oh_ = oh+phstart;
315
+ int ow_ = ow+pwstart;
316
+ const int64_t * ptr_top_mask = top_mask + oh_ * out_stride_h + ow_ * out_stride_w;
317
+ if (oh_ >= phend || ow_ >= pwend) {
318
+ msk[oh][ow] = ~index_shift;
319
+ } else {
320
+ msk[oh][ow] = ptr_top_mask[c*out_stride_c];
321
+ tpd[oh][ow] = top_diff[oh_ * out_stride_h + ow_ * out_stride_w + c*out_stride_c];
322
+ }
323
+ }
324
+ }
325
+
326
+ accscalar_t acm = 0 ;
327
+ #pragma unroll
328
+ for (int oh = 0 ; oh < _MAXh; ++oh) {
329
+ #pragma unroll
330
+ for (int ow = 0 ; ow < _MAXw; ++ow) {
331
+ if (msk[oh][ow] == index_shift) {
332
+ acm += static_cast <accscalar_t >(tpd[oh][ow]);
333
+ }
334
+ }
335
+ }
336
+ out_cached[cached_index] += acm;
337
+ cached_index += blockDim .x ;
338
+ }
339
+ }
340
+ else
341
+ #undef _MAXh
342
+ #undef _MAXw
343
+ #endif
344
+
300
345
for (int oh = phstart; oh < phend; ++oh) {
301
346
for (int ow = pwstart; ow < pwend; ++ow) {
302
347
int cached_index = threadIdx .x ;
0 commit comments