10000 Depthwise Separable Convolutions with Large Tensors (> 2**31) Elements) Fail Despite cuDNN 64-bit Indexing Support · Issue #152816 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Depthwise Separable Convolutions with Large Tensors (> 2**31) Elements) Fail Despite cuDNN 64-bit Indexing Support #152816

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
lely475 opened this issue May 5, 2025 · 6 comments
Assignees
Labels
module: convolution Problems related to convolutions (THNN, THCUNN, CuDNN) module: cuda Related to torch.cuda, and CUDA support in general module: cudnn Related to torch.backends.cudnn, and CuDNN support module: 64-bit Problems related to incorrectly using 32-bit integers when 64-bit is needed (e.g., 8G tensors) triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@lely475
Copy link
lely475 commented May 5, 2025

🐛 Describe the bug

The forward pass on a 2D convolutional layer using grouped convolutions (e.g., depthwise separable convolutions) fails when operating on tensors with more than 2**31 elements. This limitation persists even when cuDNN v9.7.1 is used, which should theoretically support 64-bit indexing for large tensors since PR #134890 ([cuDNN][64-bit indexing] cuDNN v9.3+ supports non-batch-splittable convolutions with > 2**31 elements). Below is a minimal example to reproduce the issue.

import torch
import torch.nn as nn

device = torch.device("cuda")

# Define an extremely large input tensor (exceeding 2**31 elements for a single sample), use grouped (depthwise separable) convolutions
# For example: Batch size = 1, Channels = 2, Height = 32,800, Width = 32,800
# Total elements = 1 * 2 * 32,800 * 32,800 = 2,151,680,000 > 2**31 (2,147,483,648)
num_channels=2
input_tensor = torch.randn(1, num_channels, 32800, 32800, device=device)

# Define a convolution layer
conv_layer = nn.Conv2d(num_channels, num_channels, kernel_size=3, stride=1, padding=1, groups=num_channels).to(device)

# Perform the forward pass
try:
    output_tensor = conv_layer(input_tensor)
    print("Convolution operation completed successfully. Output shape:", output_tensor.shape)
except RuntimeError as e:
    print("Error occurred:", e)

Running the above code produces the following error:

Error occurred: Expected canUse32BitIndexMath(input) && canUse32BitIndexMath(output) to be true, but got false.  (Could this error message be improved?  If so, please report an enhancement request to PyTorch.)

Additional Context:

  • The issue specifically occurs when using depthwise separable convolutions (i.e., groups > 1 in nn.Conv2d). Regular convolutions (groups=1) appear to work as expected with tensors exceeding (2^{31}) elements.
  • This suggests that the fix in PR #134890 does not fully account for grouped convolutions or depthwise separable convolutions.
  • Splitting the tensor further along the batch or channel dimensions is not an option in this case due to the nature of the operation.

Versions

PyTorch version: 2.7.0+cu128
Is debug build: False
CUDA used to build PyTorch: 12.8
ROCM used to build PyTorch: N/A

OS: Ubuntu 24.04.2 LTS (x86_64)
GCC version: (Ubuntu 13.3.0-6ubuntu2~24.04) 13.3.0
Clang version: Could not collect
CMake version: version 3.28.3
Libc version: glibc-2.39

Python version: 3.10.17 | packaged by conda-forge | (main, Apr 10 2025, 22:19:12) [GCC 13.3.0] (64-bit runtime)
Python platform: Linux-6.8.0-55-generic-x86_64-with-glibc2.39
Is CUDA available: True
CUDA runtime version: 12.8.93
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA H200

Nvidia driver version: 570.124.06
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 52 bits physical, 57 bits virtual
Byte Order: Little Endian
CPU(s): 224
On-line CPU(s) list: 0-223
Vendor ID: GenuineIntel
Model name: Intel(R) Xeon(R) Platinum 8480C
CPU family: 6
Model: 143
Thread(s) per core: 2
Core(s) per socket: 56
Socket(s): 2
Stepping: 8
CPU(s) scaling MHz: 29%
CPU max MHz: 3800.0000
CPU min MHz: 800.0000
BogoMIPS: 4000.00
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 cat_l2 cdp_l3 intel_ppin cdp_l2 ssbd mba ibrs ibpb stibp ibrs_enhanced tpr_shadow flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb intel_pt avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local split_lock_detect user_shstk avx_vnni avx512_bf16 wbnoinvd dtherm ida arat pln pts hwp hwp_act_window hwp_epp hwp_pkg_req vnmi avx512vbmi umip pku ospke waitpkg avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg tme avx512_vpopcntdq la57 rdpid bus_lock_detect cldemote movdiri movdir64b enqcmd fsrm md_clear serialize tsxldtrk pconfig arch_lbr ibt amx_bf16 avx512_fp16 amx_tile amx_int8 flush_l1d arch_capabilities
Virtualization: VT-x
L1d cache: 5.3 MiB (112 instances)
L1i cache: 3.5 MiB (112 instances)
L2 cache: 224 MiB (112 instances)
L3 cache: 210 MiB (2 instances)
NUMA node(s): 2
NUMA node0 CPU(s): 0-55,112-167
NUMA node1 CPU(s): 56-111,168-223
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Enhanced / Automatic IBRS; IBPB conditional; RSB filling; PBRSB-eIBRS SW sequence; BHI BHI_DIS_S
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected

Versions of relevant libraries:
[pip3] numpy==2.2.5
[pip3] nvidia-cublas-cu12==12.8.3.14
[pip3] nvidia-cuda-cupti-cu12==12.8.57
[pip3] nvidia-cuda-nvrtc-cu12==12.8.61
[pip3] nvidia-cuda-runtime-cu12==12.8.57
[pip3] nvidia-cudnn-cu12==9.7.1.26
[pip3] nvidia-cufft-cu12==11.3.3.41
[pip3] nvidia-curand-cu12==10.3.9.55
[pip3] nvidia-cusolver-cu12==11.7.2.55
[pip3] nvidia-cusparse-cu12==12.5.7.53
[pip3] nvidia-cusparselt-cu12==0.6.3
[pip3] nvidia-nccl-cu12==2.26.2
[pip3] nvidia-nvjitlink-cu12==12.8.61
[pip3] nvidia-nvtx-cu12==12.8.55
[pip3] optree==0.15.0
[pip3] torch==2.7.0+cu128
[pip3] torchvision==0.21.0
[pip3] torchvision-extra-decoders==0.0.2
[pip3] triton==3.3.0
[conda] cuda-cudart 12.8.90 h5888daf_1 conda-forge
[conda] cuda-cudart_linux-64 12.8.90 h3f2d84a_1 conda-forge
[conda] cuda-cupti 12.8.90 h5888daf_1 conda-forge
[conda] cuda-nvrtc 12.8.93 h5888daf_1 conda-forge
[conda] cuda-nvtx 12.8.90 h5888daf_1 conda-forge
[conda] cudnn 9.8.0.87 h81d5506_1 conda-forge
[conda] libblas 3.9.0 31_hfdb39a5_mkl conda-forge
[conda] libcblas 3.9.0 31_h372d94f_mkl conda-forge
[conda] libcublas 12.8.4.1 h9ab20c4_1 conda-forge
[conda] libcufft 11.3.3.83 h5888daf_1 conda-forge
[conda] libcurand 10.3.9.90 h9ab20c4_1 conda-forge
[conda] libcusolver 11.7.3.90 h9ab20c4_1 conda-forge
[conda] libcusparse 12.5.8.93 h5888daf_1 conda-forge
[conda] liblapack 3.9.0 31_hc41d3b0_mkl conda-forge
[conda] libmagma 2.9.0 h19665d7_1 conda-forge
[conda] libnvjitlink 12.8.93 h5888daf_1 conda-forge
[conda] libtorch 2.6.0 cuda126_mkl_h99b69db_304 conda-forge
[conda] mkl 2024.2.2 ha957f24_16 conda-forge
[conda] nccl 2.26.2.1 ha44e49d_1 conda-forge
[conda] numpy 2.2.5 py310hefbff90_0 conda-forge
[conda] nvidia-cublas-cu12 12.8.3.14 pypi_0 pypi
[conda] nvidia-cuda-cupti-cu12 12.8.57 pypi_0 pypi
[conda] nvidia-cuda-nvrtc-cu12 12.8.61 pypi_0 pypi
[conda] nvidia-cuda-runtime-cu12 12.8.57 pypi_0 pypi
[conda] nvidia-cudnn-cu12 9.7.1.26 pypi_0 pypi
[conda] nvidia-cufft-cu12 11.3.3.41 pypi_0 pypi
[conda] nvidia-curand-cu12 10.3.9.55 pypi_0 pypi
[conda] nvidia-cusolver-cu12 11.7.2.55 pypi_0 pypi
[conda] nvidia-cusparse-cu12 12.5.7.53 pypi_0 pypi
[conda] nvidia-cusparselt-cu12 0.6.3 pypi_0 pypi
[conda] nvidia-nccl-cu12 2.26.2 pypi_0 pypi
[conda] nvidia-nvjitlink-cu12 12.8.61 pypi_0 pypi
[conda] nvidia-nvtx-cu12 12.8.55 pypi_0 pypi
[conda] optree 0.15.0 py310h3788b33_0 conda-forge
[conda] torch 2.7.0+cu128 pypi_0 pypi
[conda] torchvision 0.21.0 cuda126_py310_h4459643_1 conda-forge
[conda] torchvision-extra-decoders 0.0.2 py310h9a3ef1b_2 conda-forge
[conda] triton 3.3.0 pypi_0 pypi

cc @csarofeen @ptrblck @xwang233 @eqy @msaroufim @jerryzh168

@zou3519 zou3519 added module: convolution Problems related to convolutions (THNN, THCUNN, CuDNN) triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: 64-bit Problems related to incorrectly using 32-bit integers when 64-bit is needed (e.g., 8G tensors) module: cudnn Related to torch.backends.cudnn, and CuDNN support module: cuda Related to torch.cuda, and CUDA support in general labels May 6, 2025
@eqy eqy self-assigned this May 7, 2025
@eqy
Copy link
Collaborator
eqy commented May 7, 2025

Sure, I'll test cuDNN to see if this case is supported.

For reference, this was because depthwise convolutions have additional dispatching (to a non-cuDNN native kernel) for performance which is where the failing check comes from.

@lely475
Copy link
Author
lely475 commented May 7, 2025

Thank you for investigating and the reference!

@eqy
Copy link
Collaborator
eqy commented May 7, 2025

Opened #153101 for this issue

Also note that due to the way the heuristics are written, you can probably workaround this issue by using channels last memory format instead (which should give better performance than channels-first anyway) e.g.,

import torch
import torch.nn as nn

device = torch.device("cuda")

# Define an extremely large input tensor (exceeding 2**31 elements for a single sample), use grouped (depthwise separable) convolutions
# For example: Batch size = 1, Channels = 2, Height = 32,800, Width = 32,800
# Total elements = 1 * 2 * 32,800 * 32,800 = 2,151,680,000 > 2**31 (2,147,483,648)
num_channels=2
input_tensor = torch.randn(1, num_channels, 32800, 32800, device=device).to(memory_format=torch.channels_last)

# Define a convolution layer
conv_layer = nn.Conv2d(num_channels, num_channels, kernel_size=3, stride=1, padding=1, groups=num_channels).to(device, memory_format=torch.channels_lasttq)

# Perform the forward pass
try:
    output_tensor = conv_layer(input_tensor)
    print("Convolution operation completed successfully. Output shape:", output_tensor.shape)
except RuntimeError as e:
    print("Error occurred:", e)

@lely475
Copy link
Author
lely475 commented May 8, 2025

It is working now, thanks a lot! Just to make sure: using .to(memory_format=torch.channels_last) doesn't change anything about the convolution operation itself, but rather how the tensors are stored in memory?

@eqy
Copy link
Collaborator
eqy commented May 8, 2025

It is working now, thanks a lot! Just to make sure: using .to(memory_format=torch.channels_last) doesn't change anything about the convolution operation itself, but rather how the tensors are stored in memory?

That's correct. The convention can be a bit confusing, as you will observe that even in "channels last" format, the nominal shape is still in NCHW. However the strides which determine the memory layout will be different.

@lely475
Copy link
Author
lely475 commented May 9, 2025

Great, thank you!

@lely475 lely475 closed this as completed May 9, 2025
pytorchmergebot pushed a commit that referenced this issue May 14, 2025
pytorchmergebot pushed a commit that referenced this issue May 20, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: convolution Problems related to convolutions (THNN, THCUNN, CuDNN) module: cuda Related to torch.cuda, and CUDA support in general module: cudnn Related to torch.backends.cudnn, and CuDNN support module: 64-bit Problems related to incorrectly using 32-bit integers when 64-bit is needed (e.g., 8G tensors) triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

3 participants
0