8000 reshard_after_forward does not work as expected in FSDP2 · Issue #149029 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

reshard_after_forward does not work as expected in FSDP2 #149029

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
caiqi opened this issue Mar 12, 2025 · 3 comments
Closed

reshard_after_forward does not work as expected in FSDP2 #149029

caiqi opened this issue Mar 12, 2025 · 3 comments
Assignees
Labels
module: fsdp oncall: distributed Add this issue/PR to distributed oncall triage queue

Comments

@caiqi
Copy link
caiqi commented Mar 12, 2025

🐛 Describe the bug

@awgu When enabling the reshard_after_forward flag, parameters appear to remain unsharded even after the forward pass completes. While this works as expected for simple networks, the text encoder module from HuggingFace Transformers exhibits a memory increase after forward propagation even within a torch.no_grad() context. Manually invoking reshard() post-forward reduces memory usage suggesting automatic resharding is not occurring as intended.

Observations:

  • Minimal Example Works: Basic networks behave correctly with reshard_after_forward.
  • Transformer Text Encoder Fails: Memory usage grows after forward passes in no_grad mode, implying parameters are retained in unsharded state.
  • Manual Intervention Resolves: Explicitly calling reshard() post-forward reduces memory.
  • Reproducibility: A minimal reproducible example is provided below.

import torch
import torch.distributed as dist
from torch import nn
from torch.distributed.fsdp import fully_shard, MixedPrecisionPolicy, FSDPModule
import os
from diffusers import DiffusionPipeline, StableDiffusion3Pipeline
from transformers.models.t5.modeling_t5 import T5Block

class SimpleNet(nn.Module):
    def __init__(self, *args, **kwargs):
        super(SimpleNet, self).__init__()
        self.nets = nn.ModuleList()
        for i in range(40):
            self.nets.append(nn.Conv2d(4096, 4096, 3, padding=1))
        self.attn_stream = torch.cuda.Stream()
    def forward(self, x):
        for layer in self.nets:
            x = layer(x)
        return x

def print_memory(desp):
    rank = int(os.environ['RANK'])
    torch.cuda.empty_cache()
    if rank == 0 :
        print(f"{desp} Memory: ", torch.cuda.memory_reserved() / 1024 / 1024, "MB")


def recursive_reshard(module: nn.Module):
    for n, m in reversed(list(module.named_modules())):
        if isinstance(m, FSDPModule):
            m.reshard()
    module.reshard()

if "__main__" == __name__:
    dist.init_process_group(backend='nccl')
    local_rank = int(os.environ['LOCAL_RANK'])
    rank = int(os.environ['RANK'])
    world_size = int(os.environ['WORLD_SIZE'])
    torch.cuda.set_device(local_rank)
    mp_policy = MixedPrecisionPolicy(
        param_dtype=torch.bfloat16,
        reduce_dtype=torch.bfloat16,
        output_dtype=torch.bfloat16,
        cast_forward_inputs=True
    )
    model = SimpleNet()
    model = model.to("cuda", torch.bfloat16)
    model_params = sum(p.numel() for p in model.parameters()) / 1e6
    print_memory(f"Model params: {model_params}M")
    for module in model.modules():
        if isinstance(module, nn.Conv2d):
            fully_shard(module, reshard_after_forward=True,  mp_policy=mp_policy)
    fully_shard(model, mp_policy=mp_policy)
    pipeline_utils = DiffusionPipeline.from_pretrained("./stable-diffusion-3-medium-diffusers",text_encoder = None, text_encoder_2 = None, vae = None, transformer=None)
    for module in pipeline_utils.text_encoder_3.modules():
        if isinstance(module, T5Block):
            fully_shard(module, reshard_after_forward=True, mp_policy=mp_policy)
    fully_shard(pipeline_utils.text_encoder_3, mp_policy=mp_policy)
    text_encoder_params = sum(p.numel() for p in pipeline_utils.text_encoder_3.parameters()) / 1e6
    print_memory(f"Text encoder params: {text_encoder_params}M")
    model.requires_grad_(False)
    print_memory("after init model with fsdp")
    fake_x = torch.randn(1, 4096, 16, 16, device="cuda", dtype=torch.bfloat16)
    with torch.no_grad():
        target = model(fake_x)
    print_memory("SimpleNet forward finished")
    model.reshard()
    print_memory("SimpleNet reshard finished")
    with torch.no_grad():
        text_inputs = pipeline_utils.tokenizer_3(
            "a prompt",
            padding="max_length",
            max_length=256,
            truncation=True,
            add_special_tokens=True,
            return_tensors="pt",
        ).input_ids
        prompt_embeds = pipeline_utils.text_encoder_3(text_inputs.to("cuda"))[0]
    print_memory("Encode prompt finished")
    pipeline_utils.text_encoder_3.reshard()
    print_memory("Text encoder reshard finished")
    dist.destroy_process_group()
    print_memory("Done")

Image

Versions

2.7.0.dev20250107+cu124

cc @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o @zhaojuanmao @mrshenli @rohan-varma @chauhang

@awgu
Copy link
Collaborator
awgu commented Mar 12, 2025

The main thing to keep in mind is that any FSDP root module has its reshard_after_forward: bool be overridden to be False. An FSDP root module is any FSDPModule that does not have an FSDPModule parent (above it). This is a design choice inherited from FSDP1, but I recognize the confusion.

@caiqi
Copy link
Author
caiqi commented Mar 12, 2025

Thanks. I missed the params in the root module.

Related to this issue, is it possible to change reshard_after_forward after the model has been wrapped?

The scenario is as follows:

model = fully_shard(model)

# Set `reshard_after_forward` to False
with torch.no_grad():
    for _ in range(5):
        x = model(x)
# Set `reshard_after_forward` to True

model(x)

Since inference under no_grad consumes less memory, I want to set reshard_after_forward to False to save computation time. After multiple inference steps under no_grad, I plan to set reshard_after_forward back to True and proceed with normal training.

Initially, I tried setting reshard_after_forward to 8 (assuming 8 GPUs per node) during wrapping, but I encountered an OOM error, which seems similar to this issue #147179. Now, I set reshard_after_forward to False during wrapping and explicitly call model.reshard() when necessary. Is there a better way to handle this?

@awgu
Copy link
Collaborator
awgu commented Mar 12, 2025

I think it is a reasonable request to make reshard_after_forward configurable (i.e. add a setter).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: fsdp oncall: distributed Add this issue/PR to distributed oncall triage queue
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants
0