From 21264361021a0ab0857f1a353eb71ee0549ca11e Mon Sep 17 00:00:00 2001 From: Eddie Yan Date: Wed, 7 May 2025 21:40:40 +0000 Subject: [PATCH 01/10] check in --- aten/src/ATen/native/Convolution.cpp | 13 ++++++++++++- test/nn/test_convolution.py | 9 +++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/aten/src/ATen/native/Convolution.cpp b/aten/src/ATen/native/Convolution.cpp index aa46b72a8b012..bd15571cdc1f9 100644 --- a/aten/src/ATen/native/Convolution.cpp +++ b/aten/src/ATen/native/Convolution.cpp @@ -467,8 +467,19 @@ struct ConvParams { // always use cudnn_depthwise for channels_last format return true; } + static long cudnn_version = detail::getCUDAHooks().versionCuDNN(); + // native kernel doesn't support 64-bit non-splittable case + if (needs_64bit_indexing_no_split(input, weight)) { + if (!(cudnn_version >= 90300 && at::native::cudnnv8_enabled_check_debug())) { + TORCH_WARN_ONCE("cuDNN cannot be used for large non-batch-splittable convolutions" + " if the V8 API is not enabled or before cuDNN version 9.3+." + " Upgrade cuDNN or enable the V8 API to use cuDNN for 64-bit depthwise convolutions."); + return false; + } else { + return true; + } + } if (detail::getCUDAHooks().supportsDepthwiseConvolutionWithCuDNN()) { - long cudnn_version = detail::getCUDAHooks().versionCuDNN(); if (cudnn_version >= 8200) { bool kernel_cond = (use_cudnn(input, weight) && input.scalar_type() == kHalf && // only for FP16 diff --git a/test/nn/test_convolution.py b/test/nn/test_convolution.py index a4b7fe3d40f74..ada8e2f1a1039 100644 --- a/test/nn/test_convolution.py +++ b/test/nn/test_convolution.py @@ -4052,6 +4052,15 @@ def test_conv3d_64bit_indexing(self, device): y = m.to(device=device)(x.to(device=device)) self.assertEqual(yref, y) + @onlyCUDA + @largeTensorTest("20GB") + @largeTensorTest("80GB", "cpu") + def test_depthwise_conv_64bit_indexing(self, device): + x = torch.randn(1, 2, 32800, 32800) + c = nn.Conv2d(2, 2, kernel_size=3, stride=1, padding=1, groups=2) + yref = c(x) + y = c.to(device=device)(x.to(device=device)) + self.assertEqual(yref, y) instantiate_device_type_tests(TestConvolutionNNDeviceType, globals(), allow_mps=True) instantiate_parametrized_tests(TestConvolutionNN) From 8085503275ff60fbb9e2a792cf3b59f3c1edafd6 Mon Sep 17 00:00:00 2001 From: eqy Date: Thu, 8 May 2025 10:05:11 -0700 Subject: [PATCH 02/10] lint --- test/nn/test_convolution.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/nn/test_convolution.py b/test/nn/test_convolution.py index ada8e2f1a1039..477f9c09ef76f 100644 --- a/test/nn/test_convolution.py +++ b/test/nn/test_convolution.py @@ -4052,6 +4052,7 @@ def test_conv3d_64bit_indexing(self, device): y = m.to(device=device)(x.to(device=device)) self.assertEqual(yref, y) + @onlyCUDA @largeTensorTest("20GB") @largeTensorTest("80GB", "cpu") From 55dc302f81ef91e7e5193e54474f52d0ebebec56 Mon Sep 17 00:00:00 2001 From: Eddie Yan Date: Thu, 8 May 2025 19:00:38 +0000 Subject: [PATCH 03/10] lint --- test/nn/test_convolution.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/nn/test_convolution.py b/test/nn/test_convolution.py index 477f9c09ef76f..3a0094c6c98d3 100644 --- a/test/nn/test_convolution.py +++ b/test/nn/test_convolution.py @@ -4052,7 +4052,6 @@ def test_conv3d_64bit_indexing(self, device): y = m.to(device=device)(x.to(device=device)) self.assertEqual(yref, y) - @onlyCUDA @largeTensorTest("20GB") @largeTensorTest("80GB", "cpu") @@ -4063,6 +4062,7 @@ def test_depthwise_conv_64bit_indexing(self, device): y = c.to(device=device)(x.to(device=device)) self.assertEqual(yref, y) + instantiate_device_type_tests(TestConvolutionNNDeviceType, globals(), allow_mps=True) instantiate_parametrized_tests(TestConvolutionNN) From a0279c5eec7b0e1ed7023d5c820f834cdbf960fc Mon Sep 17 00:00:00 2001 From: Eddie Yan Date: Mon, 12 May 2025 23:13:13 +0000 Subject: [PATCH 04/10] make sure enabled --- aten/src/ATen/native/Convolution.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aten/src/ATen/native/Convolution.cpp b/aten/src/ATen/native/Convolution.cpp index bd15571cdc1f9..6292f8700acc8 100644 --- a/aten/src/ATen/native/Convolution.cpp +++ b/aten/src/ATen/native/Convolution.cpp @@ -469,7 +469,7 @@ struct ConvParams { } static long cudnn_version = detail::getCUDAHooks().versionCuDNN(); // native kernel doesn't support 64-bit non-splittable case - if (needs_64bit_indexing_no_split(input, weight)) { + if (needs_64bit_indexing_no_split(input, weight) && detail::getCUDAHooks().compiledWithCuDNN() && cudnn_enabled) { if (!(cudnn_version >= 90300 && at::native::cudnnv8_enabled_check_debug())) { TORCH_WARN_ONCE("cuDNN cannot be used for large non-batch-splittable convolutions" " if the V8 API is not enabled or before cuDNN version 9.3+." From 0d3a436f51e50f95f085b0312c86528d2054bd65 Mon Sep 17 00:00:00 2001 From: Eddie Yan Date: Tue, 13 May 2025 18:00:01 +0000 Subject: [PATCH 05/10] fix condition and move version check inside --- aten/src/ATen/native/Convolution.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/aten/src/ATen/native/Convolution.cpp b/aten/src/ATen/native/Convolution.cpp index 6292f8700acc8..9bdb5c6dad9a6 100644 --- a/aten/src/ATen/native/Convolution.cpp +++ b/aten/src/ATen/native/Convolution.cpp @@ -467,9 +467,9 @@ struct ConvParams { // always use cudnn_depthwise for channels_last format return true; } - static long cudnn_version = detail::getCUDAHooks().versionCuDNN(); // native kernel doesn't support 64-bit non-splittable case - if (needs_64bit_indexing_no_split(input, weight) && detail::getCUDAHooks().compiledWithCuDNN() && cudnn_enabled) { + if (cudnn_enabled && needs_64bit_indexing_no_split(input, weight)) { + static long cudnn_version = detail::getCUDAHooks().versionCuDNN(); if (!(cudnn_version >= 90300 && at::native::cudnnv8_enabled_check_debug())) { TORCH_WARN_ONCE("cuDNN cannot be used for large non-batch-splittable convolutions" " if the V8 API is not enabled or before cuDNN version 9.3+." From 9d0ae34f951e0c880eff15183be68e609f59d8d1 Mon Sep 17 00:00:00 2001 From: eqy Date: Tue, 13 May 2025 19:24:59 -0700 Subject: [PATCH 06/10] Update Convolution.cpp --- aten/src/ATen/native/Convolution.cpp | 28 +++++----------------------- 1 file changed, 5 insertions(+), 23 deletions(-) diff --git a/aten/src/ATen/native/Convolution.cpp b/aten/src/ATen/native/Convolution.cpp index 9bdb5c6dad9a6..fb020965900de 100644 --- a/aten/src/ATen/native/Convolution.cpp +++ b/aten/src/ATen/native/Convolution.cpp @@ -467,6 +467,7 @@ struct ConvParams { // always use cudnn_depthwise for channels_last format return true; } + // native kernel doesn't support 64-bit non-splittable case if (cudnn_enabled && needs_64bit_indexing_no_split(input, weight)) { static long cudnn_version = detail::getCUDAHooks().versionCuDNN(); @@ -480,37 +481,18 @@ struct ConvParams { } } if (detail::getCUDAHooks().supportsDepthwiseConvolutionWithCuDNN()) { - if (cudnn_version >= 8200) { - bool kernel_cond = (use_cudnn(input, weight) && - input.scalar_type() == kHalf && // only for FP16 - weight.scalar_type() == kHalf && - is_depthwise(input, weight) && - input.ndimension() == 4 && // TODO: 5-D contiguous depthwise is not supported yet, need benchmarks - !is_dilated() && // no dilation supported - (stride[0] == stride[1] || at::symint::size(input, 2) == 1) && // square or 1d - at::symint::size(input, 1) >= 32); // min 32 channels supported) - if (kernel_cond) { - return check_cudnn_depthwise_workload_with_filter(input, stride[1], weight); - } - } - // keep (7600 <= cudnn < 8200) code unchanged - bool kernel_cond = (cudnn_version >= 7600 && - use_cudnn(input, weight) && + bool kernel_cond = (use_cudnn(input, weight) && input.scalar_type() == kHalf && // only for FP16 weight.scalar_type() == kHalf && is_depthwise(input, weight) && input.ndimension() == 4 && // TODO: 5-D contiguous depthwise is not supported yet, need benchmarks - at::symint::size(weight, 2) == at::symint::size(weight, 3) && // only square kernels - at::symint::size(input, 2) >= 7 && // min width/height 7 !is_dilated() && // no dilation supported - stride[0] == stride[1] && // equal strides - ((at::symint::size(weight, 3) == 3) || (at::symint::size(weight, 3) == 1)) && + (stride[0] == stride[1] || at::symint::size(input, 2) == 1) && // square or 1d at::symint::size(input, 1) >= 32); // min 32 channels supported) if (kernel_cond) { - return check_cudnn_depthwise_workload(input, stride[0]); - } else { - return false; + return check_cudnn_depthwise_workload_with_filter(input, stride[1], weight); } + return false; } else { return false; } From 52ce0b6ff981a7b96b63c11be8c6abccf4c14e7a Mon Sep 17 00:00:00 2001 From: eqy Date: Tue, 13 May 2025 19:48:05 -0700 Subject: [PATCH 07/10] lint --- aten/src/ATen/native/Convolution.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/aten/src/ATen/native/Convolution.cpp b/aten/src/ATen/native/Convolution.cpp index fb020965900de..628941cbdab87 100644 --- a/aten/src/ATen/native/Convolution.cpp +++ b/aten/src/ATen/native/Convolution.cpp @@ -467,7 +467,6 @@ struct ConvParams { // always use cudnn_depthwise for channels_last format return true; } - // native kernel doesn't support 64-bit non-splittable case if (cudnn_enabled && needs_64bit_indexing_no_split(input, weight)) { static long cudnn_version = detail::getCUDAHooks().versionCuDNN(); From 0acce5009569983e8cbced173e2e2355e1b7819f Mon Sep 17 00:00:00 2001 From: Eddie Yan Date: Wed, 14 May 2025 20:59:11 +0000 Subject: [PATCH 08/10] try to fix condition --- aten/src/ATen/native/Convolution.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aten/src/ATen/native/Convolution.cpp b/aten/src/ATen/native/Convolution.cpp index 628941cbdab87..ffe3f56e55505 100644 --- a/aten/src/ATen/native/Convolution.cpp +++ b/aten/src/ATen/native/Convolution.cpp @@ -469,7 +469,7 @@ struct ConvParams { } // native kernel doesn't support 64-bit non-splittable case if (cudnn_enabled && needs_64bit_indexing_no_split(input, weight)) { - static long cudnn_version = detail::getCUDAHooks().versionCuDNN(); + static long cudnn_version = detail::getCUDAHooks().compiledWithCuDNN() ? detail::getCUDAHooks().versionCuDNN() : -1; if (!(cudnn_version >= 90300 && at::native::cudnnv8_enabled_check_debug())) { TORCH_WARN_ONCE("cuDNN cannot be used for large non-batch-splittable convolutions" " if the V8 API is not enabled or before cuDNN version 9.3+." From 73a621660b1a4651e2cd48ca6944d5beace338a6 Mon Sep 17 00:00:00 2001 From: eqy Date: Wed, 14 May 2025 19:55:49 -0700 Subject: [PATCH 09/10] skip on rocm --- test/nn/test_convolution.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/nn/test_convolution.py b/test/nn/test_convolution.py index 3a0094c6c98d3..0d0ae2210bacf 100644 --- a/test/nn/test_convolution.py +++ b/test/nn/test_convolution.py @@ -4042,6 +4042,7 @@ def test_conv_double_backward_strided_with_3D_input_and_weight(self, device): self.assertEqual(grad_input.shape, input.shape) self.assertEqual(grad_weight.shape, weight.shape) + @skipCUDAIfRocm @onlyCUDA @largeTensorTest("40GB") @largeTensorTest("24GB", "cpu") From 0a9c6daa7b209350c17c235fa01a9d615f96b891 Mon Sep 17 00:00:00 2001 From: eqy Date: Thu, 15 May 2025 07:54:38 -0700 Subject: [PATCH 10/10] fix decorator location --- test/nn/test_convolution.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/nn/test_convolution.py b/test/nn/test_convolution.py index 0d0ae2210bacf..38238405d7bd8 100644 --- a/test/nn/test_convolution.py +++ b/test/nn/test_convolution.py @@ -4042,7 +4042,6 @@ def test_conv_double_backward_strided_with_3D_input_and_weight(self, device): self.assertEqual(grad_input.shape, input.shape) self.assertEqual(grad_weight.shape, weight.shape) - @skipCUDAIfRocm @onlyCUDA @largeTensorTest("40GB") @largeTensorTest("24GB", "cpu") @@ -4053,6 +4052,7 @@ def test_conv3d_64bit_indexing(self, device): y = m.to(device=device)(x.to(device=device)) self.assertEqual(yref, y) + @skipCUDAIfRocm @onlyCUDA @largeTensorTest("20GB") @largeTensorTest("80GB", "cpu")