diff --git a/src/diffusers/models/hooks.py b/src/diffusers/models/hooks.py new file mode 100644 index 000000000000..53ceae971196 --- /dev/null +++ b/src/diffusers/models/hooks.py @@ -0,0 +1,227 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +from typing import Any, Dict, Tuple + +import torch + + +# Reference: https://github.com/huggingface/accelerate/blob/ba7ab93f5e688466ea56908ea3b056fae2f9a023/src/accelerate/hooks.py +class ModelHook: + r""" + A hook that contains callbacks to be executed just before and after the forward method of a model. + """ + + _is_stateful = False + + def init_hook(self, module: torch.nn.Module) -> torch.nn.Module: + r""" + Hook that is executed when a model is initialized. + + Args: + module (`torch.nn.Module`): + The module attached to this hook. + """ + return module + + def pre_forward(self, module: torch.nn.Module, *args, **kwargs) -> Tuple[Tuple[Any], Dict[str, Any]]: + r""" + Hook that is executed just before the forward method of the model. + + Args: + module (`torch.nn.Module`): + The module whose forward pass will be executed just after this event. + args (`Tuple[Any]`): + The positional arguments passed to the module. + kwargs (`Dict[Str, Any]`): + The keyword arguments passed to the module. + Returns: + `Tuple[Tuple[Any], Dict[Str, Any]]`: + A tuple with the treated `args` and `kwargs`. + """ + return args, kwargs + + def post_forward(self, module: torch.nn.Module, output: Any) -> Any: + r""" + Hook that is executed just after the forward method of the model. + + Args: + module (`torch.nn.Module`): + The module whose forward pass been executed just before this event. + output (`Any`): + The output of the module. + Returns: + `Any`: The processed `output`. + """ + return output + + def detach_hook(self, module: torch.nn.Module) -> torch.nn.Module: + r""" + Hook that is executed when the hook is detached from a module. + + Args: + module (`torch.nn.Module`): + The module detached from this hook. + """ + return module + + def reset_state(self, module: torch.nn.Module) -> torch.nn.Module: + if self._is_stateful: + raise NotImplementedError("This hook is stateful and needs to implement the `reset_state` method.") + return module + + +class SequentialHook(ModelHook): + r"""A hook that can contain several hooks and iterates through them at each event.""" + + def __init__(self, *hooks): + self.hooks = hooks + + def init_hook(self, module): + for hook in self.hooks: + module = hook.init_hook(module) + return module + + def pre_forward(self, module, *args, **kwargs): + for hook in self.hooks: + args, kwargs = hook.pre_forward(module, *args, **kwargs) + return args, kwargs + + def post_forward(self, module, output): + for hook in self.hooks: + output = hook.post_forward(module, output) + return output + + def detach_hook(self, module): + for hook in self.hooks: + module = hook.detach_hook(module) + return module + + def reset_state(self, module): + for hook in self.hooks: + if hook._is_stateful: + hook.reset_state(module) + return module + + +def add_hook_to_module(module: torch.nn.Module, hook: ModelHook, append: bool = False) -> torch.nn.Module: + r""" + Adds a hook to a given module. This will rewrite the `forward` method of the module to include the hook, to remove + this behavior and restore the original `forward` method, use `remove_hook_from_module`. + + + + If the module already contains a hook, this will replace it with the new hook passed by default. To chain two hooks + together, pass `append=True`, so it chains the current and new hook into an instance of the `SequentialHook` class. + + + + Args: + module (`torch.nn.Module`): + The module to attach a hook to. + hook (`ModelHook`): + The hook to attach. + append (`bool`, *optional*, defaults to `False`): + Whether the hook should be chained with an existing one (if module already contains a hook) or not. + Returns: + `torch.nn.Module`: + The same module, with the hook attached (the module is modified in place, so the result can be discarded). + """ + original_hook = hook + + if append and getattr(module, "_diffusers_hook", None) is not None: + old_hook = module._diffusers_hook + remove_hook_from_module(module) + hook = SequentialHook(old_hook, hook) + + if hasattr(module, "_diffusers_hook") and hasattr(module, "_old_forward"): + # If we already put some hook on this module, we replace it with the new one. + old_forward = module._old_forward + else: + old_forward = module.forward + module._old_forward = old_forward + + module = hook.init_hook(module) + module._diffusers_hook = hook + + if hasattr(original_hook, "new_forward"): + new_forward = original_hook.new_forward + else: + + def new_forward(module, *args, **kwargs): + args, kwargs = module._diffusers_hook.pre_forward(module, *args, **kwargs) + output = module._old_forward(*args, **kwargs) + return module._diffusers_hook.post_forward(module, output) + + # Overriding a GraphModuleImpl forward freezes the forward call and later modifications on the graph will fail. + # Reference: https://pytorch.slack.com/archives/C3PDTEV8E/p1705929610405409 + if "GraphModuleImpl" in str(type(module)): + module.__class__.forward = functools.update_wrapper(functools.partial(new_forward, module), old_forward) + else: + module.forward = functools.update_wrapper(functools.partial(new_forward, module), old_forward) + + return module + + +def remove_hook_from_module(module: torch.nn.Module, recurse: bool = False) -> torch.nn.Module: + """ + Removes any hook attached to a module via `add_hook_to_module`. + + Args: + module (`torch.nn.Module`): + The module to attach a hook to. + recurse (`bool`, defaults to `False`): + Whether to remove the hooks recursively + Returns: + `torch.nn.Module`: + The same module, with the hook detached (the module is modified in place, so the result can be discarded). + """ + + if hasattr(module, "_diffusers_hook"): + module._diffusers_hook.detach_hook(module) + delattr(module, "_diffusers_hook") + + if hasattr(module, "_old_forward"): + # Overriding a GraphModuleImpl forward freezes the forward call and later modifications on the graph will fail. + # Reference: https://pytorch.slack.com/archives/C3PDTEV8E/p1705929610405409 + if "GraphModuleImpl" in str(type(module)): + module.__class__.forward = module._old_forward + else: + module.forward = module._old_forward + delattr(module, "_old_forward") + + if recurse: + for child in module.children(): + remove_hook_from_module(child, recurse) + + return module + + +def reset_stateful_hooks(module: torch.nn.Module, recurse: bool = False): + """ + Resets the state of all stateful hooks attached to a module. + + Args: + module (`torch.nn.Module`): + The module to reset the stateful hooks from. + """ + if hasattr(module, "_diffusers_hook") and ( + module._diffusers_hook._is_stateful or isinstance(module._diffusers_hook, SequentialHook) + ): + module._diffusers_hook.reset_state(module) + + if recurse: + for child in module.children(): + reset_stateful_hooks(child, recurse) diff --git a/src/diffusers/pipelines/teacache_utils.py b/src/diffusers/pipelines/teacache_utils.py new file mode 100644 index 000000000000..876e5aab1b53 --- /dev/null +++ b/src/diffusers/pipelines/teacache_utils.py @@ -0,0 +1,252 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from dataclasses import dataclass +from typing import List, Optional + +import numpy as np +import torch +import torch.nn as nn + +from ..models import ( + FluxTransformer2DModel, + HunyuanVideoTransformer3DModel, + LTXVideoTransformer3DModel, + LuminaNextDiT2DModel, + MochiTransformer3DModel, +) +from ..models.hooks import ModelHook, add_hook_to_module +from ..utils import logging +from .pipeline_utils import DiffusionPipeline + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# Source: https://github.com/ali-vilab/TeaCache +# TODO(aryan): Implement methods to calibrate and compute polynomial coefficients on-the-fly, and export to file for re-use. +# fmt: off +_MODEL_TO_POLY_COEFFICIENTS = { + FluxTransformer2DModel: [4.98651651e02, -2.83781631e02, 5.58554382e01, -3.82021401e00, 2.64230861e-01], + HunyuanVideoTransformer3DModel: [7.33226126e02, -4.01131952e02, 6.75869174e01, -3.14987800e00, 9.61237896e-02], + LTXVideoTransformer3DModel: [2.14700694e01, -1.28016453e01, 2.31279151e00, 7.92487521e-01, 9.69274326e-03], + LuminaNextDiT2DModel: [393.76566581, -603.50993606, 209.10239044, -23.00726601, 0.86377344], + MochiTransformer3DModel: [-3.51241319e03, 8.11675948e02, -6.09400215e01, 2.42429681e00, 3.05291719e-03], +} +# fmt: on + +_MODEL_TO_1_POINT_5X_SPEEDUP_THRESHOLD = { + FluxTransformer2DModel: 0.25, + HunyuanVideoTransformer3DModel: 0.1, + LTXVideoTransformer3DModel: 0.05, + LuminaNextDiT2DModel: 0.2, + MochiTransformer3DModel: 0.06, +} + +_MODEL_TO_TIMESTEP_MODULATED_LAYER_IDENTIFIER = { + FluxTransformer2DModel: "transformer_blocks.0.norm1", +} + +_MODEL_TO_SKIP_END_LAYER_IDENTIFIER = { + FluxTransformer2DModel: "norm_out", +} + +_DEFAULT_SKIP_LAYER_IDENTIFIERS = [ + "blocks", + "transformer_blocks", + "single_transformer_blocks", + "temporal_transformer_blocks", +] + + +@dataclass +class TeaCacheConfig: + l1_threshold: Optional[float] = None + + skip_layer_identifiers: List[str] = _DEFAULT_SKIP_LAYER_IDENTIFIERS + + _polynomial_coefficients: Optional[List[float]] = None + + +class TeaCacheDenoiserState: + def __init__(self): + self.iteration: int = 0 + self.accumulated_l1_difference: float = 0.0 + self.timestep_modulated_cache: torch.Tensor = None + self.residual_cache: torch.Tensor = None + self.should_skip_blocks: bool = False + + def reset(self): + self.iteration = 0 + self.accumulated_l1_difference = 0.0 + self.timestep_modulated_cache = None + self.residual_cache = None + + +def apply_teacache( + pipeline: DiffusionPipeline, config: Optional[TeaCacheConfig] = None, denoiser: Optional[nn.Module] = None +) -> None: + r"""Applies [TeaCache](https://huggingface.co/papers/2411.19108) to a given pipeline or denoiser module. + + Args: + TODO + """ + + if config is None: + logger.warning("No TeaCacheConfig provided. Using default configuration.") + config = TeaCacheConfig() + + if denoiser is None: + denoiser = pipeline.transformer if hasattr(pipeline, "transformer") else pipeline.unet + + if isinstance(denoiser, (_MODEL_TO_POLY_COEFFICIENTS.keys())): + if config.l1_threshold is None: + logger.info( + f"No L1 threshold was provided for {type(denoiser)}. Using default threshold as provided in the TeaCache paper for 1.5x speedup. " + f"For higher speedup, increase the threshold." + ) + config.l1_threshold = _MODEL_TO_1_POINT_5X_SPEEDUP_THRESHOLD[type(denoiser)] + if config.timestep_modulated_layer_identifier is None: + logger.info( + f"No timestep modulated layer identifier was provided for {type(denoiser)}. Using default identifier as provided in the TeaCache paper." + ) + config.timestep_modulated_layer_identifier = _MODEL_TO_TIMESTEP_MODULATED_LAYER_IDENTIFIER[type(denoiser)] + if config._polynomial_coefficients is None: + logger.info( + f"No polynomial coefficients were provided for {type(denoiser)}. Using default coefficients as provided in the TeaCache paper." + ) + config._polynomial_coefficients = _MODEL_TO_POLY_COEFFICIENTS[type(denoiser)] + else: + if config.l1_threshold is None: + raise ValueError( + f"No L1 threshold was provided for {type(denoiser)}. Using TeaCache with this model is not supported " + f"in Diffusers. Please provide the L1 threshold in the config by setting the `l1_threshold` attribute." + ) + if config.timestep_modulated_layer_identifier is None: + raise ValueError( + f"No timestep modulated layer identifier was provided for {type(denoiser)}. Using TeaCache with this model is not supported " + f"in Diffusers. Please provide the layer identifier in the config by setting the `timestep_modulated_layer_identifier` attribute." + ) + if config._polynomial_coefficients is None: + raise ValueError( + f"No polynomial coefficients were provided for {type(denoiser)}. Using TeaCache with this model is not " + f"supported in Diffusers. Please provide the polynomial coefficients in the config by setting the " + f"`_polynomial_coefficients` attribute. Automatic calibration will be implemented in the future." + ) + + timestep_modulated_layer_matches = list( + { + module + for name, module in denoiser.named_modules() + if re.match(config.timestep_modulated_layer_identifier, name) + } + ) + + if len(timestep_modulated_layer_matches) == 0: + raise ValueError( + f"No layer in the denoiser module matched the provided timestep modulated layer identifier: " + f"{config.timestep_modulated_layer_identifier}. Please provide a valid layer identifier." + ) + if len(timestep_modulated_layer_matches) > 1: + logger.warning( + f"Multiple layers in the denoiser module matched the provided timestep modulated layer identifier: " + f"{config.timestep_modulated_layer_identifier}. Using the first match." + ) + + denoiser_state = TeaCacheDenoiserState() + + timestep_modulated_layer = timestep_modulated_layer_matches[0] + hook = TimestepModulatedOutputCacheHook(denoiser_state, config.l1_threshold, config._polynomial_coefficients) + add_hook_to_module(timestep_modulated_layer, hook, append=True) + + skip_layer_identifiers = config.skip_layer_identifiers + skip_layer_matches = list( + { + module + for name, module in denoiser.named_modules() + if any(re.match(identifier, name) for identifier in skip_layer_identifiers) + } + ) + + for skip_layer in skip_layer_matches: + hook = DenoiserStateBasedSkipLayerHook(denoiser_state) + add_hook_to_module(skip_layer, hook, append=True) + + +class TimestepModulatedOutputCacheHook(ModelHook): + # The denoiser hook will reset its state, so we don't have to handle it here + _is_stateful = False + + def __init__( + self, + denoiser_state: TeaCacheDenoiserState, + l1_threshold: float, + polynomial_coefficients: List[float], + ) -> None: + self.denoiser_state = denoiser_state + self.l1_threshold = l1_threshold + # TODO(aryan): implement torch equivalent + self.rescale_fn = np.poly1d(polynomial_coefficients) + + def post_forward(self, module, output): + if isinstance(output, tuple): + # This assumes that the first element of the output tuple is the timestep modulated noise output. + # For Diffusers models, this is true. For models outside diffusers, users will have to ensure + # that the first element of the output tuple is the timestep modulated noise output (seems to be + # the case for most research model implementations). + timestep_modulated_noise = output[0] + elif torch.is_tensor(output): + timestep_modulated_noise = output + else: + raise ValueError( + f"Expected output to be a tensor or a tuple with first element as timestep modulated noise. " + f"Got {type(output)} instead. Please ensure that the denoiser module returns the timestep " + f"modulated noise output as the first element." + ) + + if self.denoiser_state.timestep_modulated_cache is not None: + l1_diff = (timestep_modulated_noise - self.denoiser_state.timestep_modulated_cache).abs().mean() + normalized_l1_diff = l1_diff / self.denoiser_state.timestep_modulated_cache.abs().mean() + rescaled_l1_diff = self.rescale_fn(normalized_l1_diff) + self.denoiser_state.accumulated_l1_difference += rescaled_l1_diff + + if self.denoiser_state.accumulated_l1_difference >= self.l1_threshold: + self.denoiser_state.should_skip_blocks = True + self.denoiser_state.accumulated_l1_difference = 0.0 + else: + self.denoiser_state.should_skip_blocks = False + + self.denoiser_state.timestep_modulated_cache = timestep_modulated_noise + return output + + +class DenoiserStateBasedSkipLayerHook(ModelHook): + _is_stateful = False + + def __init__(self, denoiser_state: TeaCacheDenoiserState) -> None: + self.denoiser_state = denoiser_state + + def new_forward(self, module, *args, **kwargs): + args, kwargs = module._diffusers_hook.pre_forward(module, *args, **kwargs) + + if not self.denoiser_state.should_skip_blocks: + output = module._old_forward(*args, **kwargs) + else: + # Diffusers models either expect one output (hidden_states) or a tuple of two outputs (hidden_states, encoder_hidden_states). + # Returning a tuple of None values handles both cases. It is okay to do because we are not going to be using these + # anywhere if self.denoiser_state.should_skip_blocks is True. + output = (None, None) + + return module._diffusers_hook.post_forward(module, output)