8000 [CUDA][CUDNN] Dispatch to cuDNN for non-batch-splittable 64-bit NCHW … · pytorch/pytorch@ced90d2 · GitHub
[go: up one dir, main page]

Skip to content

Commit ced90d2

Browse files
eqypytorchmergebot
authored andcommitted
[CUDA][CUDNN] Dispatch to cuDNN for non-batch-splittable 64-bit NCHW convolutions (#153101)
For #152816 Pull Request resolved: #153101 Approved by: https://github.com/Skylion007
1 parent 0ce941f commit ced90d2

File tree

2 files changed

+26
-24
lines changed

2 files changed

+26
-24
lines changed

aten/src/ATen/native/Convolution.cpp

Lines changed: 16 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -467,39 +467,31 @@ struct ConvParams {
467467
// always use cudnn_depthwise for channels_last format
468468
return true;
469469
}
470-
if (detail::getCUDAHooks().supportsDepthwiseConvolutionWithCuDNN()) {
471-
long cudnn_version = detail::getCUDAHooks().versionCuDNN();
472-
if (cudnn_version >= 8200) {
473-
bool kernel_cond = (use_cudnn(input, weight) &&
474-
input.scalar_type() == kHalf && // only for FP16
475-
weight.scalar_type() == kHalf &&
476-
is_depthwise(input, weight) &&
477-
input.ndimension() == 4 && // TODO: 5-D contiguous depthwise is not supported yet, need benchmarks
478-
!is_dilated() && // no dilation supported
479-
(stride[0] == stride[1] || at::symint::size<T>(input, 2) == 1) && // square or 1d
480-
at::symint::size<T>(input, 1) >= 32); // min 32 channels supported)
481-
if (kernel_cond) {
482-
return check_cudnn_depthwise_workload_with_filter<T>(input, stride[1], weight);
483-
}
470+
// native kernel doesn't support 64-bit non-splittable case
471+
if (cudnn_enabled && needs_64bit_indexing_no_split(input, weight)) {
472+
static long cudnn_version = detail::getCUDAHooks().versionCuDNN();
473+
if (!(cudnn_version >= 90300 && at::native::cudnnv8_enabled_check_debug())) {
474+
TORCH_WARN_ONCE("cuDNN cannot be used for large non-batch-splittable convolutions"
475+
" if the V8 API is not enabled or before cuDNN version 9.3+."
476+
" Upgrade cuDNN or enable the V8 API to use cuDNN for 64-bit depthwise convolutions."< 8000 /span>);
477+
return false;
478+
} else {
479+
return true;
484480
}
485-
// keep (7600 <= cudnn < 8200) code unchanged
486-
bool kernel_cond = (cudnn_version >= 7600 &&
487-
use_cudnn(input, weight) &&
481+
}
482+
if (detail::getCUDAHooks().supportsDepthwiseConvolutionWithCuDNN()) {
483+
bool kernel_cond = (use_cudnn(input, weight) &&
488484
input.scalar_type() == kHalf && // only for FP16
489485
weight.scalar_type() == kHalf &&
490486
is_depthwise(input, weight) &&
491487
input.ndimension() == 4 && // TODO: 5-D contiguous depthwise is not supported yet, need benchmarks
492-
at::symint::size<T>(weight, 2) == at::symint::size<T>(weight, 3) && // only square kernels
493-
at::symint::size<T>(input, 2) >= 7 && // min width/height 7
494488
!is_dilated() && // no dilation supported
495-
stride[0] == stride[1] && // equal strides
496-
((at::symint::size<T>(weight, 3) == 3) || (at::symint::size<T>(weight, 3) == 1)) &&
489+
(stride[0] == stride[1] || at::symint::size<T>(input, 2) == 1) && // square or 1d
497490
at::symint::size<T>(input, 1) >= 32); // min 32 channels supported)
498491
if (kerne 8000 l_cond) {
499-
return check_cudnn_depthwise_workload<T>(input, stride[0]);
500-
} else {
501-
return false;
492+
return check_cudnn_depthwise_workload_with_filter<T>(input, stride[1], weight);
502493
}
494+
return false;
503495
} else {
504496
return false;
505497
}

test/nn/test_convolution.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4052,6 +4052,16 @@ def test_conv3d_64bit_indexing(self, device):
40524052
y = m.to(device=device)(x.to(device=device))
40534053
self.assertEqual(yref, y)
40544054

4055+
@onlyCUDA
4056+
@largeTensorTest("20GB")
4057+
@largeTensorTest("80GB", "cpu")
4058+
def test_depthwise_conv_64bit_indexing(self, device):
4059+
x = torch.randn(1, 2, 32800, 32800)
4060+
c = nn.Conv2d(2, 2, kernel_size=3, stride=1, padding=1, groups=2)
4061+
yref = c(x)
4062+
y = c.to(device=device)(x.to(device=device))
4063+
self.assertEqual(yref, y)
4064+
40554065

40564066
instantiate_device_type_tests(TestConvolutionNNDeviceType, globals(), allow_mps=True)
40574067
instantiate_parametrized_tests(TestConvolutionNN)

0 commit comments

Comments
 (0)
0