Skip to content

[refactor] some shared parts between hooks + docs #11968

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions src/diffusers/hooks/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@

import torch

from ..models.attention import FeedForward, LuminaFeedForward
from ..models.attention import AttentionModuleMixin, FeedForward, LuminaFeedForward
from ..models.attention_processor import Attention, MochiAttention


_ATTENTION_CLASSES = (Attention, MochiAttention)
_ATTENTION_CLASSES = (Attention, MochiAttention, AttentionModuleMixin)
_FEEDFORWARD_CLASSES = (FeedForward, LuminaFeedForward)

_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks", "layers")
Expand All @@ -35,6 +35,19 @@
}
)

# Layers supported for group offloading and layerwise casting
_GO_LC_SUPPORTED_PYTORCH_LAYERS = (
torch.nn.Conv1d,
torch.nn.Conv2d,
torch.nn.Conv3d,
torch.nn.ConvTranspose1d,
torch.nn.ConvTranspose2d,
torch.nn.ConvTranspose3d,
torch.nn.Linear,
# TODO(aryan): look into torch.nn.LayerNorm, torch.nn.GroupNorm later, seems to be causing some issues with CogVideoX
# because of double invocation of the same norm layer in CogVideoXLayerNorm
)


def _get_submodule_from_fqn(module: torch.nn.Module, fqn: str) -> Optional[torch.nn.Module]:
for submodule_name, submodule in module.named_modules():
Expand Down
14 changes: 7 additions & 7 deletions src/diffusers/hooks/faster_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
import torch

from ..models.attention import AttentionModuleMixin
from ..models.attention_processor import Attention, MochiAttention
from ..models.modeling_outputs import Transformer2DModelOutput
from ..utils import logging
from ._common import _ATTENTION_CLASSES
from .hooks import HookRegistry, ModelHook


Expand All @@ -30,7 +30,6 @@

_FASTER_CACHE_DENOISER_HOOK = "faster_cache_denoiser"
_FASTER_CACHE_BLOCK_HOOK = "faster_cache_block"
_ATTENTION_CLASSES = (Attention, MochiAttention)
_SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = (
"^blocks.*attn",
"^transformer_blocks.*attn",
Expand Down Expand Up @@ -489,9 +488,10 @@ def apply_faster_cache(module: torch.nn.Module, config: FasterCacheConfig) -> No
Applies [FasterCache](https://huggingface.co/papers/2410.19355) to a given pipeline.

Args:
pipeline (`DiffusionPipeline`):
The diffusion pipeline to apply FasterCache to.
config (`Optional[FasterCacheConfig]`, `optional`, defaults to `None`):
module (`torch.nn.Module`):
The pytorch module to apply FasterCache to. Typically, this should be a transformer architecture supported
in Diffusers, such as `CogVideoXTransformer3DModel`, but external implementations may also work.
config (`FasterCacheConfig`):
The configuration to use for FasterCache.

Example:
Expand Down Expand Up @@ -568,7 +568,7 @@ def high_frequency_weight_callback(module: torch.nn.Module) -> float:
_apply_faster_cache_on_denoiser(module, config)

for name, submodule in module.named_modules():
if not isinstance(submodule, (*_ATTENTION_CLASSES, AttentionModuleMixin)):
if not isinstance(submodule, _ATTENTION_CLASSES):
continue
if any(re.search(identifier, name) is not None for identifier in _TRANSFORMER_BLOCK_IDENTIFIERS):
_apply_faster_cache_on_attention_class(name, submodule, config)
Expand All @@ -589,7 +589,7 @@ def _apply_faster_cache_on_denoiser(module: torch.nn.Module, config: FasterCache
registry.register_hook(hook, _FASTER_CACHE_DENOISER_HOOK)


def _apply_faster_cache_on_attention_class(name: str, module: Attention, config: FasterCacheConfig) -> None:
def _apply_faster_cache_on_attention_class(name: str, module: AttentionModuleMixin, config: FasterCacheConfig) -> None:
is_spatial_self_attention = (
any(re.search(identifier, name) is not None for identifier in config.spatial_attention_block_identifiers)
and config.spatial_attention_block_skip_range is not None
Expand Down
32 changes: 32 additions & 0 deletions src/diffusers/hooks/first_block_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,38 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs):


def apply_first_block_cache(module: torch.nn.Module, config: FirstBlockCacheConfig) -> None:
"""
Applies [First Block
Cache](https://github.com/chengzeyi/ParaAttention/blob/4de137c5b96416489f06e43e19f2c14a772e28fd/README.md#first-block-cache-our-dynamic-caching)
to a given module.

First Block Cache builds on the ideas of [TeaCache](). It is much simpler to implement generically for a wide range
of models and has been integrated first for experimental purposes.

Args:
module (`torch.nn.Module`):
The pytorch module to apply FBCache to. Typically, this should be a transformer architecture supported in
Diffusers, such as `CogVideoXTransformer3DModel`, but external implementations may also work.
config (`FirstBlockCacheConfig`):
The configuration to use for applying the FBCache method.

Example:
```python
>>> import torch
>>> from diffusers import CogView4Pipeline
>>> from diffusers.hooks import apply_first_block_cache, FirstBlockCacheConfig

>>> pipe = CogView4Pipeline.from_pretrained("THUDM/CogView4-6B", torch_dtype=torch.bfloat16)
>>> pipe.to("cuda")

>>> apply_first_block_cache(pipe.transformer, FirstBlockCacheConfig(threshold=0.2))

>>> prompt = "A photo of an astronaut riding a horse on mars"
>>> image = pipe(prompt, generator=torch.Generator().manual_seed(42)).images[0]
>>> image.save("output.png")
```
"""

state_manager = StateManager(FBCSharedBlockState, (), {})
remaining_blocks = []

Expand Down
10 changes: 2 additions & 8 deletions src/diffusers/hooks/group_offloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import torch

from ..utils import get_logger, is_accelerate_available
from ._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS
from .hooks import HookRegistry, ModelHook


Expand All @@ -39,13 +40,6 @@
_LAYER_EXECUTION_TRACKER = "layer_execution_tracker"
_LAZY_PREFETCH_GROUP_OFFLOADING = "lazy_prefetch_group_offloading"
_GROUP_ID_LAZY_LEAF = "lazy_leafs"
_SUPPORTED_PYTORCH_LAYERS = (
torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d,
torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d,
torch.nn.Linear,
# TODO(aryan): look into torch.nn.LayerNorm, torch.nn.GroupNorm later, seems to be causing some issues with CogVideoX
# because of double invocation of the same norm layer in CogVideoXLayerNorm
)
# fmt: on


Expand Down Expand Up @@ -680,7 +674,7 @@ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOff
# Create module groups for leaf modules and apply group offloading hooks
modules_with_group_offloading = set()
for name, submodule in module.named_modules():
if not isinstance(submodule, _SUPPORTED_PYTORCH_LAYERS):
if not isinstance(submodule, _GO_LC_SUPPORTED_PYTORCH_LAYERS):
continue
group = ModuleGroup(
modules=[submodule],
Expand Down
9 changes: 2 additions & 7 deletions src/diffusers/hooks/layerwise_casting.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import torch

from ..utils import get_logger, is_peft_available, is_peft_version
from ._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS
from .hooks import HookRegistry, ModelHook


Expand All @@ -27,12 +28,6 @@
# fmt: off
_LAYERWISE_CASTING_HOOK = "layerwise_casting"
_PEFT_AUTOCAST_DISABLE_HOOK = "peft_autocast_disable"
SUPPORTED_PYTORCH_LAYERS = (
torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d,
torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d,
torch.nn.Linear,
)

DEFAULT_SKIP_MODULES_PATTERN = ("pos_embed", "patch_embed", "norm", "^proj_in$", "^proj_out$")
# fmt: on

Expand Down Expand Up @@ -186,7 +181,7 @@ def _apply_layerwise_casting(
logger.debug(f'Skipping layerwise casting for layer "{_prefix}"')
return

if isinstance(module, SUPPORTED_PYTORCH_LAYERS):
if isinstance(module, _GO_LC_SUPPORTED_PYTORCH_LAYERS):
logger.debug(f'Applying layerwise casting to layer "{_prefix}"')
apply_layerwise_casting_hook(module, storage_dtype, compute_dtype, non_blocking)
return
Expand Down
22 changes: 12 additions & 10 deletions src/diffusers/hooks/pyramid_attention_broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,19 @@
from ..models.attention import AttentionModuleMixin
from ..models.attention_processor import Attention, MochiAttention
from ..utils import logging
from ._common import (
_ATTENTION_CLASSES,
_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS,
_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS,
_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS,
)
from .hooks import HookRegistry, ModelHook


logger = logging.get_logger(__name__) # pylint: disable=invalid-name


_PYRAMID_ATTENTION_BROADCAST_HOOK = "pyramid_attention_broadcast"
_ATTENTION_CLASSES = (Attention, MochiAttention)
_SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks")
_TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",)
_CROSS_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks")


@dataclass
Expand Down Expand Up @@ -61,11 +63,11 @@ class PyramidAttentionBroadcastConfig:
cross_attention_timestep_skip_range (`Tuple[int, int]`, defaults to `(100, 800)`):
The range of timesteps to skip in the cross-attention layer. The attention computations will be
conditionally skipped if the current timestep is within the specified range.
spatial_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("blocks", "transformer_blocks")`):
spatial_attention_block_identifiers (`Tuple[str, ...]`):
The identifiers to match against the layer names to determine if the layer is a spatial attention layer.
temporal_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("temporal_transformer_blocks",)`):
temporal_attention_block_identifiers (`Tuple[str, ...]`):
The identifiers to match against the layer names to determine if the layer is a temporal attention layer.
cross_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("blocks", "transformer_blocks")`):
cross_attention_block_identifiers (`Tuple[str, ...]`):
The identifiers to match against the layer names to determine if the layer is a cross-attention layer.
"""

Expand All @@ -77,9 +79,9 @@ class PyramidAttentionBroadcastConfig:
temporal_attention_timestep_skip_range: Tuple[int, int] = (100, 800)
cross_attention_timestep_skip_range: Tuple[int, int] = (100, 800)

spatial_attention_block_identifiers: Tuple[str, ...] = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS
temporal_attention_block_identifiers: Tuple[str, ...] = _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS
cross_attention_block_identifiers: Tuple[str, ...] = _CROSS_ATTENTION_BLOCK_IDENTIFIERS
spatial_attention_block_identifiers: Tuple[str, ...] = _SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS
temporal_attention_block_identifiers: Tuple[str, ...] = _TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS
cross_attention_block_identifiers: Tuple[str, ...] = _CROSS_TRANSFORMER_BLOCK_IDENTIFIERS

current_timestep_callback: Callable[[], int] = None

Expand Down
6 changes: 3 additions & 3 deletions src/diffusers/utils/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1394,9 +1394,9 @@ def get_device_properties() -> DeviceProperties:
DevicePropertiesUserDict = UserDict

if is_torch_available():
from diffusers.hooks._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS
from diffusers.hooks.group_offloading import (
_GROUP_ID_LAZY_LEAF,
_SUPPORTED_PYTORCH_LAYERS,
_compute_group_hash,
_find_parent_module_in_module_dict,
_gather_buffers_with_no_group_offloading_parent,
Expand Down Expand Up @@ -1440,13 +1440,13 @@ def get_hashed_filename(group_id: str) -> str:
elif offload_type == "leaf_level":
# Handle leaf-level module groups
for name, submodule in module.named_modules():
if isinstance(submodule, _SUPPORTED_PYTORCH_LAYERS):
if isinstance(submodule, _GO_LC_SUPPORTED_PYTORCH_LAYERS):
# These groups will always have parameters, so a file is expected
expected_files.add(get_hashed_filename(name))

# Handle groups for non-leaf parameters/buffers
modules_with_group_offloading = {
name for name, sm in module.named_modules() if isinstance(sm, _SUPPORTED_PYTORCH_LAYERS)
name for name, sm in module.named_modules() if isinstance(sm, _GO_LC_SUPPORTED_PYTORCH_LAYERS)
}
parameters = _gather_parameters_with_no_group_offloading_parent(module, modules_with_group_offloading)
buffers = _gather_buffers_with_no_group_offloading_parent(module, modules_with_group_offloading)
Expand Down
9 changes: 5 additions & 4 deletions tests/lora/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2109,14 +2109,15 @@ def test_correct_lora_configs_with_different_ranks(self):
self.assertTrue(not np.allclose(lora_output_diff_alpha, lora_output_same_rank, atol=1e-3, rtol=1e-3))

def test_layerwise_casting_inference_denoiser(self):
from diffusers.hooks.layerwise_casting import DEFAULT_SKIP_MODULES_PATTERN, SUPPORTED_PYTORCH_LAYERS
from diffusers.hooks._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS
from diffusers.hooks.layerwise_casting import DEFAULT_SKIP_MODULES_PATTERN

def check_linear_dtype(module, storage_dtype, compute_dtype):
patterns_to_check = DEFAULT_SKIP_MODULES_PATTERN
if getattr(module, "_skip_layerwise_casting_patterns", None) is not None:
patterns_to_check += tuple(module._skip_layerwise_casting_patterns)
for name, submodule in module.named_modules():
if not isinstance(submodule, SUPPORTED_PYTORCH_LAYERS):
if not isinstance(submodule, _GO_LC_SUPPORTED_PYTORCH_LAYERS):
continue
dtype_to_check = storage_dtype
if "lora" in name or any(re.search(pattern, name) for pattern in patterns_to_check):
Expand Down Expand Up @@ -2167,10 +2168,10 @@ def test_layerwise_casting_peft_input_autocast_denoiser(self):
See the docstring of [`hooks.layerwise_casting.PeftInputAutocastDisableHook`] for more details.
"""

from diffusers.hooks._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS
from diffusers.hooks.layerwise_casting import (
_PEFT_AUTOCAST_DISABLE_HOOK,
DEFAULT_SKIP_MODULES_PATTERN,
SUPPORTED_PYTORCH_LAYERS,
apply_layerwise_casting,
)

Expand All @@ -2180,7 +2181,7 @@ def test_layerwise_casting_peft_input_autocast_denoiser(self):
def check_module(denoiser):
# This will also check if the peft layers are in torch.float8_e4m3fn dtype (unlike test_layerwise_casting_inference_denoiser)
for name, module in denoiser.named_modules():
if not isinstance(module, SUPPORTED_PYTORCH_LAYERS):
if not isinstance(module, _GO_LC_SUPPORTED_PYTORCH_LAYERS):
continue
dtype_to_check = storage_dtype
if any(re.search(pattern, name) for pattern in patterns_to_check):
Expand Down
5 changes: 3 additions & 2 deletions tests/models/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1530,7 +1530,8 @@ def test_fn(storage_dtype, compute_dtype):

@torch.no_grad()
def test_layerwise_casting_inference(self):
from diffusers.hooks.layerwise_casting import DEFAULT_SKIP_MODULES_PATTERN, SUPPORTED_PYTORCH_LAYERS
from diffusers.hooks._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS
from diffusers.hooks.layerwise_casting import DEFAULT_SKIP_MODULES_PATTERN

torch.manual_seed(0)
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
Expand All @@ -1544,7 +1545,7 @@ def check_linear_dtype(module, storage_dtype, compute_dtype):
if getattr(module, "_skip_layerwise_casting_patterns", None) is not None:
patterns_to_check += tuple(module._skip_layerwise_casting_patterns)
for name, submodule in module.named_modules():
if not isinstance(submodule, SUPPORTED_PYTORCH_LAYERS):
if not isinstance(submodule, _GO_LC_SUPPORTED_PYTORCH_LAYERS):
continue
dtype_to_check = storage_dtype
if any(re.search(pattern, name) for pattern in patterns_to_check):
Expand Down
Loading