@@ -55,12 +55,12 @@ __global__ void upsample_nearest3d_out_frame(
55
55
float height_scale,
56
56
float width_scale) {
57
57
58
- int dst_idx = blockIdx .x * blockDim .x + threadIdx .x ;
58
+ int64_t dst_idx = static_cast < int64_t >( blockIdx .x ) * blockDim .x + threadIdx .x ;
59
59
if (dst_idx >= dim_c * dst_dim_d * dst_dim_h * dst_dim_w)
60
60
return ;
61
61
62
- int dst_c_stride = dst_dim_d * dst_dim_h * dst_dim_w;
63
- int src_c_stride = src_dim_d * src_dim_h * src_dim_w;
62
+ int64_t dst_c_stride = dst_dim_d * dst_dim_h * dst_dim_w;
63
+ int64_t src_c_stride = src_dim_d * src_dim_h * src_dim_w;
64
64
65
65
int c = (dst_idx / (dst_c_stride)) % dim_c;
66
66
@@ -72,7 +72,7 @@ __global__ void upsample_nearest3d_out_frame(
72
72
int dst_x = dst_idx % dst_dim_w;
73
73
int src_x = nn_compute_source_index_fn (width_scale, dst_x, src_dim_w);
74
74
75
- int src_idx = c * src_c_stride + src_z * src_dim_h * src_dim_w +
75
+ int64_t src_idx = c * src_c_stride + src_z * src_dim_h * src_dim_w +
76
76
src_y * src_dim_w + src_x;
77
77
for (int b = 0 ; b < dim_b; b++) {
78
78
output[dst_idx] = input[src_idx];
@@ -100,12 +100,12 @@ __global__ void upsample_nearest3d_backward_out_frame(
100
100
float height_scale,
101
101
float width_scale) {
102
102
103
- int dst_idx = blockIdx .x * blockDim .x + threadIdx .x ;
103
+ int64_t dst_idx = static_cast < int64_t >( blockIdx .x ) * blockDim .x + threadIdx .x ;
104
104
if (dst_idx >= dim_c * dst_dim_d * dst_dim_h * dst_dim_w)
105
105
return ;
106
106
107
- int dst_c_stride = dst_dim_d * dst_dim_h * dst_dim_w;
108
- int src_c_stride = src_dim_d * src_dim_h * src_dim_w;
107
+ int64_t dst_c_stride = dst_dim_d * dst_dim_h * dst_dim_w;
108
+ int64_t src_c_stride = src_dim_d * src_dim_h * src_dim_w;
109
109
110
110
int c = (dst_idx / (dst_c_stride)) % dim_c;
111
111
@@ -132,7 +132,7 @@ __global__ void upsample_nearest3d_backward_out_frame(
132
132
for (int z = src_z; z < src_z_up; z++) {
133
133
for (int y = src_y; y < src_y_up; y++) {
134
134
for (int x = src_x; x < src_x_up; x++) {
135
- int src_idx = b * dim_c * src_c_stride + c * src_c_stride +
135
+ int64_t src_idx = b * dim_c * src_c_stride + c * src_c_stride +
136
136
z * src_dim_h * src_dim_w + y * src_dim_w + x;
137
137
grad += grad_o[src_idx];
138
138
}
@@ -180,9 +180,9 @@ static void upsample_nearest3d_out_cuda_template(
180
180
dim3 bdim{std::min<unsigned int >(
181
181
at::cuda::getCurrentDeviceProperties ()->maxThreadsPerBlock , MAX_THREADS)};
182
182
dim3 gdim{ceil_div (n, bdim.x )};
183
- // safe check for int32 indexing; implicitly restrict launch config for kernel
184
- TORCH_CHECK (output.numel () <= std::numeric_limits<int32_t >::max (),
185
- " upsample_nearest3d only supports output tensors with less than INT_MAX elements, but got " , output.sizes ());
183
+ // safe check for int64 indexing; implicitly restrict launch config for kernel
184
+ TORCH_CHECK (output.numel () <= std::numeric_limits<int64_t >::max (),
185
+ " upsample_nearest3d only supports output tensors with less than INT64_MAX elements, but got " , output.sizes ());
186
186
187
187
cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
188
188
AT_DISPATCH_FLOATING_TYPES_AND3 (ScalarType::Half, ScalarType::BFloat16, ScalarType::Byte,input.scalar_type (), " upsample_nearest3d_out_frame" , [&] {
@@ -254,11 +254,11 @@ static void upsample_nearest3d_backward_out_cuda_template(
254
254
dim3 bdim{std::min<unsigned int >(
255
255
at::cuda::getCurrentDeviceProperties ()->maxThreadsPerBlock , MAX_THREADS)};
256
256
dim3 gdim{ceil_div (n, bdim.x )};
257
- // safe check for int32 indexing; implicitly restrict launch config for kernel
258
- TORCH_CHECK (grad_input.numel () <= std::numeric_limits<int32_t >::max (),
259
- " upsample_nearest3d_backward only supports input tensors with less than INT_MAX elements, but got " , grad_input.sizes ());
260
- TORCH_CHECK (grad_output.numel () <= std::numeric_limits<int32_t >::max (),
261
- " upsample_nearest3d_backward only supports output tensors with less than INT_MAX elements, but got " , grad_output.sizes ());
257
+ // safe check for int64 indexing; implicitly restrict launch config for kernel
258
+ TORCH_CHECK (grad_input.numel () <= std::numeric_limits<int64_t >::max (),
259
+ " upsample_nearest3d_backward only supports input tensors with less than INT64_MAX elements, but got " , grad_input.sizes ());
260
+ TORCH_CHECK (grad_output.numel () <= std::numeric_limits<int64_t >::max (),
261
+ " upsample_nearest3d_backward only supports output tensors with less than INT64_MAX elements, but got " , grad_output.sizes ());
262
262
263
263
cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
264
264
AT_DISPATCH_FLOATING_TYPES_AND3 (ScalarType::Half, ScalarType::BFloat16, ScalarType::Byte, grad_output.scalar_type (), " upsample_nearest3d_backward_out_frame" , [&] {
0 commit comments