-
Notifications
You must be signed in to change notification settings - Fork 24.2k
Loading sparse tensors in a DataLoader raises CUDA initialization error since 2.5.0 if you have already initialized CUDA #153143
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
Adding triage review as loook like nobody is looking at the dataloader issue, but feels like a pretty bad regression We would accept the change that fixes it (and adds test) obviously |
This is not a regression, and on 2.4 the short snippet above produces the same error:
Once cuda is initialized you cannot fork process and expect any cuda-related apis (including is_pinned) to work. |
The regression isn't with pinned_memory or how CUDA is initialised (as you point out!), I think it's actually with sparse tensors and from this PR that added support for using pinned_memory with sparse tensors. indices = torch.tensor([[0, 1, 1], [2, 0, 2]], pin_memory=True)
values = torch.tensor([3.0, 4.0, 5.0], pin_memory=False)
size = torch.Size([10] * 2)
tensor = torch.sparse_coo_tensor(indices, values, size) I'm wondering if that TORCH_CHECK above is actually needed. Since it only applies to sparse tensors being loaded (as far as I can tell), and saving a tensor that was pinned, and then loading it returns a tensor that isn't pinned. It's very likely I'm missing something, but can we ever run into a situation where only one of the indices and values are in pinned memory and the other isn't after the tensor is loaded, such that this check would actually fail? |
This is because loading sparse tensors enables validation of sparse tensor invariants
and this is because validation of sparse tensor invariants is disabled by default. Here follows a relevant failure: >>> tensor = torch.sparse_coo_tensor(indices, values, size, check_invariants=True)
RuntimeError: memory pinning of indices (=1) must match memory pinning of values (=0)
It is. Validation of sparse tensor invariants is a general consistency check that is disabled by default for performance reasons, and when enabled it should reveal bugs See also #144687 (comment) and its follow-up comments that I believe are still relevant here. |
So what's the way forward, given that there's pretty much no way to make |
@ngimel can we expose |
Doesn't the is_pinned check if the device is already initialized? |
@albanD the problem is that fork copies the cuda initialization state. So, if cuda is initialized in the main process then in the fork it wrongly also appears as initialized. (IIUC, the purpose of |
|
…ext." As in the title. Fixes #153143 cc alexsamardzic nikitaved cpuhrsch amjames bhosmer jcaip andrewkho divyanshk SsnL VitalyFedyunin dzhulgakov [ghstack-poisoned]
@douglas-boubert on a different angle, you should be able to use "forkserver" as your start method for DataLoader that will be as fast as regular "fork" but won't suffer with this issue. multiprocessing.set_start_method("forkserver") |
🐛 Describe the bug
This is a sequel to the issue described here: #144687
The minimum working example script from the previous issue runs without error on
2.7.0
, but adding in one line that initialises CUDA before the data loader loop causes the same error from the same line as before:This code returns the following error:
Following the example by @pearu in the previous issue (#144687 (comment)) we are able to create a simpler reproducer:
which returns the following error:
Versions
Collecting environment information...
PyTorch version: 2.7.0+cu126
Is debug build: False
CUDA used to build PyTorch: 12.6
ROCM used to build PyTorch: N/A
OS: Rocky Linux release 8.10 (Green Obsidian) (x86_64)
GCC version: (GCC) 11.4.0
Clang version: Could not collect
CMake version: version 3.26.5
Libc version: glibc-2.28
Python version: 3.11.12 | packaged by conda-forge | (main, Apr 10 2025, 22:23:25) [GCC 13.3.0] (64-bit runtime)
Python platform: Linux-4.18.0-553.27.1.el8_10.x86_64-x86_64-with-glibc2.28
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA A100 80GB PCIe
GPU 1: NVIDIA A100 80GB PCIe
GPU 2: NVIDIA A100 80GB PCIe
GPU 3: NVIDIA A100 80GB PCIe
Nvidia driver version: 570.124.06
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Byte Order: Little Endian
CPU(s): 48
On-line CPU(s) list: 0-47
Thread(s) per core: 1
Core(s) per socket: 24
Socket(s): 2
NUMA node(s): 2
Vendor ID: GenuineIntel
CPU family: 6
Model: 85
Model name: Intel(R) Xeon(R) Gold 6248R CPU @ 3.00GHz
Stepping: 7
CPU MHz: 3599.999
CPU max MHz: 4000.0000
CPU min MHz: 1200.0000
BogoMIPS: 6000.00
Virtualization: VT-x
L1d cache: 32K
L1i cache: 32K
L2 cache: 1024K
L3 cache: 36608K
NUMA node0 CPU(s): 0-23
NUMA node1 CPU(s): 24-47
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 cdp_l3 invpcid_single intel_ppin ssbd mba ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm mpx rdt_a avx512f avx512dq rdseed adx smap clflushopt clwb intel_pt avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local dtherm ida arat pln pts pku ospke avx512_vnni md_clear flush_l1d arch_capabilities
Versions of relevant libraries:
[pip3] numpy==2.2.5
[pip3] nvidia-cublas-cu12==12.6.4.1
[pip3] nvidia-cuda-cupti-cu12==12.6.80
[pip3] nvidia-cuda-nvrtc-cu12==12.6.77
[pip3] nvidia-cuda-runtime-cu12==12.6.77
[pip3] nvidia-cudnn-cu12==9.5.1.17
[pip3] nvidia-cufft-cu12==11.3.0.4
[pip3] nvidia-curand-cu12==10.3.7.77
[pip3] nvidia-cusolver-cu12==11.7.1.2
[pip3] nvidia-cusparse-cu12==12.5.4.2
[pip3] nvidia-cusparselt-cu12==0.6.3
[pip3] nvidia-nccl-cu12==2.26.2
[pip3] nvidia-nvjitlink-cu12==12.6.85
[pip3] nvidia-nvtx-cu12==12.6.77
[pip3] torch==2.7.0
[pip3] triton==3.3.0
[conda] numpy 2.2.5 pypi_0 pypi
[conda] nvidia-cublas-cu12 12.6.4.1 pypi_0 pypi
[conda] nvidia-cuda-cupti-cu12 12.6.80 pypi_0 pypi
[conda] nvidia-cuda-nvrtc-cu12 12.6.77 pypi_0 pypi
[conda] nvidia-cuda-runtime-cu12 12.6.77 pypi_0 pypi
[conda] nvidia-cudnn-cu12 9.5.1.17 pypi_0 pypi
[conda] nvidia-cufft-cu12 11.3.0.4 pypi_0 pypi
[conda] nvidia-curand-cu12 10.3.7.77 pypi_0 pypi
[conda] nvidia-cusolver-cu12 11.7.1.2 pypi_0 pypi
[conda] nvidia-cusparse-cu12 12.5.4.2 pypi_0 pypi
[conda] nvidia-cusparselt-cu12 0.6.3 pypi_0 pypi
[conda] nvidia-nccl-cu12 2.26.2 pypi_0 pypi
[conda] nvidia-nvjitlink-cu12 12.6.85 pypi_0 pypi
[conda] nvidia-nvtx-cu12 12.6.77 pypi_0 pypi
[conda] torch 2.7.0 pypi_0 pypi
[conda] triton 3.3.0 pypi_0 pypi
cc @alexsamardzic @nikitaved @pearu @cpuhrsch @amjames @bhosmer @jcaip @andrewkho @divyanshk @ssnl @VitalyFedyunin @dzhulgakov
The text was updated successfully, but these errors were encountered: