You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
@awgu When enabling the reshard_after_forward flag, parameters appear to remain unsharded even after the forward pass completes. While this works as expected for simple networks, the text encoder module from HuggingFace Transformers exhibits a memory increase after forward propagation even within a torch.no_grad() context. Manually invoking reshard() post-forward reduces memory usage suggesting automatic resharding is not occurring as intended.
Observations:
Minimal Example Works: Basic networks behave correctly with reshard_after_forward.
Transformer Text Encoder Fails: Memory usage grows after forward passes in no_grad mode, implying parameters are retained in unsharded state.
The main thing to keep in mind is that any FSDP root module has its reshard_after_forward: bool be overridden to be False. An FSDP root module is any FSDPModule that does not have an FSDPModule parent (above it). This is a design choice inherited from FSDP1, but I recognize the confusion.
Related to this issue, is it possible to change reshard_after_forward after the model has been wrapped?
The scenario is as follows:
model=fully_shard(model)
# Set `reshard_after_forward` to Falsewithtorch.no_grad():
for_inrange(5):
x=model(x)
# Set `reshard_after_forward` to Truemodel(x)
Since inference under no_grad consumes less memory, I want to set reshard_after_forward to False to save computation time. After multiple inference steps under no_grad, I plan to set reshard_after_forward back to True and proceed with normal training.
Initially, I tried setting reshard_after_forward to 8 (assuming 8 GPUs per node) during wrapping, but I encountered an OOM error, which seems similar to this issue #147179. Now, I set reshard_after_forward to False during wrapping and explicitly call model.reshard() when necessary. Is there a better way to handle this?
🐛 Describe the bug
@awgu When enabling the reshard_after_forward flag, parameters appear to remain unsharded even after the forward pass completes. While this works as expected for simple networks, the text encoder module from HuggingFace Transformers exhibits a memory increase after forward propagation even within a torch.no_grad() context. Manually invoking reshard() post-forward reduces memory usage suggesting automatic resharding is not occurring as intended.
Observations:
Versions
2.7.0.dev20250107+cu124
cc @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o @zhaojuanmao @mrshenli @rohan-varma @chauhang
The text was updated successfully, but these errors were encountered: