8000 [FSDP2] OOM when use integer `reshard_after_forward` that smaller than DP size · Issue #147179 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[FSDP2] OOM when use integer reshard_after_forward that smaller than DP size #147179

8000
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
FindDefinition opened this issue Feb 14, 2025 · 5 comments
Labels
module: fsdp oncall: distributed Add this issue/PR to distributed oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@FindDefinition
Copy link
FindDefinition commented Feb 14, 2025

🐛 Describe the bug

When we use fsdp2 module to do inference only with reshard_after_forward set, we found that if we use reshard_after_forward=True or reshard_after_forward=False, fsdp2 works fine, but if we use a integer reshard_after_forward=4 with world_size=8, OOM happens in second step of inference. The torch.cuda.memory_* also shows wrong memory stat during second inference step.

Code:

import torch 
from torch.distributed.device_mesh import init_device_mesh
from torch.nn import functional as F
import os
from torch.distributed.fsdp import (
    fully_shard,
    MixedPrecisionPolicy,
)
from torch.distributed.device_mesh import DeviceMesh
import torch.distributed as dist

_world_size = int(os.environ["WORLD_SIZE"])
assert _world_size == 8, "you must run this script with world size 8 to reproduce the bug"
device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(_world_size,))
class FFN(torch.nn.Module):

    def __init__(self, dim, inter_dim):
        super().__init__()

        self.w1 = torch.nn.Linear(dim, inter_dim)
        self.w2 = torch.nn.Linear(inter_dim, dim)
        self.w3 = torch.nn.Linear(dim, inter_dim)

    def forward(self, x) -> torch.Tensor:
        return self.w2(F.silu(self.w1(x)) * self.w3(x))

class VeryLargeFFN(torch.nn.Module):
    def __init__(self, num_layers, dim, inter_dim):
        super().__init__()
        ffns = {}
        for i in range(num_layers):
            ffns[str(i)] = FFN(dim, inter_dim)
        self.ffns = torch.nn.ModuleDict(ffns)

    def forward(self, x, show_wrong_memory_stats: bool = False) -> torch.Tensor:
        for block in self.ffns.values():
            if dist.get_rank() == 0 and show_wrong_memory_stats:
                stat = torch.cuda.memory_stats()
                active_peak = stat.get("active_bytes.all.current", 0) / (1024 * 1024 * 1024)
                alloc_peak = stat.get("allocated_bytes.all.current", 0) / (1024 * 1024 * 1024)
                reserved_peak = stat.get("reserved_bytes.all.current", 0) / (1024 * 1024 * 1024)
                print(f"active_peak: {active_peak:.2f}GB, alloc_peak: {alloc_peak:.2f}GB, reserved_peak: {reserved_peak:.2f}GB")
            # print(cur_alloc)
            x = block(x)
        return x

def fsdp_mod( net: VeryLargeFFN, mesh: DeviceMesh, reshard: int):
    full_shard: bool | int = reshard == -1
    if reshard > 0:
        full_shard = reshard
    mixed_fsdp2 = MixedPrecisionPolicy(reduce_dtype=torch.float32, param_dtype=torch.bfloat16, cast_forward_inputs=False)
    for block in net.ffns.values():
        fully_shard(block, mesh=mesh, reshard_after_forward=full_shard, mp_policy=mixed_fsdp2)
    fully_shard(net, mesh=mesh, reshard_after_forward=full_shard, mp_policy=mixed_fsdp2)

mod = VeryLargeFFN(32, 2048, 8192).cuda().eval().to(torch.bfloat16)

# fsdp_mod(mod, device_mesh, 0) # if we use 8GPUs with no reshard, no problem
fsdp_mod(mod, device_mesh, 4) # if we use 8GPUs with 4 reshard, OOM happens

for i in range(2):
    sample_inp = torch.randn(64, 16384, 2048).cuda().to(torch.bfloat16)
    if dist.get_rank() == 0:
        print(f"-----i={i}-----")   
    with torch.no_grad():
        mod(sample_inp, show_wrong_memory_stats=True)
    torch.cuda.synchronize()

# print(torch.cuda.memory_summary())
dist.barrier()
dist.destroy_process_group()

Message:

you need to use watch nvidia-smi to check memory usage, torch.cuda.memory_* don't work, don't need traceback.

Versions

both 2.6.0 and 2.7.0.dev20250212+cu124

Versions of relevant libraries:
[pip3] numpy==2.1.2
[pip3] nvidia-cublas-cu12==12.4.5.8
[pip3] nvidia-cuda-cupti-cu12==12.4.127
[pip3] nvidia-cuda-nvrtc-cu12==12.4.127
[pip3] nvidia-cuda-runtime-cu12==12.4.127
[pip3] nvidia-cudnn-cu12==9.1.0.70
[pip3] nvidia-cufft-cu12==11.2.1.3
[pip3] nvidia-curand-cu12==10.3.5.147
[pip3] nvidia-cusolver-cu12==11.6.1.9
[pip3] nvidia-cusparse-cu12==12.3.1.170
[pip3] nvidia-cusparselt-cu12==0.6.2
[pip3] nvidia-nccl-cu12==2.21.5
[pip3] nvidia-nvjitlink-cu12==12.4.127
[pip3] nvidia-nvtx-cu12==12.4.127
[pip3] pytorch-triton==3.2.0+git4b3bb1f8
[pip3] torch==2.7.0.dev20250212+cu124
[pip3] torchaudio==2.6.0.dev20250212+cu124
[pip3] torchvision==0.22.0.dev20250212+cu124
[pip3] triton==3.1.0
[conda] numpy                     2.1.2                    pypi_0    pypi
[conda] nvidia-cublas-cu12        12.4.5.8                 pypi_0    pypi
[conda] nvidia-cuda-cupti-cu12    12.4.127                 pypi_0    pypi
[conda] nvidia-cuda-nvrtc-cu12    12.4.127                 pypi_0    pypi
[conda] nvidia-cuda-runtime-cu12  12.4.127                 pypi_0    pypi
[conda] nvidia-cudnn-cu12         9.1.0.70                 pypi_0    pypi
[conda] nvidia-cufft-cu12         11.2.1.3                 pypi_0    pypi
[conda] nvidia-curand-cu12        10.3.5.147               pypi_0    pypi
[conda] nvidia-cusolver-cu12      11.6.1.9                 pypi_0    pypi
[conda] nvidia-cusparse-cu12      12.3.1.170               pypi_0    pypi
[conda] nvidia-cusparselt-cu12    0.6.2                    pypi_0    pypi
[conda] nvidia-nccl-cu12          2.21.5                   pypi_0    pypi
[conda] nvidia-nvjitlink-cu12     12.4.127                 pypi_0    pypi
[conda] nvidia-nvtx-cu12          12.4.127                 pypi_0    pypi
[conda] pytorch-triton            3.2.0+git4b3bb1f8          pypi_0    pypi
[conda] torch                     2.7.0.dev20250212+cu124          pypi_0    pypi
[conda] torchaudio                2.6.0.dev20250212+cu124          pypi_0    pypi
[conda] torchvision               0.22.0.dev20250212+cu124          pypi_0    pypi
[conda] triton                    3.1.0                    pypi_0    pypi

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @zhaojuanmao @mrshenli @rohan-varma @chauhang @mori360 @kwen2501 @c-p-i-o

@awgu
Copy link
Collaborator
awgu commented Feb 14, 2025

cc: @weifengpy would you have to time to check memory snapshot on this? I would not be too surprised if there is a memory bug with reshard_after_forward: int path for multiple forward passes.

@jbschlosser jbschlosser added oncall: distributed Add this issue/PR to distributed oncall triage queue module: fsdp labels Feb 14, 2025
@weifengpy
Copy link
Contributor

cc: @weifengpy would you have to time to check memory snapshot on this? I would not be too surprised if there is a memory bug with reshard_after_forward: int path for multiple forward passes.

taking a look

@weifengpy
Copy link
Contributor

would you have to time to check memory snapshot on this?

I dumped the memory snapshot at OOM. I can see active memory stablize at 60GB (snapshot 1) and reserved memory stablize at 82GB (snapshot 2). but nvidia-smi reported 96GB memory usage. The OOMing seems to be coming from NCCL memory space or extra PGs because of communicting on new device mesh of size 4. But those memories are not tracked by pytorch CCA. @awgu do you happen to know how to track them?

Image Image

@awgu
Copy link
Collaborator
awgu commented Feb 17, 2025

@weifengpy one idea would be to run a separate program with the same PGs constructed but without much of the activation memory (so you can control the amount of PyTorch memory to be an exact expected amount). I am not sure if/how much NCCL does lazy channel creation leading to lazy buffers being allocated.

That being said, I think already it is likely that NCCL is the one using the extra memory. Intra-node requires more NCCL memory. The issue here could likely be from the PGs then.

@fduwjj fduwjj added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Apr 23, 2025
@weifengpy
Copy link
Contributor

finally got root cause. each fully_shard(reshard_after_forward=int) creates a new PG. Many layers creates many PGs and each PG takes 1GB memory

short-term workaround is sharing post forward mesh

for block in net.ffns.values():
    fully_shard(block, mesh=mesh, reshard_after_forward=full_shard, mp_policy=mixed_fsdp2)
fully_shard(net, mesh=mesh, reshard_after_forward=full_shard, mp_policy=mixed_fsdp2)

post_forward_device_mesh_info = net._get_fsdp_state()._fsdp_param_group.post_forward_mesh_info
for block in net.ffns.values():
    block._get_fsdp_state()._fsdp_param_group.post_forward_mesh_info = post_forward_device_mesh_info

will figure out a long-term solution in #153302

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: fsdp oncall: distributed Add this issue/PR to distributed oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

5 participants
0