8000 [SWDEV-509031] [CP] [64-bit] Int64 casting for UpSampleNearest3D (#14… · ROCm/pytorch@b445b5f · GitHub
[go: up one dir, main page]

Skip to content

Commit b445b5f

Browse files
jataylojithunnair-amd
authored andcommitted
[SWDEV-509031] [CP] [64-bit] Int64 casting for UpSampleNearest3D (pytorch#144865) (#1869)
Fixes pytorch#144855 Follows approach in pytorch#141923 to use int64 types to increase INT_MAX limits Pull Request resolved: pytorch#144865 Approved by: https://github.com/eqy (cherry picked from commit 082fab0) (cherry picked from commit 5d01868)
1 parent 639eef5 commit b445b5f

File tree

2 files changed

+28
-16
lines changed

2 files changed

+28
-16
lines changed

aten/src/ATen/native/cuda/UpSampleNearest3d.cu

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -55,12 +55,12 @@ __global__ void upsample_nearest3d_out_frame(
5555
float height_scale,
5656
float width_scale) {
5757

58-
int dst_idx = blockIdx.x * blockDim.x + threadIdx.x;
58+
int64_t dst_idx = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
5959
if (dst_idx >= dim_c * dst_dim_d * dst_dim_h * dst_dim_w)
6060
return;
6161

62-
int dst_c_stride = dst_dim_d * dst_dim_h * dst_dim_w;
63-
int src_c_stride = src_dim_d * src_dim_h * src_dim_w;
62+
int64_t dst_c_stride = dst_dim_d * dst_dim_h * dst_dim_w;
63+
int64_t src_c_stride = src_dim_d * src_dim_h * src_dim_w;
6464

6565
int c = (dst_idx / (dst_c_stride)) % dim_c;
6666

@@ -72,7 +72,7 @@ __global__ void upsample_nearest3d_out_frame(
7272
int dst_x = dst_idx % dst_dim_w;
7373
int src_x = nn_compute_source_index_fn(width_scale, dst_x, src_dim_w);
7474

75-
int src_idx = c * src_c_stride + src_z * src_dim_h * src_dim_w +
75+
int64_t src_idx = c * src_c_stride + src_z * src_dim_h * src_dim_w +
7676
src_y * src_dim_w + src_x;
7777
for (int b = 0; b < dim_b; b++) {
7878
output[dst_idx] = input[src_idx];
@@ -100,12 +100,12 @@ __global__ void upsample_nearest3d_backward_out_frame(
100100
float height_scale,
101101
float width_scale) {
102102

103-
int dst_idx = blockIdx.x * blockDim.x + threadIdx.x;
103+
int64_t dst_idx = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
104104
if (dst_idx >= dim_c * dst_dim_d * dst_dim_h * dst_dim_w)
105105
return;
106106

107-
int dst_c_stride = dst_dim_d * dst_dim_h * dst_dim_w;
108-
int src_c_stride = src_dim_d * src_dim_h * src_dim_w;
107+
int64_t dst_c_stride = dst_dim_d * dst_dim_h * dst_dim_w;
108+
int64_t src_c_stride = src_dim_d * src_dim_h * src_dim_w;
109109

110110
int c = (dst_idx / (dst_c_stride)) % dim_c;
111111

@@ -132,7 +132,7 @@ __global__ void upsample_nearest3d_backward_out_frame(
132132
for (int z = src_z; z < src_z_up; z++) {
133133
for (int y = src_y; y < src_y_up; y++) {
134134
for (int x = src_x; x < src_x_up; x++) {
135-
int src_idx = b * dim_c * src_c_stride + c * src_c_stride +
135+
int64_t src_idx = b * dim_c * src_c_stride + c * src_c_stride +
136136
z * src_dim_h * src_dim_w + y * src_dim_w + x;
137137
grad += grad_o[src_idx];
138138
}
@@ -180,9 +180,9 @@ static void upsample_nearest3d_out_cuda_template(
180180
dim3 bdim{std::min<unsigned int>(
181181
at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, MAX_THREADS)};
182182
dim3 gdim{ceil_div(n, bdim.x)};
183-
// safe check for int32 indexing; implicitly restrict launch config for kernel
184-
TORCH_CHECK(output.numel() <= std::numeric_limits<int32_t>::max(),
185-
"upsample_nearest3d only supports output tensors with less than INT_MAX elements, but got ", output.sizes());
183+
// safe check for int64 indexing; implicitly restrict launch config for kernel
184+
TORCH_CHECK(output.numel() <= std::numeric_limits<int64_t>::max(),
185+
"upsample_nearest3d only supports output tensors with less than INT64_MAX elements, but got ", output.sizes());
186186

187187
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
188188
AT_DISPATCH_FLOATING_TYPES_AND3(ScalarType::Half, ScalarType::BFloat16, ScalarType::Byte,input.scalar_type(), "upsample_nearest3d_out_frame", [&] {
@@ -254,11 +254,11 @@ static void upsample_nearest3d_backward_out_cuda_template(
254254
dim3 bdim{std::min<unsigned int>(
255255
at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, MAX_THREADS)};
256256
dim3 gdim{ceil_div(n, bdim.x)};
257-
// safe check for int32 indexing; implicitly restrict launch config for kernel
258-
TORCH_CHECK(grad_input.numel() <= std::numeric_limits<int32_t>::max(),
259-
"upsample_nearest3d_backward only supports input tensors with less than INT_MAX elements, but got ", grad_input.sizes());
260-
TORCH_CHECK(grad_output.numel() <= std::numeric_limits<int32_t>::max(),
261-
"upsample_nearest3d_backward only supports output tensors with less than INT_MAX elements, but got ", grad_output.sizes());
257+
// safe check for int64 indexing; implicitly restrict launch config for kernel
258+
TORCH_CHECK(grad_input.numel() <= std::numeric_limits<int64_t>::max(),
259+
"upsample_nearest3d_backward only supports input tensors with less than INT64_MAX elements, but got ", grad_input.sizes());
260+
TORCH_CHECK(grad_output.numel() <= std::numeric_limits<int64_t>::max(),
261+
"upsample_nearest3d_backward only supports output tensors with less than INT64_MAX elements, but got ", grad_output.sizes());
262262

263263
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
264264
AT_DISPATCH_FLOATING_TYPES_AND3(ScalarType::Half, ScalarType::BFloat16, ScalarType::Byte, grad_output.scalar_type(), "upsample_nearest3d_backward_out_frame", [&] {

test/test_torch.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,18 @@ def rand_byte():
175175
scalar = bytes_to_scalar(bytes_list, dtype, device)
176176
self.assertEqual(scalar.storage().untyped().tolist(), bytes_list)
177177

178+
# For testing in64 support in upsample_nearest3d
179+
@onlyCUDA
180+
@largeTensorTest('56GB', device='cuda')
181+
@dtypes(torch.bfloat16)
182+
@unittest.skipIf(IS_JETSON, "Large tensor tests are too large for Jetson.")
183+
def test_int64_upsample3d(self, device, dtype):
184+
x = torch.ones((1, 256, 16, 720, 1280), dtype=dtype, device=device)
185+
try:
186+
torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')
187+
except Exception as e:
188+
self.fail(f"Unexpected exception raised: {e}")
189+
178190
@dtypes(torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64,
179191
torch.bool, torch.float32, torch.complex64, torch.float64,
180192
torch.complex128, torch.uint16, torch.uint32, torch.uint64)

0 commit comments

Comments
 (0)
0