-
Notifications
You must be signed in to change notification settings - Fork 6.1k
First Block Cache #11180
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
First Block Cache #11180
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
…ithout too much model-specific intrusion code)
@@ -79,10 +79,14 @@ def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, | |||
def forward( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @yiyixuxu for reviewing changes to the transformer here. The changes were made to simplify some of the code required to make cache techniques work somewhat more easily without even more if-else branching.
Ideally, if we stick to implementing models such that all blocks take in both hidden_states and encoder_hidden_states, and always return (hidden_states, encoder_hidden_states)
from the block, a lot of design choices in the hook-based code can be simplified.
For now, I think these changes should be safe and come without any significant overhead to generation time (I haven't benchmarked though).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i'm ok with the change but I think we can make encoder_hidden_states
optional here no? it is not much trouble and won't break for these using these blocks on their own
@@ -0,0 +1,222 @@ | |||
# Copyright 2024 The HuggingFace Team. All rights reserved. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @chengzeyi Would be super cool if you could give this PR a review since we're trying to integrate FBC to work with all supported models!
Currently, I've only done limited testing on few models but it should be easily extendable to all
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@a-r-r-o-w I see, let me take a look!
src/diffusers/hooks/hooks.py
Outdated
) | ||
|
||
|
||
class BaseMarkedState(BaseState): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To explain simply, a "marked" state is a copies of a state object for different batches of data. In our pipelines, we do the following:
- concatenate unconditional and conditional batch and perform single forward pass through transformer
- perform individual forward passes for conditional and unconditional batch
The state variables must track values specific to each batch of data over all inference steps, otherwise you might end up in a situation where the state variable for conditional batch is used for unconditional batch, or vice versa.
@@ -917,6 +917,7 @@ def __call__( | |||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML | |||
timestep = t.expand(latents.shape[0]).to(latents.dtype) | |||
|
|||
cc.mark_state("cond") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doing it this way helps us distinguish different batches of data. I believe this design will fit well with the upcoming "guiders" to support multiple guidance methods. As guidance methods may use 1 (guidance-distilled) or 2 (CFG) or 3 (PAG) or more latent batches, we can call "mark state" any number of times to distinguish between calls to transformer.forward
with different batches of data for the same inference step.
# TODO(aryan, dhruv): the cache tester mixins should probably be rewritten so that more models can be tested out | ||
# of the box once there is better cache support/implementation | ||
class FirstBlockCacheTesterMixin: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adding new cache tester mixins each time and extending the list of parent mixins for each pipeline test is probably not going to be a clean way of testing. We can refactor this in the future and consolidate all cache methods into a single tester once they are better supported/implemented for most models
src/diffusers/hooks/_helpers.py
Outdated
|
||
|
||
# fmt: off | ||
def _skip_block_output_fn___hidden_states_0___ret___hidden_states(self, *args, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any specific reason to use this naming convention here? Function is just meant to return combinations of hidden/encoder_hidden_states right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not really. It just spells out what argument index hidden_states
is at and what it returns. Do you have any particular recommendation?
Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
@DN6 Addressed the review comments. Could you give it another look? |
src/diffusers/hooks/_helpers.py
Outdated
encoder_hidden_states = kwargs.get("encoder_hidden_states", None) | ||
if hidden_states is None and len(args) > 0: | ||
hidden_states = args[0] | ||
if encoder_hidden_states is None and len(args) > 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a thought, could we refactor all the blocks to always use kwargs? And just enforce that? Would take a lot of guess work out of building future features like this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be no problem for us to enforce passing kwargs. But, for outside implementations built on diffusers, if we want OOTB compatibility and make it easy to work with custom implementations (for example, ComfyUI-maintained original modeling implementations, custom research repo wanting to use some cache implementation for demo, ...), we should support args-index based identification. So, in the metadata class for transformer & intermediate blocks, I would say we should maintain this info
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nicely done! 👍🏽 Good to merge once failing tests are addressed.
Failing tests should hopefully be fixed now. Caused due to a divergence in behaviour from the refactor where, previously, a cache context was not really necessary as the default state object would have been used, which is no longer the case. The current implementation is the correct behaviour |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks @a-r-r-o-w
I left some comments, I just want to brainstorm if we can find a more flexible and extensible ways to make transformers work better with these techniques
@@ -79,10 +79,14 @@ def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, | |||
def forward( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i'm ok with the change but I think we can make encoder_hidden_states
optional here no? it is not much trouble and won't break for these using these blocks on their own
|
||
|
||
@maybe_allow_in_graph | ||
@TransformerBlockRegistry.register( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ohhh cannt we just make a xxTransformerBlockOutput
? could be a named tuple or something else, depends on what type of info you need now and could need in future
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reason for not changing what is returned is that it might be a big breaking change for people that import and use the blocks. Also, in general as I've been implementing new methods, there comes requirements that I cannot anticipate beforehand - being able to add some metadata information quickly would be much more convenient compared to committing to a dedicated approach imo. Once we support a large number of techniques, we could look into refactoring/better design (since these would all be internal use private attributes that we don't have to maintain BC with), wdyt?
We can refactor this a bit for now though. Dhruv suggested replacing the centralized registry with attributes instead. So, similar to how we have _no_split_modules
, etc. at ModelMixin level, we can maintain properties at the blocks too
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think it would break though, this just an example, these two should return the same, when you accept it as x, y = fun(...)
from typing import NamedTuple
import torch
class XXTransformerBlockOutput(NamedTuple):
encoder_hidden_states: torch.Tensor
hidden_states: torch.Tensor
def fun1(x, y):
# Some processing
return x, y
def fun2(x, y):
# Same processing
return XXTransformerBlockOutput(x, y)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay I'll update.
How do you suggest we return the output from the cache hooks? If xxTransformerBlockOutput means maintaining an output class per transformer block type, the cache hook should also return the same output class. Do we instead register what the output class should be to the block metadata so that the cache hook can return the correct object? Or do we use a generic TransformerBlockOutput class for all models?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just looked through the code so my understanding could be completely wrong, let me know if that's the case (obviously you know better since you implemented it)
I think you would not need to register anything, no? if you use this new output you will get the new output object in your hook and you will know which class it is and also all the other info you need (e.g which value is which) - would this not be the case?
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
output = self.fn_ref.original_forward(*args, **kwargs)
encoder_hidden_states = fn_to_update(output.encoder_hidden_states)
hidden_states = fn_to_update(output.hidden_states)
new_output = output.__class__(
encoder_hidden_states=encoder_hidden_states,
hidden_states=hidden_states
)
return new_output
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@a-r-r-o-w overall the refactor looks good, still not too clear why we need an empty registry though? do you mean that we current code needs it but we could refactor out in the future?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yiyixuxu On second thought, I was thinking: let's go forth with current implementation as-is. These are internal details that we can change at any time because the user will never have to interact with any of it. The reason why I think we should do this is because I want to first work on adding a couple of different techniques like SmoothCache, TaylorSeer, Radial Attention, etc. and see what kind of information we need to maintain in blocks/processors/elsewhere. Will help the iteration speed a lot IMO
WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I really don't like adding a meta data registry about the return element indeex to all the model classes
maybe you keep a hardcoded map with this info for now so we can keep the changes in the hooks folder and get this merged now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay thanks, I'll go with the hardcoded map approach that we previously had and make the relevant updates
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yiyixuxu Updated! Please take another look
This is a great branch, but I encountered this problem when testing it. The first method will fail, perhaps because some logic is not implemented. Detailed Error
|
@glide-the Thanks for testing and reporting the issue! After the latest refactor, it seems that applying cache with |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks @a-r-r-o-w :)
Failing tests seem to be unrelated. The VACE lora tests are very flaky and documentation tests are being investigated by Sayak. I think we're good to go here. Once we have a couple of the latest techniques, I will look into refactoring our test suite and the internal implementations/registration part since we established we don't want centralized registries. I'll open followups for more models supporting caching in the coming days after testing |
* [CI] Fix big GPU test marker (#11786) * update * update * First Block Cache (#11180) * update * modify flux single blocks to make compatible with cache techniques (without too much model-specific intrusion code) * remove debug logs * update * cache context for different batches of data * fix hs residual bug for single return outputs; support ltx * fix controlnet flux * support flux, ltx i2v, ltx condition * update * update * Update docs/source/en/api/cache.md * Update src/diffusers/hooks/hooks.py Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com> * address review comments pt. 1 * address review comments pt. 2 * cache context refacotr; address review pt. 3 * address review comments * metadata registration with decorators instead of centralized * support cogvideox * support mochi * fix * remove unused function * remove central registry based on review * update --------- Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com> * fix --------- Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
FBC Reference: https://github.com/chengzeyi/ParaAttention
Minimal example
Benchmark scripts
Threshold vs Generation time (in seconds) for each model. In general, values below 0.2 work well for First Block Cache depending on the model. Higher values leads to blurring and artifacting
CogView4
HunyuanVideo
LTX Video
Wan
Flux
Visual result comparison
CogView4
Hunyuan Video
output_0.00000.mp4
output_0.05000.mp4
output_0.10000.mp4
output_0.20000.mp4
output_0.40000.mp4
output_0.50000.mp4
LTX Video
output_0.00000.mp4
output_0.03000.mp4
output_0.05000.mp4
output_0.10000.mp4
output_0.20000.mp4
output_0.40000.mp4
Wan
output_0.00000.mp4
output_0.05000.mp4
output_0.10000.mp4
output_0.20000.mp4
Flux
Using with
torch.compile