8000 First Block Cache by a-r-r-o-w · Pull Request #11180 · huggingface/diffusers · GitHub
[go: up one dir, main page]

Skip to content

First Block Cache #11180

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 29 commits into
base: main
Choose a base branch
from
Open

First Block Cache #11180

wants to merge 29 commits into from

Conversation

a-r-r-o-w
Copy link
Member
@a-r-r-o-w a-r-r-o-w commented Mar 31, 2025

FBC Reference: https://github.com/chengzeyi/ParaAttention

Minimal example

import torch
from diffusers import CogView4Pipeline
from diffusers.hooks import apply_first_block_cache, FirstBlockCacheConfig

pipe = CogView4Pipeline.from_pretrained("THUDM/CogView4-6B", torch_dtype=torch.bfloat16)
pipe.to("cuda")

apply_first_block_cache(pipe.transformer, FirstBlockCacheConfig(threshold=0.3))

prompt = "A photo of an astronaut riding a horse on mars"
image = pipe(prompt, generator=torch.Generator().manual_seed(42)).images[0]
image.save("output.png")

Benchmark scripts

Threshold vs Generation time (in seconds) for each model. In general, values below 0.2 work well for First Block Cache depending on the model. Higher values leads to blurring and artifacting

Threshold CogView4 HunyuanVideo LTX Video Wan Flux
0.00 40.51 121.85 33.64 222.17 16.47
0.03 - - 27.14 - -
0.05 24.08 62.47 21.73 139.26 12.63
0.10 17.55 41.84 15.10 89.99 9.27
0.20 12.99 28.11 10.29 57.01 5.91
0.40 - 18.93 7.25 - 3.70
0.50 - 16.65 - - 3.13
CogView4
import argparse
import pathlib

import torch
from diffusers import CogView4Pipeline
from diffusers.hooks import FirstBlockCacheConfig
from diffusers.utils.logging import set_verbosity_debug


def main(args):
    output_dir = pathlib.Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    set_verbosity_debug()
    pipe = CogView4Pipeline.from_pretrained("THUDM/CogView4-6B", torch_dtype=torch.bfloat16)
    pipe.to("cuda")

    prompt = "A photo of the eye of agomotto, a mystical object from the Marvel universe, with intricate details and a glowing effect, in a fantasy style. The eye sits at the center of a mystical landscape surrounded by green trees, vibrant flowers and orbs of light. The scene is illuminated by a soft, ethereal glow, creating a magical atmosphere. The eye itself is detailed with swirling patterns and a radiant light"
    negative_prompt = "bad anatomy, ugly, blurry, out of focus, low quality, worst quality, normal quality, jpeg artifacts, signature, watermark, username, artist name, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, artist name, text, error, missing fingers, extra digit, fewer digits, cropped"

    # Warmup
    pipe(prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=2, generator=torch.Generator().manual_seed(42)).images[0]

    # Benchmark
    for threshold in [0, 0.05, 0.1, 0.2, 0.4, 0.5]:
        print(f"Using threshold: {threshold}")
        filename = f"output_{threshold:.5f}.png"

        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)
        start.record()
        if threshold == 0:
            image = pipe(prompt=prompt, negative_prompt=negative_prompt, generator=torch.Generator().manual_seed(42)).images[0]
        else:
            pipe.transformer.enable_cache(FirstBlockCacheConfig(threshold=threshold))
            image = pipe(prompt=prompt, negative_prompt=negative_prompt, generator=torch.Generator().manual_seed(42)).images[0]
            pipe.transformer.disable_cache()
        end.record()
        torch.cuda.synchronize()
        
        elapsed_time = start.elapsed_time(end)
        print(f"Elapsed time: {elapsed_time / 1000:.2f}s")
        print(f"Output saved to {output_dir / filename}")
        
        image.save((output_dir / filename).as_posix())


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--output_dir", type=str, default="output_cogview4")
    return parser.parse_args()


if __name__ == "__main__":
    args = get_args()
    main(args)
HunyuanVideo
import argparse
import pathlib

import torch
from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel
from diffusers.hooks import FirstBlockCacheConfig
from diffusers.utils import export_to_video
from diffusers.utils.logging import set_verbosity_debug


def main(args):
    output_dir = pathlib.Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    set_verbosity_debug()
    model_id = "hunyuanvideo-community/HunyuanVideo"
    transformer = HunyuanVideoTransformer3DModel.from_pretrained(
        model_id, subfolder="transformer", torch_dtype=torch.bfloat16
    )
    pipe = HunyuanVideoPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.float16)
    pipe.vae.enable_tiling()
    pipe.to("cuda")

    prompt = "A cat walks on the grass, realistic"

    # Warmup
    pipe(prompt=prompt, height=320, width=512, num_frames=61, num_inference_steps=2, generator=torch.Generator().manual_seed(42)).frames[0]

    # Benchmark
    for threshold in [0, 0.05, 0.1, 0.2, 0.4, 0.5]:
        print(f"Using threshold: {threshold}")
        filename = f"output_{threshold:.5f}.mp4"

        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)
        start.record()
        if threshold == 0:
            video = pipe(prompt=prompt, height=320, width=512, num_frames=61, generator=torch.Generator().manual_seed(42)).frames[0]
        else:
            pipe.transformer.enable_cache(FirstBlockCacheConfig(threshold=threshold))
            video = pipe(prompt=prompt, height=320, width=512, num_frames=61, generator=torch.Generator().manual_seed(42)).frames[0]
            pipe.transformer.disable_cache()
        end.record()
        torch.cuda.synchronize()
        
        elapsed_time = start.elapsed_time(end)
        print(f"Elapsed time: {elapsed_time / 1000:.2f}s")
        print(f"Output saved to {output_dir / filename}")
        
        export_to_video(video, (output_dir / filename).as_posix(), fps=16)


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--output_dir", type=str, default="output_hunyuanvideo")
    return parser.parse_args()


if __name__ == "__main__":
    args = get_args()
    main(args)
LTX Video
import argparse
import pathlib

import torch
from diffusers import LTXPipeline
from diffusers.hooks import FirstBlockCacheConfig
from diffusers.utils import export_to_video
from diffusers.utils.logging import set_verbosity_debug


def main(args):
    output_dir = pathlib.Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    set_verbosity_debug()
    pipe = LTXPipeline.from_pretrained("a-r-r-o-w/LTX-Video-diffusers", torch_dtype=torch.bfloat16)
    pipe.to("cuda")

    prompt = "A woman with long brown hair and light skin smiles at another woman with long blonde hair. The woman with brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek. The camera angle is a close-up, focused on the woman with brown hair's face. The lighting is warm and natural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be real-life footage"
    negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"

    # Warmup
    pipe(prompt=prompt, negative_prompt=negative_prompt, height=480, width=704, num_frames=161, num_inference_steps=2, generator=torch.Generator().manual_seed(42)).frames[0]

    # Benchmark
    for threshold in [0, 0.03, 0.05, 0.1, 0.2, 0.4]:
        print(f"Using threshold: {threshold}")
        filename = f"output_{threshold:.5f}.mp4"

        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)
        start.record()
        if threshold == 0:
            video = pipe(prompt=prompt, negative_prompt=negative_prompt, height=480, width=704, num_frames=161, generator=torch.Generator().manual_seed(42)).frames[0]
        else:
            pipe.transformer.enable_cache(FirstBlockCacheConfig(threshold=threshold))
            video = pipe(prompt=prompt, negative_prompt=negative_prompt, height=480, width=704, num_frames=161, generator=torch.Generator().manual_seed(42)).frames[0]
            pipe.transformer.disable_cache()
        end.record()
        torch.cuda.synchronize()
        
        elapsed_time = start.elapsed_time(end)
        print(f"Elapsed time: {elapsed_time / 1000:.2f}s")
        print(f"Output saved to {output_dir / filename}")
        
        export_to_video(video, (output_dir / filename).as_posix(), fps=24)


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--output_dir", type=str, default="output_hunyuanvideo")
    return parser.parse_args()


if __name__ == "__main__":
    args = get_args()
    main(args)
Wan
import argparse
import pathlib

import torch
from diffusers import AutoencoderKLWan, WanPipeline
from diffusers.hooks import FirstBlockCacheConfig
from diffusers.utils import export_to_video
from diffusers.utils.logging import set_verbosity_debug


def main(args):
    output_dir = pathlib.Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    set_verbosity_debug()
    model_id = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
    vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
    pipe = WanPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
    pipe.to("cuda")

    prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
    negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"

    # Warmup
    pipe(prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=2, generator=torch.Generator().manual_seed(42)).frames[0]

    # Benchmark
    for threshold in [0, 0.05, 0.1, 0.2]:
        print(f"Using threshold: {threshold}")
        filename = f"output_{threshold:.5f}.mp4"

        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)
        start.record()
        if threshold == 0:
            video = pipe(prompt=prompt, negative_prompt=negative_prompt, generator=torch.Generator().manual_seed(42)).frames[0]
        else:
            pipe.transformer.enable_cache(FirstBlockCacheConfig(threshold=threshold))
            video = pipe(prompt=prompt, negative_prompt=negative_prompt, generator=torch.Generator().manual_seed(42)).frames[0]
            pipe.transformer.disable_cache()
        end.record()
        torch.cuda.synchronize()
        
        elapsed_time = start.elapsed_time(end)
        print(f"Elapsed time: {elapsed_time / 1000:.2f}s")
        print(f"Output saved to {output_dir / filename}")
        
        export_to_video(video, (output_dir / filename).as_posix(), fps=16)


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--output_dir", type=str, default="output_wan")
    return parser.parse_args()


if __name__ == "__main__":
    args = get_args()
    main(args)
Flux
import argparse
import pathlib

import torch
from diffusers import FluxPipeline
from diffusers.hooks import FirstBlockCacheConfig
from diffusers.utils.logging import set_verbosity_debug


def main(args):
    output_dir = pathlib.Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    set_verbosity_debug()
    pipe = FluxPipeline.from_pretrained("black-forest-labs/Flux.1-Dev", torch_dtype=torch.bfloat16, cache_dir="/raid/.cache/huggingface")
    pipe.to("cuda")

    prompt = "A photo of the eye of agomotto, a mystical object from the Marvel universe, with intricate details and a glowing effect, in a fantasy style. The eye sits at the center of a mystical landscape surrounded by green trees, vibrant flowers and orbs of light. The scene is illuminated by a soft, ethereal glow, creating a magical atmosphere. The eye itself is detailed with swirling patterns and a radiant light"

    # Warmup
    pipe(prompt=prompt, num_inference_steps=2, generator=torch.Generator().manual_seed(42)).images[0]

    # Benchmark
    for threshold in [0, 0.05, 0.1, 0.2, 0.4, 0.5]:
        print(f"Using threshold: {threshold}")
        filename = f"output_{threshold:.5f}.png"

        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)
        start.record()
        if threshold == 0:
            image = pipe(prompt=prompt, generator=torch.Generator().manual_seed(42)).images[0]
        else:
            pipe.transformer.enable_cache(FirstBlockCacheConfig(threshold=threshold))
            image = pipe(prompt=prompt, generator=torch.Generator().manual_seed(42)).images[0]
            pipe.transformer.disable_cache()
        end.record()
        torch.cuda.synchronize()
        
        elapsed_time = start.elapsed_time(end)
        print(f"Elapsed time: {elapsed_time / 1000:.2f}s")
        print(f"Output saved to {output_dir / filename}")
        
        image.save((output_dir / filename).as_posix())


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--output_dir", type=str, default="output_flux")
    return parser.parse_args()


if __name__ == "__main__":
    args = get_args()
    main(args)

Visual result comparison

CogView4
threshold=0.00 threshold=0.05
threshold=0.10 threshold=0.20
Hunyuan Video
threshold=0.00 threshold=0.05
output_0.00000.mp4
output_0.05000.mp4
threshold=0.10 threshold=0.20
output_0.10000.mp4
output_0.20000.mp4
threshold=0.40 threshold=0.50
output_0.40000.mp4
output_0.50000.mp4
LTX Video
threshold=0.00 threshold=0.05
output_0.00000.mp4
output_0.03000.mp4
threshold=0.10 threshold=0.20
output_0.05000.mp4
output_0.10000.mp4
threshold=0.40 threshold=0.50
output_0.20000.mp4
output_0.40000.mp4
Wan
threshold=0.00 threshold=0.05
output_0.00000.mp4
output_0.05000.mp4
threshold=0.10 threshold=0.20
output_0.10000.mp4
output_0.20000.mp4
Flux
threshold=0.00 threshold=0.05
threshold=0.10 threshold=0.20
threshold=0.40 threshold=0.50

Using with torch.compile

  • There is a forced graph break for the data-dependant control flow branching. This portion of code will always run in eager mode
  • There are a few recompilations triggered at a weird/unexpected location - the attention processor invocation. This only happens when using hooks, so I believe the current hook implementation is making torch.compile tracing add some unnecessary id/type guards. This will be tackled in the future since I haven't been able to make much progress into rewriting it

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@@ -79,10 +79,14 @@ def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int,
def forward(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @yiyixuxu for reviewing changes to the transformer here. The changes were made to simplify some of the code required to make cache techniques work somewhat more easily without even more if-else branching.

Ideally, if we stick to implementing models such that all blocks take in both hidden_states and encoder_hidden_states, and always return (hidden_states, encoder_hidden_states) from the block, a lot of design choices in the hook-based code can be simplified.

For now, I think these changes should be safe and come without any significant overhead to generation time (I haven't benchmarked though).

8000

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'm ok with the change but I think we can make encoder_hidden_states optional here no? it is not much trouble and won't break for these using these blocks on their own

@@ -0,0 +1,222 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @chengzeyi Would be super cool if you could give this PR a review since we're trying to integrate FBC to work with all supported models!

Currently, I've only done limited testing on few models but it should be easily extendable to all

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@a-r-r-o-w I see, let me take a look!

return cls._registry[model_class]


def _register_transformer_blocks_metadata():
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @DN6 For now, this PR only adds metadata for transformer blocks. The information here will be filled up over time for more models to simplify the assumptions made in hook-based code and make it cleaner to work with. More metadata needs to be maintained for transformer implementations to simplify cache methods like FasterCache, which I'll cover in future PRs

)


class BaseMarkedState(BaseState):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To explain simply, a "marked" state is a copies of a state object for different batches of data. In our pipelines, we do the following:

  • concatenate unconditional and conditional batch and perform single forward pass through transformer
  • perform individual forward passes for conditional and unconditional batch

The state variables must track values specific to each batch of data over all inference steps, otherwise you might end up in a situation where the state variable for conditional batch is used for unconditional batch, or vice versa.

@@ -917,6 +917,7 @@ def __call__(
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype)

cc.mark_state("cond")
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doing it this way helps us distinguish different batches of data. I believe this design will fit well with the upcoming "guiders" to support multiple guidance methods. As guidance methods may use 1 (guidance-distilled) or 2 (CFG) or 3 (PAG) or more latent batches, we can call "mark state" any number of times to distinguish between calls to transformer.forward with different batches of data for the same inference step.

Comment on lines +2719 to +2721
# TODO(aryan, dhruv): the cache tester mixins should probably be rewritten so that more models can be tested out
# of the box once there is better cache support/implementation
class FirstBlockCacheTesterMixin:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding new cache tester mixins each time and extending the list of parent mixins for each pipeline test is probably not going to be a clean way of testing. We can refactor this in the future and consolidate all cache methods into a single tester once they are better supported/implemented for most models

@a-r-r-o-w a-r-r-o-w requested review from DN6 and yiyixuxu April 1, 2025 23:37


# fmt: off
def _skip_block_output_fn___hidden_states_0___ret___hidden_states(self, *args, **kwargs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any specific reason to use this naming convention here? Function is just meant to return combinations of hidden/encoder_hidden_states right?

Copy link
Member Author
@a-r-r-o-w a-r-r-o-w Apr 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not really. It just spells out what argument index hidden_states is at and what it returns. Do you have any particular recommendation?

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
@a-r-r-o-w a-r-r-o-w mentioned this pull request Apr 14, 2025
12 tasks
@a-r-r-o-w
Copy link
Member Author

@DN6 Addressed the review comments. Could you give it another look?

@nitinmukesh
Copy link

Please could this feature also be implemented for LTX-Video, LTXConditionPipeline (0.9.5 and above), if not already covered.

@a-r-r-o-w
Copy link
Member Author

@nitinmukesh It's already covered I believe. But I didn't test very rigorously, so it would be super helpful if you wanted to give it a try 🤗

@a-r-r-o-w
Copy link
Member Author

Gentle ping @DN6

encoder_hidden_states = kwargs.get("encoder_hidden_states", None)
if hidden_states is None and len(args) > 0:
hidden_states = args[0]
if encoder_hidden_states is None and len(args) > 1:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a thought, could we refactor all the blocks to always use kwargs? And just enforce that? Would take a lot of guess work out of building future features like this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be no problem for us to enforce passing kwargs. But, for outside implementations built on diffusers, if we want OOTB compatibility and make it easy to work with custom implementations (for example, ComfyUI-maintained original modeling implementations, custom research repo wanting to use some cache implementation for demo, ...), we should support args-index based identification. So, in the metadata class for transformer & intermediate blocks, I would say we should maintain this info

Copy link
Collaborator
@DN6 DN6 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nicely done! 👍🏽 Good to merge once failing tests are addressed.

@a-r-r-o-w
Copy link
Member Author

Failing tests should hopefully be fixed now. Caused due to a divergence in behaviour from the refactor where, previously, a cache context was not really necessary as the default state object would have been used, which is no longer the case. The current implementation is the correct behaviour

Copy link
Collaborator
@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks @a-r-r-o-w
I left some comments, I just want to brainstorm if we can find a more flexible and extensible ways to make transformers work better with these techniques

@@ -79,10 +79,14 @@ def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int,
def forward(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'm ok with the change but I think we can make encoder_hidden_states optional here no? it is not much trouble and won't break for these using these blocks on their own



@maybe_allow_in_graph
@TransformerBlockRegistry.register(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ohhh cannt we just make a xxTransformerBlockOutput? could be a named tuple or something else, depends on what type of info you need now and could need in future

Copy link
Member Author
@a-r-r-o-w a-r-r-o-w May 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reason for not changing what is returned is that it might be a big breaking change for people that import and use the blocks. Also, in general as I've been implementing new methods, there comes requirements that I cannot anticipate beforehand - being able to add some metadata information quickly would be much more convenient compared to committing to a dedicated approach imo. Once we support a large number of techniques, we could look into refactoring/better design (since these would all be internal use private attributes that we don't have to maintain BC with), wdyt?

We can refactor this a bit for now though. Dhruv suggested replacing the centralized registry with attributes instead. So, similar to how we have _no_split_modules, etc. at ModelMixin level, we can maintain properties at the blocks too

Copy link
Collaborator
@yiyixuxu yiyixuxu May 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it would break though, this just an example, these two should return the same, when you accept it as x, y = fun(...)

from typing import NamedTuple
import torch

class XXTransformerBlockOutput(NamedTuple):
    encoder_hidden_states: torch.Tensor
    hidden_states: torch.Tensor

def fun1(x, y):
    # Some processing
    return x, y

def fun2(x, y):
    # Same processing
    return XXTransformerBlockOutput(x, y)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay I'll update.

How do you suggest we return the output from the cache hooks? If xxTransformerBlockOutput means maintaining an output class per transformer block type, the cache hook should also return the same output class. Do we instead register what the output class should be to the block metadata so that the cache hook can return the correct object? Or do we use a generic TransformerBlockOutput class for all models?

Copy link
Collaborator
@yiyixuxu yiyixuxu May 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just looked through the code so my understanding could be completely wrong, let me know if that's the case (obviously you know better since you implemented it)

I think you would not need to register anything, no? if you use this new output you will get the new output object in your hook and you will know which class it is and also all the other info you need (e.g which value is which) - would this not be the case?

def new_forward(self, module: torch.nn.Module, *args, **kwargs):

    output = self.fn_ref.original_forward(*args, **kwargs)
    
    encoder_hidden_states = fn_to_update(output.encoder_hidden_states)
    
    hidden_states = fn_to_update(output.hidden_states)

    new_output = output.__class__(
        encoder_hidden_states=encoder_hidden_states, 
        hidden_states=hidden_states
    )
    
    return new_output

@glide-the
Copy link
Contributor

This is a great branch, but I encountered this problem when testing it. The first method will fail, perhaps because some logic is not implemented.

model: CogVIew4-6B
image

Detailed Error


apply_first_block_cache(pipe.transformer, FirstBlockCacheConfig(threshold=0.3))

prompt = "A photo of an astronaut riding a horse on mars"
image = pipe(prompt, generator=torch.Generator().manual_seed(42)).images[0]
image.save("output.png")

  0%|                                                                                                                            | 0/50 [00:00<?, ?it/s]
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[12], line 4
      1 apply_first_block_cache(pipe.transformer, FirstBlockCacheConfig(threshold=0.3))
      3 prompt = "A photo of an astronaut riding a horse on mars"
----> 4 image = pipe(prompt, generator=torch.Generator().manual_seed(42)).images[0]
      5 image.save("output.png")

File /mnt/ceph/develop/jiawei/conda_env/diffuser/lib/python3.10/site-packages/torch/utils/_contextlib.py:116, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    113 @functools.wraps(func)
    114 def decorate_context(*args, **kwargs):
    115     with ctx_factory():
--> 116         return func(*args, **kwargs)

File /mnt/ceph/develop/jiawei/conda_env/diffuser/lib/python3.10/site-packages/diffusers/pipelines/cogview4/pipeline_cogview4.py:623, in CogView4Pipeline.__call__(self, prompt, negative_prompt, height, width, num_inference_steps, timesteps, sigmas, guidance_scale, num_images_per_prompt, generator, latents, prompt_embeds, negative_prompt_embeds, original_size, crops_coords_top_left, output_type, return_dict, attention_kwargs, callback_on_step_end, callback_on_step_end_tensor_inputs, max_sequence_length)
    620 timestep = t.expand(latents.shape[0])
    622 with self.transformer.cache_context("cond"):
--> 623     noise_pred_cond = self.transformer(
    624         hidden_states=latent_model_input,
    625         encoder_hidden_states=prompt_embeds,
    626         timestep=timestep,
    627         original_size=original_size,
    628         target_size=target_size,
    629         crop_coords=crops_coords_top_left,
    630         attention_kwargs=attention_kwargs,
    631         return_dict=False,
    632     )[0]
    634 # perform guidance
    635 if self.do_classifier_free_guidance:

File /mnt/ceph/develop/jiawei/conda_env/diffuser/lib/python3.10/site-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
   1551     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1552 else:
-> 1553     return self._call_impl(*args, **kwargs)

File /mnt/ceph/develop/jiawei/conda_env/diffuser/lib/python3.10/site-packages/torch/nn/modules/module.py:1562, in Module._call_impl(self, *args, **kwargs)
   1557 # If we don't have any hooks, we want to skip the rest of the logic in
   1558 # this function, and just call forward.
   1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1560         or _global_backward_pre_hooks or _global_backward_hooks
   1561         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562     return forward_call(*args, **kwargs)
   1564 try:
   1565     result = None

File /mnt/ceph/develop/jiawei/conda_env/diffuser/lib/python3.10/site-packages/diffusers/models/transformers/transformer_cogview4.py:740, in CogView4Transformer2DModel.forward(self, hidden_states, encoder_hidden_states, timestep, original_size, target_size, crop_coords, attention_kwargs, return_dict, attention_mask, image_rotary_emb)
    730         hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
    731             block,
    732             hidden_states,
   (...)
    737             attention_kwargs,
    738         )
    739     else:
--> 740         hidden_states, encoder_hidden_states = block(
    741             hidden_states,
    742             encoder_hidden_states,
    743             temb,
    744             image_rotary_emb,
    745             attention_mask,
    746             attention_kwargs,
    747         )
    749 # 4. Output norm & projection
    750 hidden_states = self.norm_out(hidden_states, temb)

File /mnt/ceph/develop/jiawei/conda_env/diffuser/lib/python3.10/site-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
   1551     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1552 else:
-> 1553     return self._call_impl(*args, **kwargs)

File /mnt/ceph/develop/jiawei/conda_env/diffuser/lib/python3.10/site-packages/torch/nn/modules/module.py:1562, in Module._call_impl(self, *args, **kwargs)
   1557 # If we don't have any hooks, we want to skip the rest of the logic in
   1558 # this function, and just call forward.
   1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1560         or _global_backward_pre_hooks or _global_backward_hooks
   1561         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562     return forward_call(*args, **kwargs)
   1564 try:
   1565     result = None

File /mnt/ceph/develop/jiawei/conda_env/diffuser/lib/python3.10/site-packages/diffusers/hooks/hooks.py:189, in HookRegistry.register_hook.<locals>.create_new_forward.<locals>.new_forward(module, *args, **kwargs)
    187 def new_forward(module, *args, **kwargs):
    188     args, kwargs = function_reference.pre_forward(module, *args, **kwargs)
--> 189     output = function_reference.forward(*args, **kwargs)
    190     return function_reference.post_forward(module, output)

File /mnt/ceph/develop/jiawei/conda_env/diffuser/lib/python3.10/site-packages/diffusers/hooks/first_block_cache.py:89, in FBCHeadBlockHook.new_forward(self, module, *args, **kwargs)
     86 else:
     87     hidden_states_residual = output - original_hidden_states
---> 89 shared_state: FBCSharedBlockState = self.state_manager.get_state()
     90 hidden_states = encoder_hidden_states = None
     91 should_compute = self._should_compute_remaining_blocks(hidden_states_residual)

File /mnt/ceph/develop/jiawei/conda_env/diffuser/lib/python3.10/site-packages/diffusers/hooks/hooks.py:44, in StateManager.get_state(self)
     42 def get_state(self):
     43     if self._current_context is None:
---> 44         raise ValueError("No context is set. Please set a context before retrieving the state.")
     45     if self._current_context not in self._state_cache.keys():
     46         self._state_cache[self._current_context] = self._state_cls(*self._init_args, **self._init_kwargs)

ValueError: No context is set. Please set a context before retrieving the state.

@a-r-r-o-w
Copy link
Member Author

@glide-the Thanks for testing and reporting the issue!

After the latest refactor, it seems that applying cache with apply_first_block_cache was broken but worked as expected with pipe.transformer.enable_cache(). This has now been fixed and both approaches of enabling cache should be usable

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants
0