8000 [64-bit] Int64 casting for UpSampleNearest3D (#144865) · pytorch/pytorch@082fab0 · GitHub
[go: up one dir, main page]

Skip to content

Commit 082fab0

Browse files
jataylopytorchmergebot
authored andcommitted
[64-bit] Int64 casting for UpSampleNearest3D (#144865)
Fixes #144855 Follows approach in #141923 to use int64 types to increase INT_MAX limits Pull Request resolved: #144865 Approved by: https://github.com/eqy
1 parent 1c9014a commit 082fab0

File tree

2 files changed

+28
-16
lines changed

2 files changed

+28
-16
lines changed

aten/src/ATen/native/cuda/UpSampleNearest3d.cu

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -55,12 +55,12 @@ __global__ void upsample_nearest3d_out_frame(
5555
float height_scale,
5656
float width_scale) {
5757

58-
int dst_idx = blockIdx.x * blockDim.x + threadIdx.x;
58+
int64_t dst_idx = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
5959
if (dst_idx >= dim_c * dst_dim_d * dst_dim_h * dst_dim_w)
6060
return;
6161

62-
int dst_c_stride = dst_dim_d * dst_dim_h * dst_dim_w;
63-
int src_c_stride = src_dim_d * src_dim_h * src_dim_w;
62+
int64_t dst_c_stride = dst_dim_d * dst_dim_h * dst_dim_w;
63+
int64_t src_c_stride = src_dim_d * src_dim_h * src_dim_w;
6464

6565
int c = (dst_idx / (dst_c_stride)) % dim_c;
6666

@@ -72,7 +72,7 @@ __global__ void upsample_nearest3d_out_frame(
7272
int dst_x = dst_idx % dst_dim_w;
7373
int src_x = nn_compute_source_index_fn(width_scale, dst_x, src_dim_w);
7474

75-
int src_idx = c * src_c_stride + src_z * src_dim_h * src_dim_w +
75+
int64_t src_idx = c * src_c_stride + src_z * src_dim_h * src_dim_w +
7676
src_y * src_dim_w + src_x;
7777
for (int b = 0; b < dim_b; b++) {
7878
output[dst_idx] = input[src_idx];
@@ -100,12 +100,12 @@ __global__ void upsample_nearest3d_backward_out_frame(
100100
float height_scale,
101101
float width_scale) {
102102

103-
int dst_idx = blockIdx.x * blockDim.x + threadIdx.x;
103+
int64_t dst_idx = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
104104
if (dst_idx >= dim_c * dst_dim_d * dst_dim_h * dst_dim_w)
105105
return;
106106

107-
int dst_c_stride = dst_dim_d * dst_dim_h * dst_dim_w;
108-
int src_c_stride = src_dim_d * src_dim_h * src_dim_w;
107+
int64_t dst_c_stride = dst_dim_d * dst_dim_h * dst_dim_w;
108+
int64_t src_c_stride = src_dim_d * src_dim_h * src_dim_w;
109109

110110
int c = (dst_idx / (dst_c_stride)) % dim_c;
111111

@@ -132,7 +132,7 @@ __global__ void upsample_nearest3d_backward_out_frame(
132132
for (int z = src_z; z < src_z_up; z++) {
133133
for (int y = src_y; y < src_y_up; y++) {
134134
for (int x = src_x; x < src_x_up; x++) {
135-
int src_idx = b * dim_c * src_c_stride + c * src_c_stride +
135+
int64_t src_idx = b * dim_c * src_c_stride + c * src_c_stride +
136136
z * src_dim_h * src_dim_w + y * src_dim_w + x;
137137
grad += grad_o[src_idx];
138138
}
@@ -180,9 +180,9 @@ static void upsample_nearest3d_out_cuda_template(
180180
dim3 bdim{std::min<unsigned int>(
181181
at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, MAX_THREADS)};
182182
dim3 gdim{ceil_div(n, bdim.x)};
183-
// safe check for int32 indexing; implicitly restrict launch config for kernel
184-
TORCH_CHECK(output.numel() <= std::numeric_limits<int32_t>::max(),
185-
"upsample_nearest3d only supports output tensors with less than INT_MAX elements, but got ", output.sizes());
183+
// safe check for int64 indexing; implicitly restrict launch config for kernel
184+
TORCH_CHECK(output.numel() <= std::numeric_limits<int64_t>::max(),
185+
"upsample_nearest3d only supports output tensors with less than INT64_MAX elements, but got ", output.sizes());
186186

187187
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
188188
AT_DISPATCH_FLOATING_TYPES_AND3(ScalarType::Half, ScalarType::BFloat16, ScalarType::Byte,input.scalar_type(), "upsample_nearest3d_out_frame", [&] {
@@ -254,11 +254,11 @@ static void upsample_nearest3d_backward_out_cuda_template(
254254
dim3 bdim{std::min<unsigned int>(
255255
at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, MAX_THREADS)};
256256
dim3 gdim{ceil_div(n, bdim.x)};
257-
// safe check for int32 indexing; implicitly restrict launch config for kernel
258-
TORCH_CHECK(grad_input.numel() <= std::numeric_limits<int32_t>::max(),
259-
"upsample_nearest3d_backward only supports input tensors with less than INT_MAX elements, but got ", grad_input.sizes());
260-
TORCH_CHECK(grad_output.numel() <= std::numeric_limits<int32_t>::max(),
261-
"upsample_nearest3d_backward only supports output tensors with less than INT_MAX elements, but got ", grad_output.sizes());
257+
// safe check for int64 indexing; implicitly restrict launch config for kernel
258+
TORCH_CHECK(grad_input.numel() <= std::numeric_limits<int64_t>::max(),
259+
"upsample_nearest3d_backward only supports input tensors with less than INT64_MAX elements, but got ", grad_input.sizes());
260+
TORCH_CHECK(grad_output.numel() <= std::numeric_limits<int64_t>::max(),
261+
"upsample_nearest3d_backward only supports output tensors with less than INT64_MAX elements, but got ", grad_output.sizes());
262262

263263
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
264264
AT_DISPATCH_FLOATING_TYPES_AND3(ScalarType::Half, ScalarType::BFloat16, ScalarType::Byte, grad_output.scalar_type(), "upsample_nearest3d_backward_out_frame", [&] {

test/test_torch.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,18 @@ def rand_byte():
174174
scalar = bytes_to_scalar(bytes_list, dtype, device)
175175
self.assertEqual(scalar.storage().untyped().tolist(), bytes_list)
176176

177+
# For testing in64 support in upsample_nearest3d
178+
@onlyCUDA
179+
@largeTensorTest('56GB', device='cuda')
180+
@dtypes(torch.bfloat16)
181+
@unittest.skipIf(IS_JETSON, "Large tensor tests are too large for Jetson.")
182+
def test_int64_upsample3d(self, device, dtype):
183+
x = torch.ones((1, 256, 16, 720, 1280), dtype=dtype, device=device)
184+
try:
185+
torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')
186+
except Exception as e:
187+
self.fail(f"Unexpected exception raised: {e}")
188+
177189
@dtypes(torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64,
178190
torch.bool, torch.float32, torch.complex64, torch.float64,
179191
torch.complex128, torch.uint16, torch.uint32, torch.uint64)

0 commit comments

Comments
 (0)
0