-
Notifications
You must be signed in to change notification settings - Fork 25.3k
Open
Labels
oncall: distributedAdd this issue/PR to distributed oncall triage queueAdd this issue/PR to distributed oncall triage queuetriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
🐛 Describe the bug
Got 2 OSS issues around GPU OOM for reshard_after_forward=int
- [FSDP2] OOM when use integer
reshard_after_forward
that smaller than DP size #147179 - reshard_after_forward does not work as expected in FSDP2 #149029
we are creating a new device mesh for every fully_shard. Each device mesh creates a new PG. Each PG takes 1GB memory. The problematic lines are
pytorch/torch/distributed/fsdp/_fully_shard/_fsdp_init.py
Lines 57 to 59 in 658aea9
post_forward_mesh = DeviceMesh( | |
mesh_info.mesh.device_type, post_forward_mesh_tensor | |
) |
A short-term workaround is
post_forward_device_mesh_info = model[0]._get_fsdp_state()._fsdp_param_group.post_forward_mesh_info
for layer in model:
layer._get_fsdp_state()._fsdp_param_group.post_forward_mesh_info = post_forward_device_mesh_info
discussing long-term fix here:
- requires user to init post forward mesh and pass it to
fully_shard(post_forward_mesh=)
- reuse device mesh / PG inside DeviceMesh
pytorch/torch/distributed/device_mesh.py
Lines 589 to 598 in 658aea9
# We temporarily revert the re-use subgroup, since it breaks two internal tests. # Temporarily reverting to resolve test timeout while root-causing. # TODO: Add two tests to cover internal tests scenarios and re-enable reuse subgroup if exists. if bound_device_id is None or not has_split_group: dim_group = new_group( ranks=subgroup_ranks, backend=backend, pg_options=pg_options, group_desc=group_desc, ) - hardening PG itself to merge smartly. see [c10d] ProcessGroupNCCL cuda streams got merged in nightly #153296
repro: torchrun --nproc-per-node 4 test_reshard_after_forward.py
import os
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed.fsdp import fully_shard
def main():
dist.init_process_group(backend="nccl")
gpu_id = int(os.environ["LOCAL_RANK"])
device = f"cuda:{gpu_id}"
torch.cuda.set_device(device)
model = nn.Sequential(
nn.Linear(10, 10),
nn.Linear(10, 10),
nn.Linear(10, 10),
nn.Linear(10, 10),
nn.Linear(10, 10),
)
for layer in model:
fully_shard(layer, reshard_after_forward=2)
fully_shard(model, reshard_after_forward=2)
# share PG to reduce memory usage
post_forward_device_mesh_info = model[0]._get_fsdp_state()._fsdp_param_group.post_forward_mesh_info
for layer in model:
layer._get_fsdp_state()._fsdp_param_group.post_forward_mesh_info = post_forward_device_mesh_info
x = torch.rand(1, 10, device=device)
model(x).sum().backward()
torch.distributed.destroy_process_group()
if __name__ == "__main__":
main()
Versions
skip
cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k
Metadata
Metadata
Assignees
Labels
oncall: distributedAdd this issue/PR to distributed oncall triage queueAdd this issue/PR to distributed oncall triage queuetriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module