8000 [CUDA][CUDNN] Dispatch to cuDNN for non-batch-splittable 64-bit NCHW convolutions by eqy · Pull Request #153101 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[CUDA][CUDNN] Dispatch to cuDNN for non-batch-splittable 64-bit NCHW convolutions #153101

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 10 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 16 additions & 24 deletions aten/src/ATen/native/Convolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -467,39 +467,31 @@ struct ConvParams {
// always use cudnn_depthwise for channels_last format
return true;
}
if (detail::getCUDAHooks().supportsDepthwiseConvolutionWithCuDNN()) {
long cudnn_version = detail::getCUDAHooks().versionCuDNN();
if (cudnn_version >= 8200) {
bool kernel_cond = (use_cudnn(input, weight) &&
input.scalar_type() == kHalf && // only for FP16
weight.scalar_type() == kHalf &&
is_depthwise(input, weight) &&
input.ndimension() == 4 && // TODO: 5-D contiguous depthwise is not supported yet, need benchmarks
!is_dilated() && // no dilation supported
(stride[0] == stride[1] || at::symint::size<T>(input, 2) == 1) && // square or 1d
at::symint::size<T>(input, 1) >= 32); // min 32 channels supported)
if (kernel_cond) {
return check_cudnn_depthwise_workload_with_filter<T>(input, stride[1], weight);
}
// native kernel doesn't support 64-bit non-splittable case
if (cudnn_enabled && needs_64bit_indexing_no_split(input, weight)) {
static long cudnn_version = detail::getCUDAHooks().compiledWithCuDNN() ? detail::getCUDAHooks().versionCuDNN() : -1;
if (!(cudnn_version >= 90300 && at::native::cudnnv8_enabled_check_debug())) {
TORCH_WARN_ONCE("cuDNN cannot be used for large non-batch-splittable convolutions"
" if the V8 API is not enabled or before cuDNN version 9.3+."
" Upgrade cuDNN or enable the V8 API to use cuDNN for 64-bit depthwise convolutions.");
return false;
} else {
return true;
}
// keep (7600 <= cudnn < 8200) code unchanged
bool kernel_cond = (cudnn_version >= 7600 &&
use_cudnn(input, weight) &&
}
if (detail::getCUDAHooks().supportsDepthwiseConvolutionWithCuDNN()) {
bool kernel_cond = (use_cudnn(input, weight) &&
input.scalar_type() == kHalf && // only for FP16
weight.scalar_type() == kHalf &&
is_depthwise(input, weight) &&
input.ndimension() == 4 && // TODO: 5-D contiguous depthwise is not supported yet, need benchmarks
at::symint::size<T>(weight, 2) == at::symint::size<T>(weight, 3) && // only square kernels
at::symint::size<T>(input, 2) >= 7 && // min width/height 7
!is_dilated() && // no dilation supported
stride[0] == stride[1] && // equal strides
((at::symint::size<T>(weight, 3) == 3) || (at::symint::size<T>(weight, 3) == 1)) &&
(stride[0] == stride[1] || at::symint::size<T>(input, 2) == 1) && // square or 1d
at::symint::size<T>(input, 1) >= 32); // min 32 channels supported)
if (kernel_cond) {
return check_cudnn_depthwise_workload<T>(input, stride[0]);
} else {
return false;
return check_cudnn_depthwise_workload_with_filter<T>(input, stride[1], weight);
}
return false;
} else {
return false;
}
Expand Down
11 changes: 11 additions & 0 deletions test/nn/test_convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -4052,6 +4052,17 @@ def test_conv3d_64bit_indexing(self, device):
y = m.to(device=device)(x.to(device=device))
self.assertEqual(yref, y)

@skipCUDAIfRocm
@onlyCUDA
@largeTensorTest("20GB")
@largeTensorTest("80GB", "cpu")
def test_depthwise_conv_64bit_indexing(self, device):
x = torch.randn(1, 2, 32800, 32800)
c = nn.Conv2d(2, 2, kernel_size=3, stride=1, padding=1, groups=2)
yref = c(x)
y = c.to(device=device)(x.to(device=device))
self.assertEqual(yref, y)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Huh, this is really numerically stable enough on for assertEqual instead of assertClose?

Copy link
Collaborator Author
@eqy eqy May 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

687A

assertEqual does use tolerances under the hood



instantiate_device_type_tests(TestConvolutionNNDeviceType, globals(), allow_mps=True)
instantiate_parametrized_tests(TestConvolutionNN)
Expand Down
Loading
0