8000 Loading sparse tensors in a DataLoader raises CUDA initialization error since 2.5.0 if you have already initialized CUDA · Issue #153143 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content
Loading sparse tensors in a DataLoader raises CUDA initialization error since 2.5.0 if you have already initialized CUDA #153143
Closed
@douglas-boubert

Description

@douglas-boubert

🐛 Describe the bug

This is a sequel to the issue described here: #144687

The minimum working example script from the previous issue runs without error on 2.7.0, but adding in one line that initialises CUDA before the data loader loop causes the same error from the same line as before:

import torch
from torch.utils.data import Dataset, DataLoader


def create_sparse_tensor():
    tensor = torch.randn(5, 5)
    sparse_tensor = tensor.to_sparse().to("cpu")
    torch.save(sparse_tensor, "sparse_tensor.pth")


class OperatorDataset(Dataset):
    def __init__(self):
        self.files = ["sparse_tensor.pth"]
        
    def __len__(self):
        return len(self.files)
    
    def __getitem__(self, idx):
        _ = torch.load(self.files[idx], weights_only=True, map_location="cpu")
        return None


if __name__ == '__main__':
    print(torch.__version__)

    # Comment out this line to avoid CUDA init error
    torch.zeros(1, device="cuda:0")

    create_sparse_tensor()
    
    dataset = OperatorDataset()
    
    dataloader = DataLoader(
        dataset,
        batch_size=None,
        num_workers=1,
        pin_memory=True,
    )
    
    for sparse_tensor in dataloader:
        # Error raised here
        pass

This code returns the following error:

2.7.0+cu126
Traceback (most recent call last):
  File "/home/douglas/projects/research/research-lethe/old_mwe.py", line 40, in <module>
    for sparse_tensor in dataloader:
  File "/home/douglas/miniconda3/envs/torch_sparse/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 733, in __next__
    data = self._next_data()
           ^^^^^^^^^^^^^^^^^
  File "/home/douglas/miniconda3/envs/torch_sparse/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 1515, in _next_data
    return self._process_data(data, worker_id)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/douglas/miniconda3/envs/torch_sparse/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 1550, in _process_data
    data.reraise()
  File "/home/douglas/miniconda3/envs/torch_sparse/lib/python3.11/site-packages/torch/_utils.py", line 750, in reraise
    raise exception
RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/douglas/miniconda3/envs/torch_sparse/lib/python3.11/site-packages/torch/utils/data/_utils/worker.py", line 349, in _worker_loop
    data = fetcher.fetch(index)  # type: ignore[possibly-undefined]
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/douglas/miniconda3/envs/torch_sparse/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py", line 54, in fetch
    data = self.dataset[possibly_batched_index]
           ~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/douglas/projects/research/research-lethe/old_mwe.py", line 19, in __getitem__
    _ = torch.load(self.files[idx], weights_only=True, map_location="cpu")
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/douglas/miniconda3/envs/torch_sparse/lib/python3.11/site-packages/torch/serialization.py", line 1516, in load
    return _load(
           ^^^^^^
  File "/home/douglas/miniconda3/envs/torch_sparse/lib/python3.11/site-packages/torch/serialization.py", line 2117, in _load
    torch._utils._validate_loaded_sparse_tensors()
  File "/home/douglas/miniconda3/envs/torch_sparse/lib/python3.11/site-packages/torch/_utils.py", line 280, in _validate_loaded_sparse_tensors
    torch._validate_sparse_coo_tensor_args(
RuntimeError: CUDA error: initialization error
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Following the example by @pearu in the previous issue (#144687 (comment)) we are able to create a simpler reproducer:

import torch
import os
torch.zeros(1, device="cuda:0")
t = torch.tensor([1, 2])
print(f'1: {t.is_pinned()}')
os.fork()
print(f'2: {t.is_pinned()}')

which returns the following error:

1: False
2: False
1: False
Traceback (most recent call last):
  File "/home/douglas/projects/research/research-lethe/mini_mwe.py", line 8, in <module>
    print(f'2: {t.is_pinned()}')
                ^^^^^^^^^^^^^
RuntimeError: CUDA error: initialization error
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Versions

Collecting environment information...
PyTorch version: 2.7.0+cu126
Is debug build: False
CUDA used to build PyTorch: 12.6
ROCM used to build PyTorch: N/A

OS: Rocky Linux release 8.10 (Green Obsidian) (x86_64)
GCC version: (GCC) 11.4.0
Clang version: Could not collect
CMake version: version 3.26.5
Libc version: glibc-2.28

Python version: 3.11.12 | packaged by conda-forge | (main, Apr 10 2025, 22:23:25) [GCC 13.3.0] (64-bit runtime)
Python platform: Linux-4.18.0-553.27.1.el8_10.x86_64-x86_64-with-glibc2.28
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA A100 80GB PCIe
GPU 1: NVIDIA A100 80GB PCIe
GPU 2: NVIDIA A100 80GB PCIe
GPU 3: NVIDIA A100 80GB PCIe

Nvidia driver version: 570.124.06
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Byte Order: Little Endian
CPU(s): 48
On-line CPU(s) list: 0-47
Thread(s) per core: 1
Core(s) per socket: 24
Socket(s): 2
NUMA node(s): 2
Vendor ID: GenuineIntel
CPU family: 6
Model: 85
Model name: Intel(R) Xeon(R) Gold 6248R CPU @ 3.00GHz
Stepping: 7
CPU MHz: 3599.999
CPU max MHz: 4000.0000
CPU min MHz: 1200.0000
BogoMIPS: 6000.00
Virtualization: VT-x
L1d cache: 32K
L1i cache: 32K
L2 cache: 1024K
L3 cache: 36608K
NUMA node0 CPU(s): 0-23
NUMA node1 CPU(s): 24-47
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 cdp_l3 invpcid_single intel_ppin ssbd mba ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm mpx rdt_a avx512f avx512dq rdseed adx smap clflushopt clwb intel_pt avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local dtherm ida arat pln pts pku ospke avx512_vnni md_clear flush_l1d arch_capabilities

Versions of relevant libraries:
[pip3] numpy==2.2.5
[pip3] nvidia-cublas-cu12==12.6.4.1
[pip3] nvidia-cuda-cupti-cu12==12.6.80
[pip3] nvidia-cuda-nvrtc-cu12==12.6.77
[pip3] nvidia-cuda-runtime-cu12==12.6.77
[pip3] nvidia-cudnn-cu12==9.5.1.17
[pip3] nvidia-cufft-cu12==11.3.0.4
[pip3] nvidia-curand-cu12==10.3.7.77
[pip3] nvidia-cusolver-cu12==11.7.1.2
[pip3] nvidia-cusparse-cu12==12.5.4.2
[pip3] nvidia-cusparselt-cu12==0.6.3
[pip3] nvidia-nccl-cu12==2.26.2
[pip3] nvidia-nvjitlink-cu12==12.6.85
[pip3] nvidia-nvtx-cu12==12.6.77
[pip3] torch==2.7.0
[pip3] triton==3.3.0
[conda] numpy 2.2.5 pypi_0 pypi
[conda] nvidia-cublas-cu12 12.6.4.1 pypi_0 pypi
[conda] nvidia-cuda-cupti-cu12 12.6.80 pypi_0 pypi
[conda] nvidia-cuda-nvrtc-cu12 12.6.77 pypi_0 pypi
[conda] nvidia-cuda-runtime-cu12 12.6.77 pypi_0 pypi
[conda] nvidia-cudnn-cu12 9.5.1.17 pypi_0 pypi
[conda] nvidia-cufft-cu12 11.3.0.4 pypi_0 pypi
[conda] nvidia-curand-cu12 10.3.7.77 pypi_0 pypi
[conda] nvidia-cusolver-cu12 11.7.1.2 pypi_0 pypi
[conda] nvidia-cusparse-cu12 12.5.4.2 pypi_0 pypi
[conda] nvidia-cusparselt-cu12 0.6.3 pypi_0 pypi
[conda] nvidia-nccl-cu12 2.26.2 pypi_0 pypi
[conda] nvidia-nvjitlink-cu12 12.6.85 pypi_0 pypi
[conda] nvidia-nvtx-cu12 12.6.77 pypi_0 pypi
[conda] torch 2.7.0 pypi_0 pypi
[conda] triton 3.3.0 pypi_0 pypi

cc @alexsamardzic @nikitaved @pearu @cpuhrsch @amjames @bhosmer @jcaip @andrewkho @divyanshk @ssnl @VitalyFedyunin @dzhulgakov

Metadata

Metadata

Assignees

Labels

module: dataloaderRelated to torch.utils.data.DataLoader and Samplermodule: sparseRelated to torch.sparsetriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions

    0