8000 NCCL out of memory error after updating to PyTorch 2.7 · Issue #152302 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

NCCL out of memory error after updating to PyTorch 2.7 #152302

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
BaconGabe opened this issue Apr 28, 2025 · 17 comments
Open

NCCL out of memory error after updating to PyTorch 2.7 #152302

BaconGabe opened this issue Apr 28, 2025 · 17 comments
Assignees
Labels
module: nccl Problems related to nccl support module: regression It used to work, and now it doesn't oncall: distributed Add this issue/PR to distributed oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Milestone

Comments

@BaconGabe
Copy link
BaconGabe commented Apr 28, 2025

🐛 Describe the bug

After updating to PyTorch 2.7, using init process group with nccl and calling DDP(model, device_ids=[rank]) results in a out of memory error. This makes absolutely no sense because it happens even when I am using extremely small amounts of memory, and DDP with nccl worked perfectly fine before the update on the same code.

Here is the error:

W0428 00:47:04.140000 51980 .venv/lib/python3.12/site-packages/torch/multiprocessing/spawn.py:169] Terminating process 52051 via signal SIGTERM

-- Process 1 terminated with the following error:
Traceback (most recent call last):
  File "/home/.../.venv/lib/python3.12/site-packages/torch/multiprocessing/spawn.py", line 90, in _wrap
    fn(i, *args)
  File "/home/.../example.py", line 39, in demo_basic
    ddp_model = DDP(model, device_ids=[rank])
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/.../.venv/lib/python3.12/site-packages/torch/nn/parallel/distributed.py", line 835, in __init__
    _verify_param_shape_across_processes(self.process_group, parameters)
  File "/home/.../.venv/lib/python3.12/site-packages/torch/distributed/utils.py", line 282, in _verify_param_shape_across_processes
    return dist._verify_params_across_processes(process_group, tensors, logger)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch.distributed.DistBackendError: NCCL error in: /pytorch/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:3353, unhandled cuda error (run with NCCL_DEBUG=INFO for details), NCCL version 2.26.2
ncclUnhandledCudaError: Call to CUDA function failed.
Last error:
Cuda failure 2 'out of memory'

The demo code on how to use DDP provided by PyTorch produces the same error:

import os
import sys
import tempfile
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
import torch.multiprocessing as mp

from torch.nn.parallel import DistributedDataParallel as DDP

def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'

    # initialize the process group
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

def cleanup():
    dist.destroy_process_group()

class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = nn.Linear(10, 10)
        self.relu = nn.ReLU()
        self.net2 = nn.Linear(10, 5)

    def forward(self, x):
        return self.net2(self.relu(self.net1(x)))


def demo_basic(rank, world_size):
    print(f"Running basic DDP example on rank {rank}.")
    setup(rank, world_size)

    # create model and move it to GPU with id rank
    model = ToyModel().to(rank)
    ddp_model = DDP(model, device_ids=[rank]) # HERE IS WHERE THE ERROR OCCURS

    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)

    optimizer.zero_grad()
    outputs = ddp_model(torch.randn(20, 10))
    labels = torch.randn(20, 5).to(rank)
    loss_fn(outputs, labels).backward()
    optimizer.step()

    cleanup()
    print(f"Finished running basic DDP example on rank {rank}.")


def run_demo(demo_fn, world_size):
    mp.spawn(demo_fn,
             args=(world_size,),
             nprocs=world_size,
             join=True)
    
if __name__ == "__main__":
    run_demo(demo_basic, 2)

Versions

PyTorch version: 2.7.0+cu128
Is debug build: False
CUDA used to build PyTorch: 12.8
ROCM used to build PyTorch: N/A

OS: Ubuntu 24.04.2 LTS (x86_64)
GCC version: (Ubuntu 13.3.0-6ubuntu2~24.04) 13.3.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.39

Python version: 3.12.3 (main, Feb 4 2025, 14:48:35) [GCC 13.3.0] (64-bit runtime)
Python platform: Linux-5.15.167.4-microsoft-standard-WSL2-x86_64-with-glibc2.39
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA GeForce RTX 4090
GPU 1: NVIDIA GeForce RTX 5090
GPU 2: NVIDIA GeForce RTX 4090

Nvidia driver version: 576.02
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 48 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 128
On-line CPU(s) list: 0-127
Vendor ID: AuthenticAMD
Model name: AMD Ryzen Threadripper 7980X 64-Cores
CPU family: 25
Model: 24
Thread(s) per core: 2
Core(s) per socket: 64
Socket(s): 1
Stepping: 1
BogoMIPS: 6390.51
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl tsc_reliable nonstop_tsc cpuid extd_apicid pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm cmp_legacy svm cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw topoext perfctr_core invpcid_single ssbd ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves avx512_bf16 clzero xsaveerptr arat npt nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold v_vmsave_vmload avx512vbmi umip avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq rdpid fsrm
Virtualization: AMD-V
Hypervisor vendor: Microsoft
Virtualization type: full
L1d cache: 2 MiB (64 instances)
L1i cache: 2 MiB (64 instances)
L2 cache: 64 MiB (64 instances)
L3 cache: 32 MiB (1 instance)
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec rstack overflow: Mitigation; safe RET
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Retpolines; IBPB conditional; IBRS_FW; STIBP always-on; RSB filling; PBRSB-eIBRS Not affected; BHI Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected

Versions of relevant libraries:
[pip3] numpy==2.1.2
[pip3] nvidia-cublas-cu12==12.8.3.14
[pip3] nvidia-cuda-cupti-cu12==12.8.57
[pip3] nvidia-cuda-nvrtc-cu12==12.8.61
[pip3] nvidia-cuda-runtime-cu12==12.8.57
[pip3] nvidia-cudnn-cu12==9.7.1.26
[pip3] nvidia-cufft-cu12==11.3.3.41
[pip3] nvidia-curand-cu12==10.3.9.55
[pip3] nvidia-cusolver-cu12==11.7.2.55
[pip3] nvidia-cusparse-cu12==12.5.7.53
[pip3] nvidia-cusparselt-cu12==0.6.3
[pip3] nvidia-nccl-cu12==2.26.2
[pip3] nvidia-nvjitlink-cu12==12.8.61
[pip3] nvidia-nvtx-cu12==12.8.55
[pip3] pytorch-lightning==2.5.1.post0
[pip3] pytorch_optimizer==3.5.1
[pip3] torch==2.7.0+cu128
[pip3] torchaudio==2.7.0+cu128
[pip3] torchmetrics==1.7.1
[pip3] torchvision==0.22.0+cu128
[pip3] triton==3.3.0
[conda] Could not collect

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k

@malfet malfet added oncall: distributed Add this issue/PR to distributed oncall triage queue module: nccl Problems related to nccl support module: regression It used to work, and now it doesn't labels Apr 28, 2025
@atalman atalman added this to the 2.7.1 milestone Apr 28, 2025
@fduwjj
Copy link
Contributor
fduwjj commented Apr 28, 2025

@kwen2501 could this be related to the latest optimization you have been brought in for DDP? cc: @fegin

@kwen2501
Copy link
Contributor

Can you please run with NCCL_DEBUG=INFO as the message instructed?
Also, can you please confirm there are no other jobs running on your machine?

If you do not use DDP, i.e. just running the model + backward on a single GPU, do you see the OOM?

@kwen2501
Copy link
Contributor

NCCL_DEBUG_SUBSYS=ALLOC would be helpful too (to see NCCL's allocation, in case the allocation size is corrupted).

@fduwjj
Copy link
Contributor
fduwjj commented Apr 29, 2025

I cannot repro this but maybe let me run a mem profiler.

@BaconGabe
Copy link
Author

@kwen2501 There are no other jobs running on my machine. There is no error when just running the model + backward on a single GPU without DDP. There is also no error when using gloo instead of nccl.

Here is the debug information as requested

Gabriel:62080:62080 [0] NCCL INFO NCCL version 2.26.2+cuda12.2
Gabriel:62081:62081 [1] NCCL INFO NCCL version 2.26.2+cuda12.2
Gabriel:62081:62081 [1] NCCL INFO init.cc:1670 Cuda Host Alloc Size 4 pointer 0x204a00000
Gabriel:62080:62080 [0] NCCL INFO init.cc:1670 Cuda Host Alloc Size 4 pointer 0x204a00000
Gabriel:62080:62233 [0] NCCL INFO Mem Realloc old size 0, new size 256 pointer 0x7f89c47ae780
Gabriel:62081:62234 [1] NCCL INFO Mem Realloc old size 0, new size 256 pointer 0x7f380c7ae780
Gabriel:62080:62244 [0] NCCL INFO Mem Realloc old size 0, new size 32 pointer 0x7f899c004b80
Gabriel:62081:62245 [0] NCCL INFO Mem Realloc old size 0, new size 32 pointer 0x7f37e4004b80
Gabriel:62080:62233 [0] NCCL INFO misc/utils.cc:233 memory stack hunk malloc(65536)
Gabriel:62081:62234 [1] NCCL INFO misc/utils.cc:233 memory stack hunk malloc(65536)
Gabriel:62080:62233 [0] NCCL INFO channel.cc:42 Cuda Alloc Size 1216 pointer 0x2b00000000
Gabriel:62080:62233 [0] NCCL INFO channel.cc:45 Cuda Alloc Size 40 pointer 0x2b00200000
Gabriel:62081:62234 [1] NCCL INFO channel.cc:42 Cuda Alloc Size 1216 pointer 0x2b00000000
Gabriel:62080:62233 [0] NCCL INFO channel.cc:56 Cuda Alloc Size 8 pointer 0x2b00400000
Gabriel:62081:62234 [1] NCCL INFO channel.cc:45 Cuda Alloc Size 40 pointer 0x2b00200000
Gabriel:62080:62233 [0] NCCL INFO channel.cc:42 Cuda Alloc Size 1216 pointer 0x2b00600000
Gabriel:62081:62234 [1] NCCL INFO channel.cc:56 Cuda Alloc Size 8 pointer 0x2b00400000
Gabriel:62080:62233 [0] NCCL INFO channel.cc:45 Cuda Alloc Size 40 pointer 0x2b00800000
Gabriel:62081:62234 [1] NCCL INFO channel.cc:42 Cuda Alloc Size 1216 pointer 0x2b00600000
Gabriel:62080:62233 [0] NCCL INFO channel.cc:56 Cuda Alloc Size 8 pointer 0x2b00a00000
Gabriel:62081:62234 [1] NCCL INFO channel.cc:45 Cuda Alloc Size 40 pointer 0x2b00800000
Gabriel:62080:62233 [0] NCCL INFO init.cc:436 Cuda Alloc Size 23648 pointer 0x2b00c00000
Gabriel:62081:62234 [1] NCCL INFO channel.cc:56 Cuda Alloc Size 8 pointer 0x2b00a00000
Gabriel:62080:62233 [0] NCCL INFO init.cc:438 Cuda Alloc Size 8 pointer 0x2b00e00000
Gabriel:62080:62233 [0] NCCL INFO init.cc:480 Cuda Host Alloc Size 1048576 pointer 0x204a00200
Gabriel:62080:62233 [0] NCCL INFO init.cc:485 Cuda Host Alloc Size 256 pointer 0x204b00200
Gabriel:62080:62233 [0] NCCL INFO init.cc:492 Cuda Host Alloc Size 512 pointer 0x204b00400
Gabriel:62080:62233 [0] NCCL INFO init.cc:493 Cuda Host Alloc Size 512 pointer 0x204b00600
Gabriel:62081:62234 [1] NCCL INFO init.cc:436 Cuda Alloc Size 23648 pointer 0x2b00c00000
Gabriel:62081:62234 [1] NCCL INFO init.cc:438 Cuda Alloc Size 8 pointer 0x2b00e00000
Gabriel:62081:62234 [1] NCCL INFO init.cc:480 Cuda Host Alloc Size 1048576 pointer 0x204a00200
Gabriel:62081:62234 [1] NCCL INFO init.cc:485 Cuda Host Alloc Size 256 pointer 0x204b00200
Gabriel:62081:62234 [1] NCCL INFO init.cc:492 Cuda Host Alloc Size 512 pointer 0x204b00400
Gabriel:62081:62234 [1] NCCL INFO init.cc:493 Cuda Host Alloc Size 512 pointer 0x204b00600
Gabriel:62081:62081 [1] NCCL INFO misc/utils.cc:233 memory stack hunk malloc(65536)
Gabriel:62080:62080 [0] NCCL INFO misc/utils.cc:233 memory stack hunk malloc(65536)
Gabriel:62081:62248 [1] NCCL INFO Mem Realloc old size 0, new size 8 pointer 0x7f37dc004f20
Gabriel:62080:62246 [0] NCCL INFO Mem Realloc old size 0, new size 8 pointer 0x7f8994004f20

[2025-04-28 21:28:40] Gabriel:62081:62248 [1] include/alloc.h:65 NCCL WARN Cuda failure 2 'out of memory'
Gabriel:62081:62248 [1] NCCL INFO transport/shm.cc:534 -> 1
Gabriel:62081:62248 [1] NCCL INFO transport/shm.cc:491 -> 1
Gabriel:62081:62250 [1] NCCL INFO transport/shm.cc:153 -> 1
Gabriel:62081:62250 [1] NCCL INFO transport.cc:35 -> 1
Gabriel:62081:62250 [1] NCCL INFO transport.cc:147 -> 1
Gabriel:62081:62250 [1] NCCL INFO transport/generic.cc:19 -> 1
Gabriel:62081:62250 [1] NCCL INFO group.cc:148 -> 1
Gabriel:62081:62081 [1] NCCL INFO group.cc:460 -> 1
Gabriel:62081:62081 [1] NCCL INFO group.cc:581 -> 1
Gabriel:62081:62081 [1] NCCL INFO enqueue.cc:2299 -> 1
Gabriel:62080:62246 [0] NCCL INFO CUMEM Host Alloc Size 10485760 pointer 0x2b01000000 handle 7f8994008cd0 numa 0 dev 0 granularity 2097152
Gabriel:62080:62246 [0] NCCL INFO CUMEM Host Alloc Size 10485760 pointer 0x2b02000000 handle 7f899400a6a0 numa 0 dev 0 granularity 2097152
Gabriel:62080:62246 [0] NCCL INFO CUMEM Host Alloc Size 2097152 pointer 0x2b02a00000 handle 7f899400c3e0 numa 0 dev 0 granularity 2097152
Gabriel:62080:62246 [0] NCCL INFO CUMEM Host Alloc Size 2097152 pointer 0x2b02c00000 handle 7f899400dc90 numa 0 dev 0 granularity 2097152
Gabriel:62081:62253 [1] NCCL INFO misc/socket.cc:64 -> 3
Gabriel:62081:62248 [1] NCCL INFO misc/socket.cc:881 -> 3
Gabriel:62081:62253 [1] NCCL INFO misc/socket.cc:80 -> 3
Gabriel:62081:62253 [1] NCCL INFO misc/socket.cc:829 -> 3
W0428 21:28:40.927000 62009 .venv/lib/python3.12/site-packages/torch/multiprocessing/spawn.py:169] Terminating process 62080 via signal SIGTERM

@fduwjj
Copy link
Contributor
fduwjj commented Apr 29, 2025

I did a brief memory profiling of your example code. Below is what is shown in the result:

PT 2.7 above:
Image

PT 2.6 release:
Image

I didn't see any memory allocation difference. We are using A100 for this benchmark.

@fduwjj fduwjj added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Apr 29, 2025
@kwen2501
Copy link
Contributor

Thanks for the logs.
At this point it may be helpful to involve the NCCL team.
cc @eqy @kiskra-nvidia :
The user saw OOM as pytorch upgrades from NCCL 2.21 to 2.26.
This is SHM, two GPUs, over PCI-e. Thanks!

@kiskra-nvidia
Copy link

@BaconGabe Is this running under WSL or is it bare-metal Linux? Is anything like Docker being used?

Can you run the following commands in the same environment that you use to launch Python?

nvidia-smi topo -m
numactl -H

Given where it fails, running with NCCL_CUMEM_HOST_ENABLE=0 should unblock you. I'm just a little concerned by the nature of the failure though. We've been seeing more of these cuMem host allocation failures lately and in NCCL 2.26.5 (out later today) we added transparent detection code with a fall back. But that new code won't trigger in your case because with things like Docker we normally see it fail earlier, in cuMemCreate, whereas for you it fails in cuMemSetAccess, and we don't check for that...

Actually, it's even worse -- I see that some cuMem host allocations succeed... Could it be GPU-dependent? We definitely need to see the output of the above commands!

@kwen2501
Copy link
Contributor

We've been seeing more of these cuMem host allocation failures lately and in NCCL 2.26.5 (out later today) we added transparent detection code with a fall back.

@atalman @malfet:
If a torch 2.7.1 is planned, shall we consider this NCCL patch version 2.26.5? We are currently on 2.26.2.

@kwen2501
Copy link
Contributor
kwen2501 commented Apr 29, 2025

@kiskra-nvidia Thanks for the information!

In the info collection, I do see:
Python platform: Linux-5.15.167.4-microsoft-standard-WSL2-x86_64-with-glibc2.39

So yes, it seems WSL is used.

@BaconGabe
Copy link
Author

@kiskra-nvidia Yes it is WSL running Ubuntu, docker is not being used. Here are the commands as requested

NCCL_CUMEM_HOST_ENABLE=0

Running basic DDP example on rank 0.
Running basic DDP example on rank 1.
Finished running basic DDP example on rank 1.
Finished running basic DDP example on rank 0.

It completed successfully after setting this variable without an error!

nvidia-smi topo -m

        GPU0    GPU1    GPU2    CPU Affinity    NUMA Affinity   GPU NUMA ID
GPU0     X      SYS     SYS                             N/A
GPU1    SYS      X      SYS                             N/A
GPU2    SYS     SYS      X                              N/A

Legend:

  X    = Self
  SYS  = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
  NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
  PHB  = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
  PXB  = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
  PIX  = Connection traversing at most a single PCIe bridge
  NV#  = Connection traversing a bonded set of # NVLinks

numactl -H

No NUMA available on this system

@kwen2501 kwen2501 self-assigned this Apr 29, 2025
@kiskra-nvidia
Copy link

@BaconGabe Could you do me a favor and compile/run the below in your environment? I'm trying to understand what does and what does not work...

Compile it with something like: g++ -I$CUDA_HOME/include -L$CUDA_HOME/lib64 -o chm chm.cc -lcuda -lcudart

#include <cstdio>
#include <cstdlib>
#include <cuda.h>
#include <cuda_runtime.h>

#define CUDACHECK(cmd) do {                                   \
    cudaError_t err = cmd;                                    \
    if( err != cudaSuccess ) {                                \
        printf("Cuda failure '%s'", cudaGetErrorString(err)); \
        exit(1);                                              \
    }                                                         \
} while(false)

#define CUCHECK(cmd) do {                                     \
    CUresult err = cmd;                                       \
    if( err != CUDA_SUCCESS ) {                               \
      const char *errStr;                                     \
      (void)cuGetErrorString(err, &errStr);                   \
      printf("Cuda failure %d '%s'", err, errStr);            \
      exit(1);                                                \
    }                                                         \
} while(false)

#define ALIGN_SIZE(size, align) \
  size = ((size + (align) - 1) / (align)) * (align);

int main() {
  int nDev;
  CUDACHECK(cudaGetDeviceCount(&nDev));
  printf("nDev: %d\n", nDev);
  for (int i = 0; i < nDev; i++) {
    CUdevice currentDev;
    int cpuNumaNodeId = -2;
    CUmemAllocationProp prop = {};
    CUmemAccessDesc accessDesc = {};
    CUmemGenericAllocationHandle handle;
    void* ptr;
    size_t size = 10485760, granularity;
    CUresult res;

    CUCHECK(cuDeviceGet(&currentDev, i));
    CUCHECK(cuDeviceGetAttribute(&cpuNumaNodeId, CU_DEVICE_ATTRIBUTE_HOST_NUMA_ID, currentDev));
    printf("i %d, currentDev %
8000
d, cpuNumaNodeId %d\n", i, currentDev, cpuNumaNodeId);
    if (cpuNumaNodeId < 0) cpuNumaNodeId = 0;

    prop.location.type = CU_MEM_LOCATION_TYPE_HOST_NUMA;
    prop.type = CU_MEM_ALLOCATION_TYPE_PINNED;
    prop.requestedHandleTypes = CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR;
    prop.location.id = cpuNumaNodeId;
    CUCHECK(cuMemGetAllocationGranularity(&granularity, &prop, CU_MEM_ALLOC_GRANULARITY_MINIMUM));
    ALIGN_SIZE(size, granularity);
    printf("size %zd, granularity %zd\n", size, granularity);

    /* Allocate the physical memory on the device */
    CUCHECK(cuMemCreate(&handle, size, &prop, 0));
    /* Reserve a virtual address range */
    CUCHECK(cuMemAddressReserve((CUdeviceptr*)&ptr, size, granularity, 0, 0));
    /* Map the virtual address range to the physical allocation */
    CUCHECK(cuMemMap((CUdeviceptr)ptr, size, 0, handle, 0));
    printf("cuMemMap successful; ptr %p\n", ptr);

    /* Now allow RW access to the newly mapped memory for local GPU */
    accessDesc.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
    accessDesc.location.id = i;
    accessDesc.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
    CUCHECK(cuMemSetAccess((CUdeviceptr)ptr, size, &accessDesc, 1));
    printf("cuMemSetAccess successful device-side\n");

    /* Now allow RW access to the newly mapped memory from the CPU */
    accessDesc.location.type = CU_MEM_LOCATION_TYPE_HOST_NUMA;
    accessDesc.location.id = cpuNumaNodeId;
    accessDesc.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
retry:
    if ((res = cuMemSetAccess((CUdeviceptr)ptr, size, &accessDesc, 1)) == CUDA_SUCCESS) {
      printf("cuMemSetAccess successful host-side\n");
    } else {
      printf("cuMemSetAccess host-side failed with %d\n", res);
      if (accessDesc.location.id != 0) {
        accessDesc.location.id = 0;
        printf("retrying with location.id 0\n");
        goto retry;
      }
      if (accessDesc.location.type == CU_MEM_LOCATION_TYPE_HOST_NUMA) {
        accessDesc.location.type = CU_MEM_LOCATION_TYPE_HOST;
        printf("retrying with location.type CU_MEM_LOCATION_TYPE_HOST\n");
        goto retry;
      }
    }
    printf("\n");
  }
  return 0;
}

@BaconGabe
Copy link
Author

@kiskra-nvidia Here is the output of the compiled program as requested.

nDev: 3
i 0, currentDev 0, cpuNumaNodeId 0
size 10485760, granularity 2097152
cuMemMap successful; ptr 0x1f00000000
cuMemSetAccess successful device-side
cuMemSetAccess successful host-side

i 1, currentDev 1, cpuNumaNodeId 0
size 10485760, granularity 2097152
cuMemMap successful; ptr 0x1f00a00000
Cuda failure 2 'out of memory'

@kiskra-nvidia
Copy link

Sorry, it failed in an earlier spot than I expected so it didn't quite complete all the tasks that I wanted it to. Could you do it one more time with a (hopefully) improved version below? Thanks!

#include <cstdio>
#include <cstdlib>
#include <cuda.h>
#include <cuda_runtime.h>

#define CUDACHECK(cmd) do {                                   \
    cudaError_t err = cmd;                                    \
    if( err != cudaSuccess ) {                                \
        printf("Cuda failure '%s'", cudaGetErrorString(err)); \
        exit(1);                                              \
    }                                                         \
} while(false)

#define CUCHECK(cmd) do {                                     \
    CUresult err = cmd;                                       \
    if( err != CUDA_SUCCESS ) {                               \
      const char *errStr;                                     \
      (void)cuGetErrorString(err, &errStr);                   \
      printf("Cuda failure %d '%s'", err, errStr);            \
      exit(1);                                                \
    }                                                         \
} while(false)

#define ALIGN_SIZE(size, align) \
  size = ((size + (align) - 1) / (align)) * (align);

int main() {
  int nDev;
  CUDACHECK(cudaGetDeviceCount(&nDev));
  printf("nDev: %d\n", nDev);
  for (int i = 0; i < nDev; i++) {
    CUdevice currentDev;
    int cpuNumaNodeId = -2;
    CUmemAllocationProp prop = {};
    CUmemAccessDesc accessDesc = {};
    CUmemGenericAllocationHandle handle;
    void* ptr;
    size_t size = 10485760, granularity;
    CUresult res;

    CUCHECK(cuDeviceGet(&currentDev, i));
    CUCHECK(cuDeviceGetAttribute(&cpuNumaNodeId, CU_DEVICE_ATTRIBUTE_HOST_NUMA_ID, currentDev));
    printf("i %d, currentDev %d, cpuNumaNodeId %d\n", i, currentDev, cpuNumaNodeId);
    if (cpuNumaNodeId < 0) cpuNumaNodeId = 0;

    prop.location.type = CU_MEM_LOCATION_TYPE_HOST_NUMA;
    prop.type = CU_MEM_ALLOCATION_TYPE_PINNED;
    prop.requestedHandleTypes = CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR;
    prop.location.id = cpuNumaNodeId;
    CUCHECK(cuMemGetAllocationGranularity(&granularity, &prop, CU_MEM_ALLOC_GRANULARITY_MINIMUM));
    ALIGN_SIZE(size, granularity);
    printf("size %zd, granularity %zd\n", size, granularity);

    /* Allocate the physical memory on the device */
    CUCHECK(cuMemCreate(&handle, size, &prop, 0));
    /* Reserve a virtual address range */
    CUCHECK(cuMemAddressReserve((CUdeviceptr*)&ptr, size, granularity, 0, 0));
    /* Map the virtual address range to the physical allocation */
    CUCHECK(cuMemMap((CUdeviceptr)ptr, size, 0, handle, 0));
    printf("cuMemMap successful; ptr %p\n", ptr);

    /* Now allow RW access to the newly mapped memory for local GPU */
    accessDesc.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
    accessDesc.location.id = i;
    accessDesc.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
retry1:
    if ((res = cuMemSetAccess((CUdeviceptr)ptr, size, &accessDesc, 1)) == CUDA_SUCCESS) {
      printf("cuMemSetAccess successful device-side\n");
    } else {
      printf("cuMemSetAccess device-side failed with %d\n", res);
      if (accessDesc.location.id != 0) {
        accessDesc.location.id = 0;
        printf("retrying with location.id 0\n");
        goto retry1;
      }
    }

    /* Now allow RW access to the newly mapped memory from the CPU */
    accessDesc.location.type = CU_MEM_LOCATION_TYPE_HOST_NUMA;
    accessDesc.location.id = cpuNumaNodeId;
    accessDesc.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
retry2:
    if ((res = cuMemSetAccess((CUdeviceptr)ptr, size, &accessDesc, 1)) == CUDA_SUCCESS) {
      printf("cuMemSetAccess successful host-side\n");
    } else {
      printf("cuMemSetAccess host-side failed with %d\n", res);
      if (accessDesc.location.id != 0) {
        accessDesc.location.id = 0;
        printf("retrying with location.id 0\n");
        goto retry2;
      }
      if (accessDesc.location.type == CU_MEM_LOCATION_TYPE_HOST_NUMA) {
        accessDesc.location.type = CU_MEM_LOCATION_TYPE_HOST;
        printf("retrying with location.type CU_MEM_LOCATION_TYPE_HOST\n");
        goto retry2;
      }
    }
    printf("\n");
  }
  return 0;
}

@BaconGabe
Copy link
Author

@kiskra-nvidia Sorry for late reply, here is the output

nDev: 3
i 0, currentDev 0, cpuNumaNodeId 0
size 10485760, granularity 2097152
cuMemMap successful; ptr 0x1f00000000
cuMemSetAccess successful device-side
cuMemSetAccess successful host-side

i 1, currentDev 1, cpuNumaNodeId 0
size 10485760, granularity 2097152
cuMemMap successful; ptr 0x1f00a00000
cuMemSetAccess device-side failed with 2
retrying with location.id 0
cuMemSetAccess successful device-side
cuMemSetAccess successful host-side

i 2, currentDev 2, cpuNumaNodeId 0
size 10485760, granularity 2097152
cuMemMap successful; ptr 0x1f01400000
cuMemSetAccess device-side failed with 2
retrying with location.id 0
cuMemSetAccess successful device-side
cuMemSetAccess successful host-side

@btrude
Copy link
btrude commented May 10, 2025

I have recently upgraded to 2x 5090 from 2x 4090. FSDP2 code that I was previously using for the 4090s fails with this same error as the OP for any and all collective communications, including calls to dist.barrier(). I would post the logs here, but they are quite literally exactly the same as the OP. Note that this is WSL w/o Docker as well and that there are no issues when running with only one GPU (although this seems obvious given that no distributed calls are made). I have tried 2.7 and the current nightly build and the issue persists.

@btrude
Copy link
btrude commented May 10, 2025

I'll also note for anyone else experiencing this that while NCCL_CUMEM_HOST_ENABLE=0 unblocks the out of memory error, you will also need to do eg pip install nvidia-nccl-cu12==2.26.2.post1 or you will encounter illegal memory access when saving checkpoints with FSDP.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: nccl Problems related to nccl support module: regression It used to work, and now it doesn't oncall: distributed Add this issue/PR to distributed oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

7 participants
0