8000 [64-bit][CUDA] Upsample2D 64-bit indexing fix attempt 2 by eqy · Pull Request #141923 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[64-bit][CUDA] Upsample2D 64-bit indexing fix attempt 2 #141923

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 12 additions & 14 deletions aten/src/ATen/native/cuda/UpSampleNearest2d.cu
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,14 @@ __global__ void upsample_nearest2d_out_frame(
float height_scale,
float width_scale) {
size_t nc_iter = threadIdx.z + blockIdx.z * blockDim.z;
int w2 = threadIdx.x + blockIdx.x * blockDim.x;
int h2 = threadIdx.y + blockIdx.y * blockDim.y;
int64_t w2 = ((int64_t) threadIdx.x) + blockIdx.x * blockDim.x;
int64_t h2 = threadIdx.y + blockIdx.y * blockDim.y;

if (w2 >= width2 || h2 >= height2) {
return;
}

int nc_stride = blockDim.z * gridDim.z;
int64_t nc_stride = ((int64_t) blockDim.z) * gridDim.z;

const size_t h1 = height1 == height2
? h2
Expand Down Expand Up @@ -93,9 +93,9 @@ __global__ void upsample_nearest2d_nhwc_out_frame(
float width_scale,
const size_t out_numel) {

const int64_t index = ((int64_t) blockIdx.x) * blockDim.x + threadIdx.x;
const int64_t index = ((int64_t) blockIdx.x) * blockDim.x + threadIdx.x;

if (index < out_numel) {
if (index < out_numel) {
const auto c = index % channels;
const auto w2 = (index / channels) % width2;
const auto h2 = (index / channels / width2) % height2;
Expand Down Expand Up @@ -126,8 +126,8 @@ __global__ void upsample_nearest2d_backward_out_frame(
if (dst_idx >= dim_c * dst_dim_h * dst_dim_w)
return;

int dst_c_stride = dst_dim_h * dst_dim_w;
int src_c_stride = src_dim_h * src_dim_w;
int64_t dst_c_stride = dst_dim_h * dst_dim_w;
int64_t src_c_stride = src_dim_h * src_dim_w;

int c = (dst_idx / (dst_c_stride)) % dim_c;

Expand Down Expand Up @@ -178,7 +178,7 @@ __global__ void upsample_nearest2d_backward_nhwc_out_frame(
// 1 is for grad_output (src)
// 2 is for grad_input (dst)

const int index = blockIdx.x * blockDim.x + threadIdx.x;
const int64_t index = ((int64_t) blockIdx.x) * blockDim.x + threadIdx.x;

if (index < gi_numel) {
const int c = index % channels;
Expand Down Expand Up @@ -250,7 +250,6 @@ static void upsample_nearest2d_out_cuda_template(
AT_DISPATCH_FLOATING_TYPES_AND3(ScalarType::Half, ScalarType::BFloat16, ScalarType::Byte, input.scalar_type(), "upsample_nearest2d_nhwc_out_frame", [&] {
const scalar_t* idata = input.const_data_ptr<scalar_t>();
scalar_t* odata = output.mutable_data_ptr<scalar_t>();

upsample_nearest2d_nhwc_out_frame<scalar_t, nn_compute_source_index_fn>
<<<ceil_div(num_kernels, num_threads), num_threads, 0, at::cuda::getCurrentCUDAStream()>>>(
idata,
Expand All @@ -272,7 +271,7 @@ static void upsample_nearest2d_out_cuda_template(
Tensor output_c = output.is_contiguous() ? output : at::empty(output.sizes(), output.options());
Tensor input = input_.contiguous();

int nc = nbatch * channels;
int64_t nc = nbatch * channels;

const int max_threads = std::min<int>(
at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, MAX_THREADS);
Expand All @@ -293,7 +292,7 @@ static void upsample_nearest2d_out_cuda_template(
int grid_x = ceil_div(output_width, block_x);
int grid_y = ceil_div(output_height, block_y);
int grid_z = std::min<int>(
maxGridSize[2], ceil_div(nc, block_z * 4));
maxGridSize[2], ceil_div(nc, (int64_t) block_z * 4));
const dim3 grid(grid_x, grid_y, grid_z);
// Error out on cases where grid_x & grid_y exceeds limit of launch config, as
// the current kernel implementation doesn't loop over the two dimensions.
Expand All @@ -303,7 +302,6 @@ static void upsample_nearest2d_out_cuda_template(
TORCH_CHECK(
grid_x <= maxGridSize[0] && grid_y <= maxGridSize[1],
"input tensor has spatial dimension larger than the kernel capacity");

cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND3(ScalarType::Half, ScalarType::BFloat16, ScalarType::Byte, input.scalar_type(), "upsample_nearest2d_out_frame", [&] {
using accscalar_t = at::acc_type<scalar_t, true>;
Expand Down Expand Up @@ -404,10 +402,10 @@ static void upsample_nearest2d_backward_out_cuda_template(
Tensor grad_output = grad_output_.contiguous();

// upsample_nearest2d meta call makes sure `nbatch != 0`
unsigned int n = grad_input.numel() / nbatch;
size_t n = grad_input.numel() / nbatch;
dim3 bdim{std::min<unsigned int>(
at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, MAX_THREADS)};
dim3 gdim{ceil_div(n, bdim.x)};
dim3 gdim{(unsigned int) ceil_div(n, (size_t) bdim.x)};
// safe check for int64 indexing; implicitly restrict launch config for kernel
TORCH_CHECK(grad_input.numel() <= std::numeric_limits<int64_t>::max(), "upsample2d grad_input.numel() <= std::numeric_limits<int64_t>::max(), but got ", grad_input.sizes());
TORCH_CHECK(grad_output.numel() <= std::numeric_limits<int64_t>::max(), "upsample2d grad_output.numel() <= std::numeric_limits<int64_t>::max(), but got ", grad_output.sizes());
Expand Down
7 changes: 6 additions & 1 deletion test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9961,7 +9961,8 @@ def test_upsamplingTrilinear3d(self, device, align_corners, memory_format):
gradgradcheck(lambda x: F.interpolate(x, out_size, **kwargs), [input])

@onlyCUDA
@dtypes(torch.half)
@skipCUDAIfRocm(msg="launch bounds error out on ROCM")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@eqy curious do we know if this was a regression on ROCm caused by this PR, or new failures from the dtypes change? cc: @jeffdaily

@dtypes(torch.half, torch.bfloat16)
@largeTensorTest('40GB')
def test_upsampling_64bit_indexing_channels_last(self, device, dtype):
x = torch.rand((32, 64, 512, 512), dtype=dtype, device=device)
Expand All @@ -9970,6 +9971,10 @@ def test_upsampling_64bit_indexing_channels_last(self, device, dtype):
del x
self.assertTrue(torch.allclose(out, out_ref))

x = torch.ones((17, 256, 512, 512), dtype=dtype).cuda().to(memory_format=torch.channels_last)
out = torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')
self.assertEqual(out[0], out[-1])

@onlyCUDA
@dtypes(torch.half)
@largeTensorTest('40GB')
Expand Down
Loading
0