[FSDP2] avoid GPU OOM for reshard_after_forward=int with shared post_forward_mesh #153302
Labels
oncall: distributed
Add this issue/PR to distributed oncall triage queue
triaged
This issue has been looked at a team member, and triaged and prioritized into an appropriate module
🐛 Describe the bug
Got 2 OSS issues around GPU OOM for
reshard_after_forward=int
reshard_after_forward
that smaller than DP size #147179we are creating a new device mesh for every fully_shard. Each device mesh creates a new PG. Each PG takes 1GB memory. The problematic lines are
pytorch/torch/distributed/fsdp/_fully_shard/_fsdp_init.py
Lines 57 to 59 in 658aea9
A short-term workaround is
discussing long-term fix here:
fully_shard(post_forward_mesh=)
pytorch/torch/distributed/device_mesh.py
Lines 589 to 598 in 658aea9
repro:
torchrun --nproc-per-node 4 test_reshard_after_forward.py
Versions
skip
cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k
The text was updated successfully, but these errors were encountered: