8000 reshard_after_forward does not work as expected in FSDP2 · Issue #149029 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content
reshard_after_forward does not work as expected in FSDP2 #149029
@caiqi

Description

@caiqi

🐛 Describe the bug

@awgu When enabling the reshard_after_forward flag, parameters appear to remain unsharded even after the forward pass completes. While this works as expected for simple networks, the text encoder module from HuggingFace Transformers exhibits a memory increase after forward propagation even within a torch.no_grad() context. Manually invoking reshard() post-forward reduces memory usage suggesting automatic resharding is not occurring as intended.

Observations:

  • Minimal Example Works: Basic networks behave correctly with reshard_after_forward.
  • Transformer Text Encoder Fails: Memory usage grows after forward passes in no_grad mode, implying parameters are retained in unsharded state.
  • Manual Intervention Resolves: Explicitly calling reshard() post-forward reduces memory.
  • Reproducibility: A minimal reproducible example is provided below.

import torch
import torch.distributed as dist
from torch import nn
from torch.distributed.fsdp import fully_shard, MixedPrecisionPolicy, FSDPModule
import os
from diffusers import DiffusionPipeline, StableDiffusion3Pipeline
from transformers.models.t5.modeling_t5 import T5Block

class SimpleNet(nn.Module):
    def __init__(self, *args, **kwargs):
        super(SimpleNet, self).__init__()
        self.nets = nn.ModuleList()
        for i in range(40):
            self.nets.append(nn.Conv2d(4096, 4096, 3, padding=1))
        self.attn_stream = torch.cuda.Stream()
    def forward(self, x):
        for layer in self.nets:
            x = layer(x)
        return x

def print_memory(desp):
    rank = int(os.environ['RANK'])
    torch.cuda.empty_cache()
    if rank == 0 :
        print(f"{desp} Memory: ", torch.cuda.memory_reserved() / 1024 / 1024, "MB")


def recursive_reshard(module: nn.Module):
    for n, m in reversed(list(module.named_modules())):
        if isinstance(m, FSDPModule):
            m.reshard()
    module.reshard()

if "__main__" == __name__:
    dist.init_process_group(backend='nccl')
    local_rank = int(os.environ['LOCAL_RANK'])
    rank = int(os.environ['RANK'])
    world_size = int(os.environ['WORLD_SIZE'])
    torch.cuda.set_device(local_rank)
    mp_policy = MixedPrecisionPolicy(
        param_dtype=torch.bfloat16,
        reduce_dtype=torch.bfloat16,
        output_dtype=torch.bfloat16,
        cast_forward_inputs=True
    )
    model = SimpleNet()
    model = model.to("cuda", torch.bfloat16)
    model_params = sum(p.numel() for p in model.parameters()) / 1e6
    print_memory(f"Model params: {model_params}M")
    for module in model.modules():
        if isinstance(module, nn.Conv2d):
            fully_shard(module, reshard_after_forward=True,  mp_policy=mp_policy)
    fully_shard(model, mp_policy=mp_policy)
    pipeline_utils = DiffusionPipeline.from_pretrained("./stable-diffusion-3-medium-diffusers",text_encoder = None, text_encoder_2 = None, vae = None, transformer=None)
    for module in pipeline_utils.text_encoder_3.modules():
        if isinstance(module, T5Block):
            fully_shard(module, reshard_after_forward=True, mp_policy=mp_policy)
    fully_shard(pipeline_utils.text_encoder_3, mp_policy=mp_policy)
    text_encoder_params = sum(p.numel() for p in pipeline_utils.text_encoder_3.parameters()) / 1e6
    print_memory(f"Text encoder params: {text_encoder_params}M")
    model.requires_grad_(False)
    print_memory("after init model with fsdp")
    fake_x = torch.randn(1, 4096, 16, 16, device="cuda", dtype=torch.bfloat16)
    with torch.no_grad():
        target = model(fake_x)
    print_memory("SimpleNet forward finished")
    model.reshard()
    print_memory("SimpleNet reshard finished")
    with torch.no_grad():
        text_inputs = pipeline_utils.tokenizer_3(
            "a prompt",
            padding="max_length",
            max_length=256,
            truncation=True,
            add_special_tokens=True,
            return_tensors="pt",
        ).input_ids
        prompt_embeds = pipeline_utils.text_encoder_3(text_inputs.to("cuda"))[0]
    print_memory("Encode prompt finished")
    pipeline_utils.text_encoder_3.reshard()
    print_memory("Text encoder reshard finished")
    dist.destroy_process_group()
    print_memory("Done")

Image

Versions

2.7.0.dev20250107+cu124

cc @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o @zhaojuanmao @mrshenli @rohan-varma @chauhang

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions

    0