8000 [FSDP2] avoid GPU OOM for reshard_after_forward=int with shared post_forward_mesh · Issue #153302 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[FSDP2] avoid GPU OOM for reshard_after_forward=int with shared post_forward_mesh #153302

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
weifengpy opened this issue May 10, 2025 · 4 comments
Labels
oncall: distributed Add this issue/PR to distributed oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@weifengpy
Copy link
Contributor
weifengpy commented May 10, 2025

🐛 Describe the bug

Got 2 OSS issues around GPU OOM for reshard_after_forward=int

we are creating a new device mesh for every fully_shard. Each device mesh creates a new PG. Each PG takes 1GB memory. The problematic lines are

post_forward_mesh = DeviceMesh(
mesh_info.mesh.device_type, post_forward_mesh_tensor
)

A short-term workaround is

post_forward_device_mesh_info = model[0]._get_fsdp_state()._fsdp_param_group.post_forward_mesh_info
for layer in model:
    layer._get_fsdp_state()._fsdp_param_group.post_forward_mesh_info = post_forward_device_mesh_info

discussing long-term fix here:

  • requires user to init post forward mesh and pass it to fully_shard(post_forward_mesh=)
  • reuse device mesh / PG inside DeviceMesh
    # We temporarily revert the re-use subgroup, since it breaks two internal tests.
    # Temporarily reverting to resolve test timeout while root-causing.
    # TODO: Add two tests to cover internal tests scenarios and re-enable reuse subgroup if exists.
    if bound_device_id is None or not has_split_group:
    dim_group = new_group(
    ranks=subgroup_ranks,
    backend=backend,
    pg_options=pg_options,
    group_desc=group_desc,
    )
  • hardening PG itself to merge smartly. see [c10d] ProcessGroupNCCL cuda streams got merged in nightly #153296

repro: torchrun --nproc-per-node 4 test_reshard_after_forward.py

import os
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed.fsdp import fully_shard

def main():
    dist.init_process_group(backend="nccl")
    gpu_id = int(os.environ["LOCAL_RANK"])
    device = f"cuda:{gpu_id}"
    torch.cuda.set_device(device)
    model = nn.Sequential(
        nn.Linear(10, 10), 
        nn.Linear(10, 10),
        nn.Linear(10, 10),
        nn.Linear(10, 10),
        nn.Linear(10, 10),
    )
    for layer in model:
        fully_shard(layer, reshard_after_forward=2)
    fully_shard(model, reshard_after_forward=2)

    # share PG to reduce memory usage
    post_forward_device_mesh_info = model[0]._get_fsdp_state()._fsdp_param_group.post_forward_mesh_info
    for layer in model:
        layer._get_fsdp_state()._fsdp_param_group.post_forward_mesh_info = post_forward_device_mesh_info
    
    x = torch.rand(1, 10, device=device)
    model(x).sum().backward()
    torch.distributed.destroy_process_group()


if __name__ == "__main__":
    main()

Versions

skip

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k

@weifengpy
Copy link
Contributor Author

cc @awgu for preferences, fully_shard(post_forward_mesh=) vs reuse subgroup in DeviceMesh

@awgu
Copy link
Collaborator
awgu commented May 10, 2025

I do not have a strong preference between passing a mesh vs. having DeviceMesh try to reuse smartly. The former is more explicit. I personally am not a big fan of the reshard_after_forward: int code path to begin with 🤔

@colesbury colesbury added the oncall: distributed Add this issue/PR to distributed oncall triage queue label May 13, 2025
@fegin
Copy link
Contributor
fegin commented May 13, 2025

The PG reuse for DeviceMesh was on the plan last year, afaik. We need to be careful if users do not want to reuse PG, otherwise, it should also be beneficial for other use cases.

@weifengpy
Copy link
Contributor Author

Talked with @wz337 , she is open to PG reuse in device mesh

@weifengpy weifengpy added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label May 13, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
oncall: distributed Add this issue/PR to distributed oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

4 participants
0