8000 [FSDP2] avoid GPU OOM for reshard_after_forward=int with shared post_forward_mesh · Issue #153302 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content
[FSDP2] avoid GPU OOM for reshard_after_forward=int with shared post_forward_mesh #153302
Open
@weifengpy

Description

@weifengpy

🐛 Describe the bug

Got 2 OSS issues around GPU OOM for reshard_after_forward=int

we are creating a new device mesh for every fully_shard. Each device mesh creates a new PG. Each PG takes 1GB memory. The problematic lines are

post_forward_mesh = DeviceMesh(
mesh_info.mesh.device_type, post_forward_mesh_tensor
)

A short-term workaround is

post_forward_device_mesh_info = model[0]._get_fsdp_state()._fsdp_param_group.post_forward_mesh_info
for layer in model:
    layer._get_fsdp_state()._fsdp_param_group.post_forward_mesh_info = post_forward_device_mesh_info

discussing long-term fix here:

  • requires user to init post forward mesh and pass it to fully_shard(post_forward_mesh=)
  • reuse device mesh / PG inside DeviceMesh
    # We temporarily revert the re-use subgroup, since it breaks two internal tests.
    # Temporarily reverting to resolve test timeout while root-causing.
    # TODO: Add two tests to cover internal tests scenarios and re-enable reuse subgroup if exists.
    if bound_device_id is None or not has_split_group:
    dim_group = new_group(
    ranks=subgroup_ranks,
    backend=backend,
    pg_options=pg_options,
    group_desc=group_desc,
    )
  • hardening PG itself to merge smartly. see [c10d] ProcessGroupNCCL cuda streams got merged in nightly #153296

repro: torchrun --nproc-per-node 4 test_reshard_after_forward.py

import os
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed.fsdp import fully_shard

def main():
    dist.init_process_group(backend="nccl")
    gpu_id = int(os.environ["LOCAL_RANK"])
    device = f"cuda:{gpu_id}"
    torch.cuda.set_device(device)
    model = nn.Sequential(
        nn.Linear(10, 10), 
        nn.Linear(10, 10),
        nn.Linear(10, 10),
        nn.Linear(10, 10),
        nn.Linear(10, 10),
    )
    for layer in model:
        fully_shard(layer, reshard_after_forward=2)
    fully_shard(model, reshard_after_forward=2)

    # share PG to reduce memory usage
    post_forward_device_mesh_info = model[0]._get_fsdp_state()._fsdp_param_group.post_forward_mesh_info
    for layer in model:
        layer._get_fsdp_state()._fsdp_param_group.post_forward_mesh_info = post_forward_device_mesh_info
    
    x = torch.rand(1, 10, device=device)
    model(x).sum().backward()
    torch.distributed.destroy_process_group()


if __name__ == "__main__":
    main()

Versions

skip

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k

Metadata

Metadata

Assignees

No one assigned

    Labels

    oncall: distributedAdd this issue/PR to distributed oncall triage queuetriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      0