-
Notifications
You must be signed in to change notification settings - Fork 24.3k
[FSDP2] OOM when use integer reshard_after_forward
that smaller than DP size
#147179
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
cc: @weifengpy would you have to time to check memory snapshot on this? I would not be too surprised if there is a memory bug with |
taking a look |
I dumped the memory snapshot at OOM. I can see active memory stablize at 60GB (snapshot 1) and reserved memory stablize at 82GB (snapshot 2). but ![]() ![]() |
@weifengpy one idea would be to run a separate program with the same PGs constructed but without much of the activation memory (so you can control the amount of PyTorch memory to be an exact expected amount). I am not sure if/how much NCCL does lazy channel creation leading to lazy buffers being allocated. That being said, I think already it is likely that NCCL is the one using the extra memory. Intra-node requires more NCCL memory. The issue here could likely be from the PGs then. |
finally got root cause. each fully_shard(reshard_after_forward=int) creates a new PG. Many layers creates many PGs and each PG takes 1GB memory short-term workaround is sharing post forward mesh
will figure out a long-term solution in #153302 |
Uh oh!
There was an error while loading. Please reload this page.
🐛 Describe the bug
When we use fsdp2 module to do inference only with
reshard_after_forward
set, we found that if we usereshard_after_forward=True
orreshard_after_forward=False
, fsdp2 works fine, but if we use a integerreshard_after_forward=4
withworld_size=8
, OOM happens in second step of inference. Thetorch.cuda.memory_*
also shows wrong memory stat during second inference step.Code:
Message:
you need to use
watch nvidia-smi
to check memory usage,torch.cuda.memory_*
don't work, don't need traceback.Versions
both
2.6.0
and2.7.0.dev20250212+cu124
cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @zhaojuanmao @mrshenli @rohan-varma @chauhang @mori360 @kwen2501 @c-p-i-o
The text was updated successfully, but these errors were encountered: