-
Notifications
You must be signed in to change notification settings - Fork 6k
First Block Cache #11180
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
First Block Cache #11180
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
…ithout too much model-specific intrusion code)
@@ -79,10 +79,14 @@ def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, | |||
def forward( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @yiyixuxu for reviewing changes to the transformer here. The changes were made to simplify some of the code required to make cache techniques work somewhat more easily without even more if-else branching.
Ideally, if we stick to implementing models such that all blocks take in both hidden_states and encoder_hidden_states, and always return (hidden_states, encoder_hidden_states)
from the block, a lot of design choices in the hook-based code can be simplified.
For now, I think these changes should be safe and come without any significant overhead to generation time (I haven't benchmarked though).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i'm ok with the change but I think we can make encoder_hidden_states
optional here no? it is not much trouble and won't break for these using these blocks on their own
@@ -0,0 +1,222 @@ | |||
# Copyright 2024 The HuggingFace Team. All rights reserved. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @chengzeyi Would be super cool if you could give this PR a review since we're trying to integrate FBC to work with all supported models!
Currently, I've only done limited testing on few models but it should be easily extendable to all
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@a-r-r-o-w I see, let me take a look!
src/diffusers/hooks/_helpers.py
Outdated
return cls._registry[model_class] | ||
|
||
|
||
def _register_transformer_blocks_metadata(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @DN6 For now, this PR only adds metadata for transformer blocks. The information here will be filled up over time for more models to simplify the assumptions made in hook-based code and make it cleaner to work with. More metadata needs to be maintained for transformer implementations to simplify cache methods like FasterCache, which I'll cover in future PRs
src/diffusers/hooks/hooks.py
Outdated
) | ||
|
||
|
||
class BaseMarkedState(BaseState): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To explain simply, a "marked" state is a copies of a state object for different batches of data. In our pipelines, we do the following:
- concatenate unconditional and conditional batch and perform single forward pass through transformer
- perform individual forward passes for conditional and unconditional batch
The state variables must track values specific to each batch of data over all inference steps, otherwise you might end up in a situation where the state variable for conditional batch is used for unconditional batch, or vice versa.
@@ -917,6 +917,7 @@ def __call__( | |||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML | |||
timestep = t.expand(latents.shape[0]).to(latents.dtype) | |||
|
|||
cc.mark_state("cond") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doing it this way helps us distinguish different batches of data. I believe this design will fit well with the upcoming "guiders" to support multiple guidance methods. As guidance methods may use 1 (guidance-distilled) or 2 (CFG) or 3 (PAG) or more latent batches, we can call "mark state" any number of times to distinguish between calls to transformer.forward
with different batches of data for the same inference step.
# TODO(aryan, dhruv): the cache tester mixins should probably be rewritten so that more models can be tested out | ||
# of the box once there is better cache support/implementation | ||
class FirstBlockCacheTesterMixin: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adding new cache tester mixins each time and extending the list of parent mixins for each pipeline test is probably not going to be a clean way of testing. We can refactor this in the future and consolidate all cache methods into a single tester once they are better supported/implemented for most models
src/diffusers/hooks/_helpers.py
Outdated
|
||
|
||
# fmt: off | ||
def _skip_block_output_fn___hidden_states_0___ret___hidden_states(self, *args, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any specific reason to use this naming convention here? Function is just meant to return combinations of hidden/encoder_hidden_states right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not really. It just spells out what argument index hidden_states
is at and what it returns. Do you have any particular recommendation?
Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
@DN6 Addressed the review comments. Could you give it another look? |
Please could this feature also be implemented for LTX-Video, LTXConditionPipeline (0.9.5 and above), if not already covered. |
@nitinmukesh It's already covered I believe. But I didn't test very rigorously, so it would be super helpful if you wanted to give it a try 🤗 |
Gentle ping @DN6 |
src/diffusers/hooks/_helpers.py
Outdated
encoder_hidden_states = kwargs.get("encoder_hidden_states", None) | ||
if hidden_states is None and len(args) > 0: | ||
hidden_states = args[0] | ||
if encoder_hidden_states is None and len(args) > 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a thought, could we refactor all the blocks to always use kwargs? And just enforce that? Would take a lot of guess work out of building future features like this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be no problem for us to enforce passing kwargs. But, for outside implementations built on diffusers, if we want OOTB compatibility and make it easy to work with custom implementations (for example, ComfyUI-maintained original modeling implementations, custom research repo wanting to use some cache implementation for demo, ...), we should support args-index based identification. So, in the metadata class for transformer & intermediate blocks, I would say we should maintain this info
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nicely done! 👍🏽 Good to merge once failing tests are addressed.
Failing tests should hopefully be fixed now. Caused due to a divergence in behaviour from the refactor where, previously, a cache context was not really necessary as the default state object would have been used, which is no longer the case. The current implementation is the correct behaviour |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks @a-r-r-o-w
I left some comments, I just want to brainstorm if we can find a more flexible and extensible ways to make transformers work better with these techniques
@@ -79,10 +79,14 @@ def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, | |||
def forward( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i'm ok with the change but I think we can make encoder_hidden_states
optional here no? it is not much trouble and won't break for these using these blocks on their own
|
||
|
||
@maybe_allow_in_graph | ||
@TransformerBlockRegistry.register( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ohhh cannt we just make a xxTransformerBlockOutput
? could be a named tuple or something else, depends on what type of info you need now and could need in future
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reason for not changing what is returned is that it might be a big breaking change for people that import and use the blocks. Also, in general as I've been implementing new methods, there comes requirements that I cannot anticipate beforehand - being able to add some metadata information quickly would be much more convenient compared to committing to a dedicated approach imo. Once we support a large number of techniques, we could look into refactoring/better design (since these would all be internal use private attributes that we don't have to maintain BC with), wdyt?
We can refactor this a bit for now though. Dhruv suggested replacing the centralized registry with attributes instead. So, similar to how we have _no_split_modules
, etc. at ModelMixin level, we can maintain properties at the blocks too
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think it would break though, this just an example, these two should return the same, when you accept it as x, y = fun(...)
from typing import NamedTuple
import torch
class XXTransformerBlockOutput(NamedTuple):
encoder_hidden_states: torch.Tensor
hidden_states: torch.Tensor
def fun1(x, y):
# Some processing
return x, y
def fun2(x, y):
# Same processing
return XXTransformerBlockOutput(x, y)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay I'll update.
How do you suggest we return the output from the cache hooks? If xxTransformerBlockOutput means maintaining an output class per transformer block type, the cache hook should also return the same output class. Do we instead register what the output class should be to the block metadata so that the cache hook can return the correct object? Or do we use a generic TransformerBlockOutput class for all models?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just looked through the code so my understanding could be completely wrong, let me know if that's the case (obviously you know better since you implemented it)
I think you would not need to register anything, no? if you use this new output you will get the new output object in your hook and you will know which class it is and also all the other info you need (e.g which value is which) - would this not be the case?
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
output = self.fn_ref.original_forward(*args, **kwargs)
encoder_hidden_states = fn_to_update(output.encoder_hidden_states)
hidden_states = fn_to_update(output.hidden_states)
new_output = output.__class__(
encoder_hidden_states=encoder_hidden_states,
hidden_states=hidden_states
)
return new_output
This is a great branch, but I encountered this problem when testing it. The first method will fail, perhaps because some logic is not implemented. Detailed Error
|
@glide-the Thanks for testing and reporting the issue! After the latest refactor, it seems that applying cache with |
FBC Reference: https://github.com/chengzeyi/ParaAttention
Minimal example
Benchmark scripts
Threshold vs Generation time (in seconds) for each model. In general, values below 0.2 work well for First Block Cache depending on the model. Higher values leads to blurring and artifacting
CogView4
HunyuanVideo
LTX Video
Wan
Flux
Visual result comparison
CogView4
Hunyuan Video
output_0.00000.mp4
output_0.05000.mp4
output_0.10000.mp4
output_0.20000.mp4
output_0.40000.mp4
output_0.50000.mp4
LTX Video
output_0.00000.mp4
output_0.03000.mp4
output_0.05000.mp4
output_0.10000.mp4
output_0.20000.mp4
output_0.40000.mp4
Wan
output_0.00000.mp4
output_0.05000.mp4
output_0.10000.mp4
output_0.20000.mp4
Flux
Using with
torch.compile