8000 Fix Flux Context Parallel Bug (Incoherent Image Generation) by mali-afridi · Pull Request #12443 · huggingface/diffusers · GitHub
[go: up one dir, main page]

Skip to content

Conversation

mali-afridi
Copy link

What does this PR do?

Fix Context Parallelism: Implement Ring Attention Pattern for Coherent Multi-GPU Generation

🐛 Problem

I did some testings of the https://huggingface.co/docs/diffusers/main/training/distributed_inference on main branch.
Context parallelism in diffusers was producing fragmented/split images when using multiple GPUs. Instead of generating a single coherent image, each GPU was independently generating its own portion, resulting in visible seams or completely different content in each image segment.

Example: Running with torchrun --nproc-per-node=2 would produce an image that looked like two different images side-by-side rather than one unified image.

🔍 Root Cause Analysis

The issue stems from how attention was computed in context parallel mode:

Before (Broken):

# Each GPU only sees its local chunk
GPU 0: Q=[B, S/2, H, D], K=[B, S/2, H, D], V=[B, S/2, H, D]  # First half only
GPU 1: Q=[B, S/2, H, D], K=[B, S/2, H, D], V=[B, S/2, H, D]  # Second half only

# Result: Each GPU's attention can only see S/2 context, missing the other half

Each GPU was computing attention using only its local sequence chunk for Q, K, and V. This meant:

  • GPU 0's queries could only attend to the first half of the sequence
  • GPU 1's queries could only attend to the second half of the sequence
  • No cross-GPU attention was happening → broken context → independent generations

✅ Solution: Ring Attention Pattern

This PR implements the Ring Attention pattern where:

  • Query (Q) remains local to each GPU (for computational efficiency)
  • Key (K) and Value (V) are gathered from all GPUs (for full context)

After (Fixed):

# Each GPU's Q is local, but K and V are global
GPU 0: Q=[B, S/2, H, D], K=[B, S_full, H, D], V=[B, S_full, H, D]
GPU 1: Q=[B, S/2, H, D], K=[B, S_full, H, D], V=[B, S_full, H, D]

# Result: Each GPU's queries can attend to the FULL sequence context!

📝 Implementation Details

The fix is applied directly in the attention processors after rotary embeddings but before attention computation:

FluxAttnProcessor (transformer_flux.py):

# After rotary embeddings
    # Gather K and V from all GPUs
    key_list = [torch.empty_like(key) for _ in range(world_size)]
    value_list = [torch.empty_like(value) for _ in range(world_size)]
    
    torch.distributed.all_gather(key_list, key.contiguous())
    torch.distributed.all_gather(value_list, value.contiguous())
    
    # Concatenate along sequence dimension
    key = torch.cat(key_list, dim=1)    # Now has full sequence
    value = torch.cat(value_list, dim=1)  # Now has full sequence

🧪 Testing

For testing, run the following with torchrun --nproc-per-node=2:

import torch, time
import torch.distributed as dist
from diffusers import AutoModel, FluxPipeline, ContextParallelConfig

dist.init_process_group("nccl")
rank = dist.get_rank()
world_size = dist.get_world_size()
device = torch.device("cuda", rank)
torch.cuda.set_device(device)

transformer = AutoModel.from_pretrained(
    "black-forest-labs/FLUX.1-dev",
    subfolder="transformer",
    torch_dtype=torch.bfloat16,
    parallel_config=ContextParallelConfig(ring_degree=world_size) if world_size > 1 else None
).to(device)

pipeline = FluxPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev",
    transformer=transformer,
    torch_dtype=torch.bfloat16,
).to(device)

prompt = "A beautiful sunset over mountains"


generator = torch.Generator(device=device).manual_seed(42)

start_time = time.time()
image = pipeline(
    prompt=prompt,
    num_inference_steps=28,
    generator=generator,
).images[0]
end_time = time.time()
print(f"Time taken: {end_time - start_time} seconds")
start_time = time.time()
if rank == 0:
    image.save(f"flux_{dist.get_world_size()}.png")
    print("\n" + "="*60)
    print(f"Saved flux_{dist.get_world_size()}.png")

dist.destroy_process_group()

Before Fix:

image

Result: Two different images side-by-side in output

After Fix:

image

Result: Single coherent image matching single-GPU output

Summary: This PR fixes context parallelism by ensuring each GPU's attention queries can access the full key-value context from all GPUs, implementing the Ring Attention pattern for coherent multi-GPU image generation.

Note:I have observed that some tensors in QwenImage cannot be divided by world_size (encoder_hidden_states, encoder_hidden_mask etc.) . I am also willing to make a new PR for the QwenImage support for context parallel by padding the tensors to be divisible by world size, similar to chengzeyi/ParaAttention#53 if you guys want to.

Fixes # (issue)

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@sayakpaul @a-r-r-o-w

@mali-afridi mali-afridi marked this pull request as ready for review October 6, 2025 23:56
@sayakpaul
Copy link
Member
sayakpaul commented Oct 7, 2025

Thanks for the PR. From the looks of it, it does seem like it is fully LLM-generated. Also, FWIW, we strive to keep our modeling implementations simple so, I am not sure yet if the changes align with that philosophy. @DN6 WDYT?

@DN6
Copy link
Collaborator
DN6 commented Oct 7, 2025

Hi @mali-afridi the issue seems to be because an unsupported backend is being used with CP.

This snippet should work

import torch
from diffusers import FluxPipeline
from diffusers import ContextParallelConfig

try:
    torch.distributed.init_process_group("nccl")
    rank = torch.distributed.get_rank()
    device = torch.device("cuda", rank % torch.cuda.device_count())
    torch.cuda.set_device(device)

    device = torch.device("cuda")
    pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
    pipe.to(device)
    pipe.transformer.set_attention_backend("_native_cudnn")
    pipe.transformer.enable_parallelism(config=ContextParallelConfig(ring_degree=2))
    prompt = "A picture of a cat holding a sign that says hello"

    # Must specify generator so all ranks start with same latents (or pass your own)
    generator = torch.Generator().manual_seed(42)
    image = pipe(prompt, num_inference_steps=28, guidance_scale=4.0, generator=generator).images[0]

    if rank == 0:
        image.save("output.png")

except Exception as e:
    raise e

finally:
    if torch.distributed.is_initialized():
        torch.distributed.destroy_process_group()

I've opened a PR to raise an error when an incompatible backend is used: #12446

@mali-afridi
Copy link
Author

Interesting, yeah the _native_cudnn worked. What should we do about the QwenImage though? The tensors don't get divided by world size. Padding seems to work but it will cause SSIM change as mentioned in the link in description.

@mali-afridi
Copy link
Author

For Qwen Reproducibility:

import torch, time
from diffusers import QwenImagePipeline, ContextParallelConfig

try:
    torch.distributed.init_process_group("nccl")
    rank = torch.distributed.get_rank()
    device = torch.device("cuda", rank % torch.cuda.device_count())
    torch.cuda.set_device(device)
    
    pipeline = QwenImagePipeline.from_pretrained("Qwen/Qwen-Image", torch_dtype=torch.bfloat16)
    pipeline.to(device)
    pipeline.transformer.set_attention_backend("_native_cudnn")
    pipeline.transformer.enable_parallelism(config=ContextParallelConfig(ring_degree=2))

    # pipeline.transformer.set_attention_backend("flash")
    # positive_magic = {
    #     "en": "Ultra HD, 4K, cinematic composition.", # for english prompt,
    #     "zh": "超清,4K,电影级构图" # for chinese prompt,
    # }

    prompt = '''A coffee shop entrance features a chalkboard sign reading "Qwen Coffee 😊 $2 per cup," with a neon light beside it displaying "通义千问". Next to it hangs a poster showing a beautiful Chinese woman, and beneath the poster is written "π≈3.1415926-53589793-23846264-33832795-02384197". Ultra HD, 4K, cinematic composition'''

    negative_prompt = " " # using an empty string if you do not have specific concept to remove


    # Generate with different aspect ratios
    aspect_ratios = {
        "1:1": (1328, 1328),
        "16:9": (1664, 928),
        "9:16": (928, 1664),
        "4:3": (1472, 1140),
        "3:4": (1140, 1472),
        "3:2": (1584, 1056),
        "2:3": (1056, 1584),
    }
    width, height = aspect_ratios["16:9"]


    
    # Must specify generator so all ranks start with same latents (or pass your own)
    generator = torch.Generator().manual_seed(42)
    start_time = time.time()
    image = pipeline(prompt=prompt,
    # negative_prompt=negative_prompt,
    # width=width,
    # height=height,
     num_inference_steps=50,


    # true_cfg_scale=4.0,
    generator=generator).images[0]
    end_time = time.time()
    print(f"Time taken: {end_time - start_time} seconds")
    if rank == 0:
        image.save("output1.png")

except Exception as e:
    import traceback
    print(f"An error occurred: {e}")
    print("Traceback (most recent call last):")
    traceback.print_exc()
    raise

finally:
    if torch.distributed.is_initialized():
        torch.distributed.destroy_process_group()

Error: https://pastebin.com/NhcEhb3s

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants
0