-
Notifications
You must be signed in to change notification settings - Fork 25.3k
Closed
Labels
module: fsdponcall: distributedAdd this issue/PR to distributed oncall triage queueAdd this issue/PR to distributed oncall triage queue
Description
🐛 Describe the bug
@awgu When enabling the reshard_after_forward flag, parameters appear to remain unsharded even after the forward pass completes. While this works as expected for simple networks, the text encoder module from HuggingFace Transformers exhibits a memory increase after forward propagation even within a torch.no_grad() context. Manually invoking reshard() post-forward reduces memory usage suggesting automatic resharding is not occurring as intended.
Observations:
- Minimal Example Works: Basic networks behave correctly with reshard_after_forward.
- Transformer Text Encoder Fails: Memory usage grows after forward passes in no_grad mode, implying parameters are retained in unsharded state.
- Manual Intervention Resolves: Explicitly calling reshard() post-forward reduces memory.
- Reproducibility: A minimal reproducible example is provided below.
import torch
import torch.distributed as dist
from torch import nn
from torch.distributed.fsdp import fully_shard, MixedPrecisionPolicy, FSDPModule
import os
from diffusers import DiffusionPipeline, StableDiffusion3Pipeline
from transformers.models.t5.modeling_t5 import T5Block
class SimpleNet(nn.Module):
def __init__(self, *args, **kwargs):
super(SimpleNet, self).__init__()
self.nets = nn.ModuleList()
for i in range(40):
self.nets.append(nn.Conv2d(4096, 4096, 3, padding=1))
self.attn_stream = torch.cuda.Stream()
def forward(self, x):
for layer in self.nets:
x = layer(x)
return x
def print_memory(desp):
rank = int(os.environ['RANK'])
torch.cuda.empty_cache()
if rank == 0 :
print(f"{desp} Memory: ", torch.cuda.memory_reserved() / 1024 / 1024, "MB")
def recursive_reshard(module: nn.Module):
for n, m in reversed(list(module.named_modules())):
if isinstance(m, FSDPModule):
m.reshard()
module.reshard()
if "__main__" == __name__:
dist.init_process_group(backend='nccl')
local_rank = int(os.environ['LOCAL_RANK'])
rank = int(os.environ['RANK'])
world_size = int(os.environ['WORLD_SIZE'])
torch.cuda.set_device(local_rank)
mp_policy = MixedPrecisionPolicy(
param_dtype=torch.bfloat16,
reduce_dtype=torch.bfloat16,
output_dtype=torch.bfloat16,
cast_forward_inputs=True
)
model = SimpleNet()
model = model.to("cuda", torch.bfloat16)
model_params = sum(p.numel() for p in model.parameters()) / 1e6
print_memory(f"Model params: {model_params}M")
for module in model.modules():
if isinstance(module, nn.Conv2d):
fully_shard(module, reshard_after_forward=True, mp_policy=mp_policy)
fully_shard(model, mp_policy=mp_policy)
pipeline_utils = DiffusionPipeline.from_pretrained("./stable-diffusion-3-medium-diffusers",text_encoder = None, text_encoder_2 = None, vae = None, transformer=None)
for module in pipeline_utils.text_encoder_3.modules():
if isinstance(module, T5Block):
fully_shard(module, reshard_after_forward=True, mp_policy=mp_policy)
fully_shard(pipeline_utils.text_encoder_3, mp_policy=mp_policy)
text_encoder_params = sum(p.numel() for p in pipeline_utils.text_encoder_3.parameters()) / 1e6
print_memory(f"Text encoder params: {text_encoder_params}M")
model.requires_grad_(False)
print_memory("after init model with fsdp")
fake_x = torch.randn(1, 4096, 16, 16, device="cuda", dtype=torch.bfloat16)
with torch.no_grad():
target = model(fake_x)
print_memory("SimpleNet forward finished")
model.reshard()
print_memory("SimpleNet reshard finished")
with torch.no_grad():
text_inputs = pipeline_utils.tokenizer_3(
"a prompt",
padding="max_length",
max_length=256,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
).input_ids
prompt_embeds = pipeline_utils.text_encoder_3(text_inputs.to("cuda"))[0]
print_memory("Encode prompt finished")
pipeline_utils.text_encoder_3.reshard()
print_memory("Text encoder reshard finished")
dist.destroy_process_group()
print_memory("Done")
Versions
2.7.0.dev20250107+cu124
cc @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o @zhaojuanmao @mrshenli @rohan-varma @chauhang
Metadata
Metadata
Assignees
Labels
module: fsdponcall: distributedAdd this issue/PR to distributed oncall triage queueAdd this issue/PR to distributed oncall triage queue