Open
Description
🐛 Describe the bug
Got 2 OSS issues around GPU OOM for reshard_after_forward=int
- [FSDP2] OOM when use integer
reshard_after_forward
that smaller than DP size #147179 - reshard_after_forward does not work as expected in FSDP2 #149029
we are creating a new device mesh for every fully_shard. Each device mesh creates a new PG. Each PG takes 1GB memory. The problematic lines are
pytorch/torch/distributed/fsdp/_fully_shard/_fsdp_init.py
Lines 57 to 59 in 658aea9
A short-term workaround is
post_forward_device_mesh_info = model[0]._get_fsdp_state()._fsdp_param_group.post_forward_mesh_info
for layer in model:
layer._get_fsdp_state()._fsdp_param_group.post_forward_mesh_info = post_forward_device_mesh_info
discussing long-term fix here:
- requires user to init post forward mesh and pass it to
fully_shard(post_forward_mesh=)
- reuse device mesh / PG inside DeviceMesh
pytorch/torch/distributed/device_mesh.py
Lines 589 to 598 in 658aea9
- hardening PG itself to merge smartly. see [c10d] ProcessGroupNCCL cuda streams got merged in nightly #153296
repro: torchrun --nproc-per-node 4 test_reshard_after_forward.py
import os
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed.fsdp import fully_shard
def main():
dist.init_process_group(backend="nccl")
gpu_id = int(os.environ["LOCAL_RANK"])
device = f"cuda:{gpu_id}"
torch.cuda.set_device(device)
model = nn.Sequential(
nn.Linear(10, 10),
nn.Linear(10, 10),
nn.Linear(10, 10),
nn.Linear(10, 10),
nn.Linear(10, 10),
)
for layer in model:
fully_shard(layer, reshard_after_forward=2)
fully_shard(model, reshard_after_forward=2)
# share PG to reduce memory usage
post_forward_device_mesh_info = model[0]._get_fsdp_state()._fsdp_param_group.post_forward_mesh_info
for layer in model:
layer._get_fsdp_state()._fsdp_param_group.post_forward_mesh_info = post_forward_device_mesh_info
x = torch.rand(1, 10, device=device)
model(x).sum().backward()
torch.distributed.destroy_process_group()
if __name__ == "__main__":
main()
Versions
skip
cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k