8000 [ROCm] Maxpool backward NHWC Perf Improvement targeting Resnet scena… · pytorch/pytorch@4015166 · GitHub
[go: up one dir, main page]

Skip to content

Commit 4015166

Browse files
amd-hhashemipytorchmergebot
authored andcommitted
[ROCm] Maxpool backward NHWC Perf Improvement targeting Resnet scenarios (#152267)
Fixes #ISSUE_NUMBER Pull Request resolved: #152267 Approved by: https://github.com/jeffdaily
1 parent 4c5cf18 commit 4015166

File tree

1 file changed

+45
-0
lines changed

1 file changed

+45
-0
lines changed

aten/src/ATen/native/cuda/DilatedMaxPool2d.cu

+45
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,51 @@ __global__ void max_pool_backward_nhwc(const scalar_t* top_diff,
297297
int pwend = p_end(iw, pad_w, pooled_width, stride_w);
298298
int index_shift = ih * width + iw;
299299
if ((phstart + 1 != phend) || (pwstart + 1 != pwend)) {
300+
301+
#if defined (USE_ROCM)
302+
#define _MAXh 2
303+
#define _MAXw 2
304+
if (phend-phstart<=_MAXh && pwend-pwstart<=_MAXw) {
305+
int msk[_MAXh][_MAXw];
306+
scalar_t tpd[_MAXh][_MAXw];
307+
int cached_index = threadIdx.x;
308+
#pragma unroll
309+
for (int c = channel_offset; c < channels; c += blockDim.x*kernel_stride_C) {
310+
#pragma unroll
311+
for(int oh = 0; oh < _MAXh; ++oh) {
< 8000 div aria-hidden="true" style="left:-2px" class="position-absolute top-0 d-flex user-select-none DiffLineTableCellParts-module__in-progress-comment-indicator--hx3m3">
312+
#pragma unroll
313+
for(int ow = 0; ow < _MAXw; ++ow) {
314+
int oh_ = oh+phstart;
315+
int ow_ = ow+pwstart;
316+
const int64_t* ptr_top_mask = top_mask + oh_ * out_stride_h + ow_ * out_stride_w;
317+
if (oh_ >= phend || ow_ >= pwend) {
318+
msk[oh][ow] = ~index_shift;
319+
} else {
320+
msk[oh][ow] = ptr_top_mask[c*out_stride_c];
321+
tpd[oh][ow] = top_diff[oh_ * out_stride_h + ow_ * out_stride_w + c*out_stride_c];
322+
}
323+
}
324+
}
325+
326+
accscalar_t acm = 0;
327+
#pragma unroll
328+
for(int oh = 0; oh < _MAXh; ++oh) {
329+
#pragma unroll
330+
for(int ow = 0; ow < _MAXw; ++ow) {
331+
if (msk[oh][ow] == index_shift) {
332+
acm += static_cast<accscalar_t>(tpd[oh][ow]);
333+
}
334+
}
335+
}
336+
out_cached[cached_index] += acm;
337+
cached_index += blockDim.x;
338+
}
339+
}
340+
else
341+
#undef _MAXh
342+
#undef _MAXw
343+
#endif
344+
300345
for(int oh = phstart; oh < phend; ++oh) {
301346
for(int ow = pwstart; ow < pwend; ++ow) {
302347
int cached_index = threadIdx.x;

0 commit comments

Comments
 (0)
0