Skip to content

Commit 844221a

Browse files
[core] FasterCache (#10163)
* init * update * update * update * make style * update * fix * make it work with guidance distilled models * update * make fix-copies * add tests * update * apply_faster_cache -> apply_fastercache * fix * reorder * update * refactor * update docs * add fastercache to CacheMixin * update tests * Apply suggestions from code review * make style * try to fix partial import error * Apply style fixes * raise warning * update --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent 9b2c0a7 commit 844221a

File tree

16 files changed

+976
-25
lines changed

16 files changed

+976
-25
lines changed

docs/source/en/api/cache.md

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,33 @@ config = PyramidAttentionBroadcastConfig(
3838
pipe.transformer.enable_cache(config)
3939
```
4040

41+
## Faster Cache
42+
43+
[FasterCache](https://huggingface.co/papers/2410.19355) from Zhengyao Lv, Chenyang Si, Junhao Song, Zhenyu Yang, Yu Qiao, Ziwei Liu, Kwan-Yee K. Wong.
44+
45+
FasterCache is a method that speeds up inference in diffusion transformers by:
46+
- Reusing attention states between successive inference steps, due to high similarity between them
47+
- Skipping unconditional branch prediction used in classifier-free guidance by revealing redundancies between unconditional and conditional branch outputs for the same timestep, and therefore approximating the unconditional branch output using the conditional branch output
48+
49+
```python
50+
import torch
51+
from diffusers import CogVideoXPipeline, FasterCacheConfig
52+
53+
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
54+
pipe.to("cuda")
55+
56+
config = FasterCacheConfig(
57+
spatial_attention_block_skip_range=2,
58+
spatial_attention_timestep_skip_range=(-1, 681),
59+
current_timestep_callback=lambda: pipe.current_timestep,
60+
attention_weight_callback=lambda _: 0.3,
61+
unconditional_batch_skip_range=5,
62+
unconditional_batch_timestep_skip_range=(-1, 781),
63+
tensor_format="BFCHW",
64+
)
65+
pipe.transformer.enable_cache(config)
66+
```
67+
4168
### CacheMixin
4269

4370
[[autodoc]] CacheMixin
@@ -47,3 +74,9 @@ pipe.transformer.enable_cache(config)
4774
[[autodoc]] PyramidAttentionBroadcastConfig
4875

4976
[[autodoc]] apply_pyramid_attention_broadcast
77+
78+
### FasterCacheConfig
79+
80+
[[autodoc]] FasterCacheConfig
81+
82+
[[autodoc]] apply_faster_cache

src/diffusers/__init__.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,8 +131,10 @@
131131
else:
132132
_import_structure["hooks"].extend(
133133
[
134+
"FasterCacheConfig",
134135
"HookRegistry",
135136
"PyramidAttentionBroadcastConfig",
137+
"apply_faster_cache",
136138
"apply_pyramid_attention_broadcast",
137139
]
138140
)
@@ -703,7 +705,13 @@
703705
except OptionalDependencyNotAvailable:
704706
from .utils.dummy_pt_objects import * # noqa F403
705707
else:
706-
from .hooks import HookRegistry, PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
708+
from .hooks import (
709+
FasterCacheConfig,
710+
HookRegistry,
711+
PyramidAttentionBroadcastConfig,
712+
apply_faster_cache,
713+
apply_pyramid_attention_broadcast,
714+
)
707715
from .models import (
708716
AllegroTransformer3DModel,
709717
AsymmetricAutoencoderKL,

src/diffusers/hooks/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33

44
if is_torch_available():
5+
from .faster_cache import FasterCacheConfig, apply_faster_cache
56
from .group_offloading import apply_group_offloading
67
from .hooks import HookRegistry, ModelHook
78
from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook

src/diffusers/hooks/faster_cache.py

Lines changed: 653 additions & 0 deletions
Large diffs are not rendered by default.

src/diffusers/hooks/pyramid_attention_broadcast.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@
2626
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
2727

2828

29+
_PYRAMID_ATTENTION_BROADCAST_HOOK = "pyramid_attention_broadcast"
2930
_ATTENTION_CLASSES = (Attention, MochiAttention)
30-
3131
_SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks")
3232
_TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",)
3333
_CROSS_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks")
@@ -87,7 +87,7 @@ class PyramidAttentionBroadcastConfig:
8787

8888
def __repr__(self) -> str:
8989
return (
90-
f"PyramidAttentionBroadcastConfig("
90+
f"PyramidAttentionBroadcastConfig(\n"
9191
f" spatial_attention_block_skip_range={self.spatial_attention_block_skip_range},\n"
9292
f" temporal_attention_block_skip_range={self.temporal_attention_block_skip_range},\n"
9393
f" cross_attention_block_skip_range={self.cross_attention_block_skip_range},\n"
@@ -175,10 +175,7 @@ def reset_state(self, module: torch.nn.Module) -> None:
175175
return module
176176

177177

178-
def apply_pyramid_attention_broadcast(
179-
module: torch.nn.Module,
180-
config: PyramidAttentionBroadcastConfig,
181-
):
178+
def apply_pyramid_attention_broadcast(module: torch.nn.Module, config: PyramidAttentionBroadcastConfig):
182179
r"""
183180
Apply [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) to a given pipeline.
184181
@@ -311,4 +308,4 @@ def _apply_pyramid_attention_broadcast_hook(
311308
"""
312309
registry = HookRegistry.check_if_exists_or_initialize(module)
313310
hook = PyramidAttentionBroadcastHook(timestep_skip_range, block_skip_range, current_timestep_callback)
314-
registry.register_hook(hook, "pyramid_attention_broadcast")
311+
registry.register_hook(hook, _PYRAMID_ATTENTION_BROADCAST_HOOK)

src/diffusers/models/cache_utils.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ class CacheMixin:
2424
2525
Supported caching techniques:
2626
- [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588)
27+
- [FasterCache](https://huggingface.co/papers/2410.19355)
2728
"""
2829

2930
_cache_config = None
@@ -59,25 +60,43 @@ def enable_cache(self, config) -> None:
5960
```
6061
"""
6162

62-
from ..hooks import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
63+
from ..hooks import (
64+
FasterCacheConfig,
65+
PyramidAttentionBroadcastConfig,
66+
apply_faster_cache,
67+
apply_pyramid_attention_broadcast,
68+
)
69+
70+
if self.is_cache_enabled:
71+
raise ValueError(
72+
f"Caching has already been enabled with {type(self._cache_config)}. To apply a new caching technique, please disable the existing one first."
73+
)
6374

6475
if isinstance(config, PyramidAttentionBroadcastConfig):
6576
apply_pyramid_attention_broadcast(self, config)
77+
elif isinstance(config, FasterCacheConfig):
78+
apply_faster_cache(self, config)
6679
else:
6780
raise ValueError(f"Cache config {type(config)} is not supported.")
6881

6982
self._cache_config = config
7083

7184
def disable_cache(self) -> None:
72-
from ..hooks import HookRegistry, PyramidAttentionBroadcastConfig
85+
from ..hooks import FasterCacheConfig, HookRegistry, PyramidAttentionBroadcastConfig
86+
from ..hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK
87+
from ..hooks.pyramid_attention_broadcast import _PYRAMID_ATTENTION_BROADCAST_HOOK
7388

7489
if self._cache_config is None:
7590
logger.warning("Caching techniques have not been enabled, so there's nothing to disable.")
7691
return
7792

7893
if isinstance(self._cache_config, PyramidAttentionBroadcastConfig):
7994
registry = HookRegistry.check_if_exists_or_initialize(self)
80-
registry.remove_hook("pyramid_attention_broadcast", recurse=True)
95+
registry.remove_hook(_PYRAMID_ATTENTION_BROADCAST_HOOK, recurse=True)
96+
elif isinstance(self._cache_config, FasterCacheConfig):
97+
registry = HookRegistry.check_if_exists_or_initialize(self)
98+
registry.remove_hook(_FASTER_CACHE_DENOISER_HOOK, recurse=True)
99+
registry.remove_hook(_FASTER_CACHE_BLOCK_HOOK, recurse=True)
81100
else:
82101
raise ValueError(f"Cache config {type(self._cache_config)} is not supported.")
83102

src/diffusers/models/embeddings.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,7 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np"):
336336
" `from_numpy` is no longer required."
337337
" Pass `output_type='pt' to use the new version now."
338338
)
339-
deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False)
339+
deprecate("output_type=='np'", "0.34.0", deprecation_message, standard_warn=False)
340340
return get_1d_sincos_pos_embed_from_grid_np(embed_dim=embed_dim, pos=pos)
341341
if embed_dim % 2 != 0:
342342
raise ValueError("embed_dim must be divisible by 2")

src/diffusers/models/modeling_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737
from typing_extensions import Self
3838

3939
from .. import __version__
40-
from ..hooks import apply_group_offloading, apply_layerwise_casting
4140
from ..quantizers import DiffusersAutoQuantizer, DiffusersQuantizer
4241
from ..quantizers.quantization_config import QuantizationMethod
4342
from ..utils import (
@@ -504,6 +503,7 @@ def enable_layerwise_casting(
504503
non_blocking (`bool`, *optional*, defaults to `False`):
505504
If `True`, the weight casting operations are non-blocking.
506505
"""
506+
from ..hooks import apply_layerwise_casting
507507

508508
user_provided_patterns = True
509509
if skip_modules_pattern is None:
@@ -570,6 +570,8 @@ def enable_group_offload(
570570
... )
571571
```
572572
"""
573+
from ..hooks import apply_group_offloading
574+
573575
if getattr(self, "enable_tiling", None) is not None and getattr(self, "use_tiling", False) and use_stream:
574576
msg = (
575577
"Applying group offloading on autoencoders, with CUDA streams, may not work as expected if the first "

src/diffusers/pipelines/latte/pipeline_latte.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -817,7 +817,7 @@ def __call__(
817817

818818
# predict noise model_output
819819
noise_pred = self.transformer(
820-
latent_model_input,
820+
hidden_states=latent_model_input,
821821
encoder_hidden_states=prompt_embeds,
822822
timestep=current_timestep,
823823
enable_temporal_attentions=enable_temporal_attentions,

src/diffusers/utils/dummy_pt_objects.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,21 @@
22
from ..utils import DummyObject, requires_backends
33

44

5+
class FasterCacheConfig(metaclass=DummyObject):
6+
_backends = ["torch"]
7+
8+
def __init__(self, *args, **kwargs):
9+
requires_backends(self, ["torch"])
10+
11+
@classmethod
12+
def from_config(cls, *args, **kwargs):
13+
requires_backends(cls, ["torch"])
14+
15+
@classmethod
16+
def from_pretrained(cls, *args, **kwargs):
17+
requires_backends(cls, ["torch"])
18+
19+
520
class HookRegistry(metaclass=DummyObject):
621
_backends = ["torch"]
722

@@ -32,6 +47,10 @@ def from_pretrained(cls, *args, **kwargs):
3247
requires_backends(cls, ["torch"])
3348

3449

50+
def apply_faster_cache(*args, **kwargs):
51+
requires_backends(apply_faster_cache, ["torch"])
52+
53+
3554
def apply_pyramid_attention_broadcast(*args, **kwargs):
3655
requires_backends(apply_pyramid_attention_broadcast, ["torch"])
3756

0 commit comments

Comments
 (0)