8000 FSDP state dict OOM during model saving · Issue #98823 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

FSDP state dict OOM during model saving #98823

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wanchaol opened this issue Apr 11, 2023 · 23 comments
Closed

FSDP state dict OOM during model saving #98823

wanchaol opened this issue Apr 11, 2023 · 23 comments
Assignees
Labels
module: fsdp oncall: distributed Add this issue/PR to distributed oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@wanchaol
Copy link
Collaborator
wanchaol commented Apr 11, 2023

🐛 Describe the bug

see related user reporting issues in tatsu-lab/stanford_alpaca#81 and lm-sys/FastChat#256

A workaround that the community is applying is:

Assume you are using torch=1.13.0, change python/lib/python3.9/site packages/torch/distributed/fsdp/fully_sharded_data_parallel.py:2224 from state_dict[fqn] = state_dict[fqn].clone().detach() to state_dict[fqn] = state_dict[fqn].cpu().clone().detach()`

This is pretty manual monkey patching and we should really fix this in pytorch directly.

@fegin @awgu @rohan-varma @zhaojuanmao

Versions

This happens since pytorch 1.13 and I don't think we have fixed it so far.

cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu

@wanchaol wanchaol added oncall: distributed Add this issue/PR to distributed oncall triage queue module: fsdp labels Apr 11, 2023
@fegin
Copy link
Contributor
fegin commented Apr 12, 2023

The recommended solution is to turn on cpu_offload for state_dict. The example can be found https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.set_state_dict_type

@awgu awgu added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Apr 12, 2023
@awgu
Copy link
Collaborator
awgu commented Apr 24, 2023

I am closing this given @fegin's comment.

@alanxmay
Copy link

@fegin @awgu I tried using:

with FSDP.FullyShardedDataParallel.state_dict_type(
        trainer.model,
        StateDictType.LOCAL_STATE_DICT, # or any other StateDictType
        LocalStateDictConfig(offload_to_cpu=True), # or without this line
        LocalOptimStateDictConfig(offload_to_cpu=True), # or without this line
        ):
    state_dict = trainer.model.state_dict()

The program will be stuck in this state for a very long time:
image

It eventually times out

I finally managed to save the model with python3.10 and torch==2.0 by change /python3.10/site-packages/torch/distributed/fsdp/_state_dict_utils.py on line 309 from state_dict[fqn] = state_dict[fqn].clone().detach() to state_dict[fqn] = state_dict[fqn].cpu().clone().detach(). It works really well.

@laiqinghan
Copy link

@alanxmay many thanks to you ! NB plus!!!!

@JulioZhao97
Copy link

still facing this problem, my process hangs forever

@JulioZhao97
Copy link

@fegin @awgu I tried using:

with FSDP.FullyShardedDataParallel.state_dict_type(
        trainer.model,
        StateDictType.LOCAL_STATE_DICT, # or any other StateDictType
        LocalStateDictConfig(offload_to_cpu=True), # or without this line
        LocalOptimStateDictConfig(offload_to_cpu=True), # or without this line
        ):
    state_dict = trainer.model.state_dict()

The program will be stuck in this state for a very long time: image

It eventually times out

I finally managed to save the model with python3.10 and torch==2.0 by change /python3.10/site-packages/torch/distributed/fsdp/_state_dict_utils.py on line 309 from state_dict[fqn] = state_dict[fqn].clone().detach() to state_dict[fqn] = state_dict[fqn].cpu().clone().detach(). It works really well.

Can I please ask how you call the model.state_dict? Did you use with FSDP.FullyShardedDataParallel.state_dict_type or just change the line you mentioned?