-
Notifications
You must be signed in to change notification settings - Fork 6.1k
[core] FasterCache #10163
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[core] FasterCache #10163
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Have some ideas around the design. LMK when is a good time to pass those off. |
@sayakpaul Feel free to mention them, but this is not finalized yet either. Also, could you please not merge main into the branch unless it is ready for reviews? It causes merge conflicts unnecessarily that I don't want to deal with because I have changes locally 😭 |
src/diffusers/models/hooks.py
Outdated
import torch | ||
|
||
|
||
# Reference: https://github.com/huggingface/accelerate/blob/ba7ab93f5e688466ea56908ea3b056fae2f9a023/src/accelerate/hooks.py |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(nit) Since some of the utils are taken from accelerate
and repurposed here, adding a note on why we're not directly importing them from accelerate
would be nice.
|
||
# Reference: https://github.com/Vchitect/FasterCache/blob/fab32c15014636dc854948319c0a9a8d92c7acb4/scripts/latte/fastercache_sample_latte.py#L127C1-L143C39 | ||
@torch.no_grad() | ||
def _fft(tensor): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could go to torch_utils.py
similar to:
def fourier_filter(x_in: "torch.Tensor", threshold: int, scale: int) -> "torch.Tensor": |
5843961
to
c02f72d
Compare
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
Hi, thanks for the great work! I’m interested in this feature and looking forward to using it. I noticed that all checks have passed, but it’s still awaiting review. Just wondering if there’s any update on the review process?Thank you. @DN6 |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
what is the status of this pr? |
@vladmandic This is ready to merge and there are no blockers. The only thing holding me off from merging is that, in light of some other caching techniques, I want to do it in a way that is agnostic to all our model implementations. The current implementation works agnostically up to some extent (I haven't tried it on all models though) as can be seen with the benchmarks in the PR desc. But, it's done with a variety of assumptions about what each internal transformer block returns. In Diffusers, we mostly always return the following:
Now, unique to each transformer block implementation, sometimes we might decide to skip the execution path (such as Skip Layer Guidance, First Block Cache, etc). We can do this in two ways: remove the block from the transformer stack, or return the input hidden_states/encoder_hidden_states as the output. Doing the former is not possible for us and is not the route we want to take. Doing the latter can easily be done with the "hooks" design introduced in this release schedule. However, we need to maintain some metadata about what each transformer block returns, in which order does it return, can the inputs be directly returned as outputs if the block is to be skipped, etc. This information will significantly simplify adding new cache methods, since having a bunch of if-else statements to handle the different cases drastically reduces the debugability/maintainability (atleast for me). I have a plan on how to implement it better so that I can quickly add other cache techniques, but I want to hold off from introducing changes in this release that I know I might have to break/change in next release. I currently am looking into a few other things though, and as a result this might have to wait a bit longer. I'm about ~75% confident that any changes I make will not affect the user-facing API and will not be breaking in any way, but better be safe than sorry. I'll discuss with Dhruv and see what he thinks, and merge if this is okay |
@bot /style |
Style fixes have been applied. View the workflow run here. |
@a-r-r-o-w thanks for the detailed writeup! |
Failing tests seem unrelated |
Fixes #10128.
Flux visual results
HunyuanVideo visual results
hunyuan_video---dtype-bf16---cache_method-none---compile-False.mp4
hunyuan_video---dtype-bf16---cache_method-fastercache---compile-False.mp4
Latte visual results
latte---dtype-fp16---cache_method-none---compile-False.mp4
latte---dtype-fp16---cache_method-fastercache---compile-False.mp4
CogVideoX visual results
cogvideox-1.0---dtype-bf16---cache_method-none---compile-False.mp4
cogvideox-1.0---dtype-bf16---cache_method-fastercache---compile-False.mp4
Mochi visual results
Note: I'm yet to find the optimal inference parameters for Mochi to minimize quality difference. Will try to work on a blog on how this can be done.
mochi---dtype-bf16---cache_method-none---compile-False.mp4
mochi---dtype-bf16---cache_method-fastercache---compile-False.mp4
Important
This implementations differs from the original implementation and I believe is true to what's described in the paper. The original implementation has certain implementation differences to what's described in the paper.
See Vchitect/FasterCache#13 for more details.
TLDR; The original implementation approximates the conditional branch outputs with the inference results from unconditional branch. This is incorrect in comparison to what's described in the "CFG Cache" section of the paper. We should use the conditional predictions to approximate the outputs of unconditional branch.
I tested hook-based implementation for both original and our current version separately. Visually, I think our current implementation produces better results that are more aligned with the video generated without applying FasterCache.
Note
The complete ideas behind FasterCache require both an unconditional and conditional batch for approximating generation of videos. So, models like Flux and HunyuanVideo are not OOTB fully-compatible with it as they are guidance-distilled. Broadly, there are two approximations at play:
For guidance-distilled models, only the attention approximation parts are used.
We currently need to add a
reset_stateful_hooks()
call to every pipeline for FasterCache to work correctly. This is not ideal. We should also support "pipeline hooks" as counter parts to "model hooks" - which would allow users to pre/post-hook into all pipeline methods likeencode_prompt
,prepare_latents
and__call__
. I have a prototype implementation ready to demonstrate how this can be done. Ideally, we want it to be able to target__call__
so that the hook can trigger the state reset.Code
TODO:
cc @cszy98 @ChenyangSi