-
Notifications
You must be signed in to change notification settings - Fork 24.2k
FSDP state dict OOM during model saving #98823
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
The recommended solution is to turn on |
I am closing this given @fegin's comment. |
@alanxmay many thanks to you ! NB plus!!!! |
still facing this problem, my process hangs forever |
@alanxmay I would like to understand the more detail issue here. The code you attached is |
@JulioZhao97 Just modify the line |
@fegin With the def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str):
"""Collects the state dict and dump to disk."""
state_dict = trainer.model.state_dict()
if trainer.args.should_save:
cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
del state_dict
trainer._save(output_dir, state_dict=cpu_state_dict) # noqa Without the def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str):
"""Collects the state dict and dump to disk."""
with FSDP.FullyShardedDataParallel.state_dict_type(
trainer.model,
StateDictType.LOCAL_STATE_DICT, # or any other StateDictType
LocalStateDictConfig(offload_to_cpu=True), # or without this line
LocalOptimStateDictConfig(offload_to_cpu=True), # or without this line
):
state_dict = trainer.model.state_dict()
if trainer.args.should_save:
cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
del state_dict
trainer._save(output_dir, state_dict=cpu_state_dict) # noqa |
Thanks for sharing, I will try right away. |
Finally, my issue was fixed, my simplified code is as follows:
It turns out that the cause of hanging is |
I further test the behavior of fsdp_model.state_dict() as follows:
And the output is this:
It seems that all other processes are blocked when state_dict() is called (or jump the state_dict() code). This is the weirdest thing, can anyone enlighten me on this? |
@JulioZhao97 That is very interesting. Given the code you showed, I cannot understand how that can happen. FSDP's I am not sure why a barrier in your previous code snippet would cause hangs either. I wonder if the two issues are related somehow. |
@alanxmay Can you confirm that the entire model is wrapped by FSDP? Or only parts of the model are wrapped by FSDP? |
My bad, I just notice that there is a
|
case closed, the hanging is because of the decorator |
@fegin Thanks for your valuable suggestion, I will check later |
Did you apply the |
I also apply the fix you 8000 mentioned, but not dive in deeper. |
origin: state_dict = trainer.model.state_dict()
if trainer.args.should_save:
cpu_state_dict = {
key: value.cpu()
for key, value in state_dict.items()
}
del state_dict
trainer._save(output_dir, state_dict=cpu_state_dict) # noqa fixed: from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import StateDictType, FullStateDictConfig
model = trainer.model
save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, save_policy):
cpu_state_dict = model.state_dict()
if trainer.args.should_save:
trainer._save(output_dir, state_dict=cpu_state_dict) # noqa |
@JACKHAHA363 did you solve it via KooSung's method? I tried @alanxmay 's approach but it also raised an OOM error. File "/miniconda3/envs/LLAMA2/lib/python3.9/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1154, in _optim_state_dict_impl |
Mark Thanks @ntlm1686 |
🐛 Describe the bug
see related user reporting issues in tatsu-lab/stanford_alpaca#81 and lm-sys/FastChat#256
A workaround that the community is applying is:
This is pretty manual monkey patching and we should really fix this in pytorch directly.
@fegin @awgu @rohan-varma @zhaojuanmao
Versions
This happens since pytorch 1.13 and I don't think we have fixed it so far.
cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu
The text was updated successfully, but these errors were encountered: