Skip to content

[core] FasterCache #10163

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 34 commits into from
Mar 21, 2025
Merged

[core] FasterCache #10163

merged 34 commits into from
Mar 21, 2025

Conversation

a-r-r-o-w
Copy link
Member

@a-r-r-o-w a-r-r-o-w commented Dec 9, 2024

Fixes #10128.

model_id cache_method time model_memory model_max_memory_reserved inference_memory inference_max_memory_reserved
latte none 31.742 12.88 13.189 12.888 14.344
latte fastercache 24.288 12.882 13.17 12.891 22.199
latte pyramid_attention_broadcast 27.531 12.882 13.17 12.891 20.23
cogvideox-1.0 none 245.939 19.66 19.678 19.671 24.426
cogvideox-1.0 fastercache 159.031 19.661 19.678 19.672 40.721
cogvideox-1.0 pyramid_attention_broadcast 184.013 19.661 19.678 19.672 32.57
mochi none 437.327 28.411 28.65 28.421 36.062
mochi fastercache 358.871 28.411 28.648 28.422 36.062
mochi pyramid_attention_broadcast 324.051 28.411 28.65 28.421 52.088
hunyuan_video none 72.628 38.577 38.672 38.587 41.141
hunyuan_video fastercache 47.35 38.579 38.658 38.588 49.102
hunyuan_video pyramid_attention_broadcast 63.892 38.578 38.672 38.587 44.785
flux none 16.802 31.44 31.451 31.448 32.023
flux fastercache 12.259 31.44 31.451 31.448 33.818
flux pyramid_attention_broadcast 13.719 31.439 31.451 31.447 32.832
Flux visual results
Normal FasterCache
HunyuanVideo visual results
Normal FasterCache
hunyuan_video---dtype-bf16---cache_method-none---compile-False.mp4
hunyuan_video---dtype-bf16---cache_method-fastercache---compile-False.mp4
Latte visual results
Normal FasterCache
latte---dtype-fp16---cache_method-none---compile-False.mp4
latte---dtype-fp16---cache_method-fastercache---compile-False.mp4
CogVideoX visual results
Normal FasterCache
cogvideox-1.0---dtype-bf16---cache_method-none---compile-False.mp4
cogvideox-1.0---dtype-bf16---cache_method-fastercache---compile-False.mp4
Mochi visual results

Note: I'm yet to find the optimal inference parameters for Mochi to minimize quality difference. Will try to work on a blog on how this can be done.

Normal FasterCache
mochi---dtype-bf16---cache_method-none---compile-False.mp4
mochi---dtype-bf16---cache_method-fastercache---compile-False.mp4

Important

This implementations differs from the original implementation and I believe is true to what's described in the paper. The original implementation has certain implementation differences to what's described in the paper.

See Vchitect/FasterCache#13 for more details.

TLDR; The original implementation approximates the conditional branch outputs with the inference results from unconditional branch. This is incorrect in comparison to what's described in the "CFG Cache" section of the paper. We should use the conditional predictions to approximate the outputs of unconditional branch.

I tested hook-based implementation for both original and our current version separately. Visually, I think our current implementation produces better results that are more aligned with the video generated without applying FasterCache.

Note

The complete ideas behind FasterCache require both an unconditional and conditional batch for approximating generation of videos. So, models like Flux and HunyuanVideo are not OOTB fully-compatible with it as they are guidance-distilled. Broadly, there are two approximations at play:

  • Approximating attention outputs from previous computations
  • Approximating unconditional batch outputs from inference run on conditional batch
    For guidance-distilled models, only the attention approximation parts are used.

We currently need to add a reset_stateful_hooks() call to every pipeline for FasterCache to work correctly. This is not ideal. We should also support "pipeline hooks" as counter parts to "model hooks" - which would allow users to pre/post-hook into all pipeline methods like encode_prompt, prepare_latents and __call__. I have a prototype implementation ready to demonstrate how this can be done. Ideally, we want it to be able to target __call__ so that the hook can trigger the state reset.

Code
import argparse
import gc
import pathlib
import traceback

import git
import pandas as pd
import torch
from diffusers import (
    AllegroPipeline,
    CogVideoXPipeline,
    FluxPipeline,
    HunyuanVideoPipeline,
    LattePipeline,
    MochiPipeline,
)
from diffusers.models import HunyuanVideoTransformer3DModel
from diffusers.utils import export_to_video
from diffusers.utils.logging import set_verbosity_info, set_verbosity_debug
from tabulate import tabulate


repo = git.Repo(path="/home/aryan/work/diffusers")
branch = repo.active_branch

if branch.name in ["pyramid-attention-broadcast", "pyramid-attention-rewrite-2"]:
    from diffusers.pipelines.pyramid_attention_broadcast_utils import (
        apply_pyramid_attention_broadcast,
        PyramidAttentionBroadcastConfig,
    )
elif branch.name in ["fastercache"]:
    from diffusers.pipelines.fastercache_utils import apply_fastercache, FasterCacheConfig


def pretty_print_results(results, precision: int = 3):
    def format_value(value):
        if isinstance(value, float):
            return f"{value:.{precision}f}"
        return value

    filtered_table = {k: format_value(v) for k, v in results.items()}
    print(tabulate([filtered_table], headers="keys", tablefmt="pipe", stralign="center"))


def benchmark_fn(f, *args, **kwargs):
    torch.cuda.synchronize()
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)

    start.record()
    output = f(*args, **kwargs)
    end.record()
    torch.cuda.synchronize()
    elapsed_time = round(start.elapsed_time(end) / 1000, 3)

    return elapsed_time, output


def prepare_allegro(dtype: torch.dtype, compile: bool = False, **kwargs):
    model_id = "rhymes-ai/Allegro"
    cache_dir = None

    pipe = AllegroPipeline.from_pretrained(model_id, torch_dtype=dtype, cache_dir=cache_dir)
    pipe.to("cuda")
    pipe.vae.enable_tiling()

    if compile:
        pipe.transformer = torch.compile(
            pipe.transformer, mode="max-autotune-no-cudagraphs", fullgraph=True, dynamic=False
        )

    for key, value in list(kwargs.items()):
        if torch.is_tensor(value):
            kwargs[key] = value.to(device="cuda", dtype=dtype)

    generation_kwargs = {
        "prompt": "A seaside harbor with bright sunlight and sparkling seawater, with many boats in the water. From an aerial view, the boats vary in size and color, some moving and some stationary. Fishing boats in the water suggest that this location might be a popular spot for docking fishing boats.",
        "height": 720,
        "width": 1280,
        "num_inference_steps": 50,
        "guidance_scale": 5.0,
        **kwargs,
    }

    return pipe, generation_kwargs


def prepare_cogvideox_1_0(dtype: torch.dtype, compile: bool = False, **kwargs):
    model_id = "THUDM/CogVideoX-5b"
    cache_dir = None

    pipe = CogVideoXPipeline.from_pretrained(model_id, torch_dtype=dtype, cache_dir=cache_dir)
    pipe.to("cuda")

    if compile:
        pipe.transformer = torch.compile(
            pipe.transformer, mode="max-autotune-no-cudagraphs", fullgraph=True, dynamic=False
        )

    for key, value in list(kwargs.items()):
        if torch.is_tensor(value):
            kwargs[key] = value.to(device="cuda", dtype=dtype)

    generation_kwargs = {
        "prompt": (
            "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
            "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
            "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
            "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
            "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
            "atmosphere of this unique musical performance."
        ),
        "height": 480,
        "width": 720,
        "num_frames": 49,
        "num_inference_steps": 50,
        "guidance_scale": 5.0,
        **kwargs,
    }

    return pipe, generation_kwargs


def prepare_flux(dtype: torch.dtype, compile: bool = False, **kwargs) -> None:
    model_id = "black-forest-labs/Flux.1-Dev"
    cache_dir = "/raid/.cache/huggingface"

    pipe = FluxPipeline.from_pretrained(model_id, torch_dtype=dtype, cache_dir=cache_dir)
    pipe.to("cuda")

    if compile:
        pipe.transformer = torch.compile(
            pipe.transformer, mode="max-autotune-no-cudagraphs", fullgraph=True, dynamic=False
        )

    for key, value in list(kwargs.items()):
        if torch.is_tensor(value):
            kwargs[key] = value.to(device="cuda", dtype=dtype)

    generation_kwargs = {
        "prompt": "A cat holding a sign that says hello world",
        "height": 768,
        "width": 768,
        "num_inference_steps": 50,
        "guidance_scale": 5.0,
        **kwargs,
    }

    return pipe, generation_kwargs


def prepare_hunyuan_video(dtype: torch.dtype, compile: bool = False, **kwargs):
    model_id = "hunyuanvideo-community/HunyuanVideo"
    cache_dir = None

    transformer = HunyuanVideoTransformer3DModel.from_pretrained(
        model_id, subfolder="transformer", torch_dtype=torch.bfloat16
    )
    pipe = HunyuanVideoPipeline.from_pretrained(
        model_id, transformer=transformer, torch_dtype=torch.float16, cache_dir=cache_dir
    )
    pipe.to("cuda")

    if compile:
        pipe.transformer = torch.compile(
            pipe.transformer, mode="max-autotune-no-cudagraphs", fullgraph=True, dynamic=False
        )

    for key, value in list(kwargs.items()):
        if torch.is_tensor(value):
            kwargs[key] = value.to(device="cuda", dtype=dtype)

    generation_kwargs = {
        "prompt": "A cat wearing sunglasses and working as a lifeguard at pool.",
        "height": 320,
        "width": 512,
        "num_frames": 61,
        "num_inference_steps": 30,
    }

    return pipe, generation_kwargs


def prepare_latte(dtype: torch.dtype, compile: bool = False, **kwargs):
    model_id = "maxin-cn/Latte-1"
    cache_dir = None

    pipe = LattePipeline.from_pretrained(model_id, torch_dtype=dtype, cache_dir=cache_dir)
    pipe.to("cuda")

    if compile:
        pipe.transformer = torch.compile(
            pipe.transformer, mode="max-autotune-no-cudagraphs", fullgraph=True, dynamic=False
        )

    for key, value in list(kwargs.items()):
        if torch.is_tensor(value):
            kwargs[key] = value.to(device="cuda", dtype=dtype)

    generation_kwargs = {
        "prompt": "a cat wearing sunglasses and working as a lifeguard at pool.",
        "height": 512,
        "width": 512,
        "video_length": 16,
        "num_inference_steps": 50,
    }

    return pipe, generation_kwargs


def prepare_mochi(dtype: torch.dtype, compile: bool = False, **kwargs):
    model_id = "genmo/mochi-1-preview"
    cache_dir = None

    pipe = MochiPipeline.from_pretrained(model_id, torch_dtype=dtype, cache_dir=cache_dir)
    pipe.to("cuda")
    pipe.vae.enable_tiling()

    if compile:
        pipe.transformer = torch.compile(
            pipe.transformer, mode="max-autotune-no-cudagraphs", fullgraph=True, dynamic=False
        )

    for key, value in list(kwargs.items()):
        if torch.is_tensor(value):
            kwargs[key] = value.to(device="cuda", dtype=dtype)

    generation_kwargs = {
        "prompt": "Close-up of a chameleon's eye, with its scaly skin changing color. Ultra high resolution 4k.",
        "height": 480,
        "width": 848,
        "num_frames": 85,
        "num_inference_steps": 50,
    }

    return pipe, generation_kwargs


def prepare_allegro_config(cache_method: str):
    if cache_method == "pyramid_attention_broadcast":
        return PyramidAttentionBroadcastConfig(
            spatial_attention_block_skip_range=2,
            cross_attention_block_skip_range=6,
            spatial_attention_timestep_skip_range=(100, 700),
            cross_attention_timestep_skip_range=(100, 800),
            spatial_attention_block_identifiers=["transformer_blocks"],
            cross_attention_block_identifiers=["transformer_blocks"],
        )
    elif cache_method == "fastercache":
        return FasterCacheConfig(
            spatial_attention_block_skip_range=2,
            spatial_attention_timestep_skip_range=(-1, 681),
            low_frequency_weight_update_timestep_range=(99, 641),
            high_frequency_weight_update_timestep_range=(-1, 301),
            spatial_attention_block_identifiers=["transformer_blocks"],
        )
    elif cache_method == "none":
        return None


def prepare_cogvideox_1_0_config(cache_method: str):
    if cache_method == "pyramid_attention_broadcast":
        return PyramidAttentionBroadcastConfig(
            spatial_attention_block_skip_range=2,
            spatial_attention_timestep_skip_range=(100, 800),
            spatial_attention_block_identifiers=["transformer_blocks"],
        )
    elif cache_method == "fastercache":
        return FasterCacheConfig(
            spatial_attention_block_skip_range=2,
            spatial_attention_timestep_skip_range=(-1, 681),
            low_frequency_weight_update_timestep_range=(99, 641),
            high_frequency_weight_update_timestep_range=(-1, 301),
            spatial_attention_block_identifiers=["transformer_blocks"],
            attention_weight_callback=lambda _: 0.3,
            tensor_format="BFCHW",
        )
    elif cache_method == "none":
        return None


def prepare_flux_config(cache_method: str):
    if cache_method == "pyramid_attention_broadcast":
        return PyramidAttentionBroadcastConfig(
            spatial_attention_block_skip_range=2,
            spatial_attention_timestep_skip_range=(100, 950),
            spatial_attention_block_identifiers=["transformer_blocks", "single_transformer_blocks"],
        )
    elif cache_method == "fastercache":
        return FasterCacheConfig(
            spatial_attention_block_skip_range=4,
            spatial_attention_timestep_skip_range=(-1, 961),
            spatial_attention_block_identifiers=["transformer_blocks", "single_transformer_blocks"],
            tensor_format="BCHW",
        )
    elif cache_method == "none":
        return None


def prepare_hunyuan_video_config(cache_method: str):
    if cache_method == "pyramid_attention_broadcast":
        return PyramidAttentionBroadcastConfig(
            spatial_attention_block_skip_range=2,
            spatial_attention_timestep_skip_range=(100, 800),
            spatial_attention_block_identifiers=["transformer_blocks", "single_transformer_blocks"],
        )
    elif cache_method == "fastercache":
        return FasterCacheConfig(
            spatial_attention_block_skip_range=4,
            spatial_attention_timestep_skip_range=(99, 941),
            spatial_attention_block_identifiers=["transformer_blocks", "single_transformer_blocks"],
            tensor_format="BCFHW",
        )
    elif cache_method == "none":
        return None


def prepare_latte_config(cache_method: str):
    if cache_method == "pyramid_attention_broadcast":
        return PyramidAttentionBroadcastConfig(
            spatial_attention_block_skip_range=2,
            temporal_attention_block_skip_range=3,
            cross_attention_block_skip_range=6,
            spatial_attention_timestep_skip_range=(100, 700),
            temporal_attention_timestep_skip_range=(100, 800),
            cross_attention_timestep_skip_range=(100, 800),
            spatial_attention_block_identifiers=["transformer_blocks"],
            temporal_attention_block_identifiers=["temporal_transformer_blocks"],
            cross_attention_block_identifiers=["transformer_blocks"],
        )
    elif cache_method == "fastercache":
        return FasterCacheConfig(
            spatial_attention_block_skip_range=2,
            temporal_attention_block_skip_range=2,
            spatial_attention_timestep_skip_range=(-1, 681),
            temporal_attention_timestep_skip_range=(-1, 681),
            low_frequency_weight_update_timestep_range=(99, 641),
            high_frequency_weight_update_timestep_range=(-1, 301),
            spatial_attention_block_identifiers=["transformer_blocks.*attn1"],
            temporal_attention_block_identifiers=["temporal_transformer_blocks"],
        )
    elif cache_method == "none":
        return None


def prepare_mochi_config(cache_method: str):
    if cache_method == "pyramid_attention_broadcast":
        return PyramidAttentionBroadcastConfig(
            spatial_attention_block_skip_range=2,
            spatial_attention_timestep_skip_range=(400, 987),
            spatial_attention_block_identifiers=["transformer_blocks"],
        )
    elif cache_method == "fastercache":
        return FasterCacheConfig(
            spatial_attention_block_skip_range=2,
            spatial_attention_timestep_skip_range=(-1, 981),
            low_frequency_weight_update_timestep_range=(301, 961),
            high_frequency_weight_update_timestep_range=(-1, 851),
            unconditional_batch_skip_range=4,
            unconditional_batch_timestep_skip_range=(-1, 975),
            spatial_attention_block_identifiers=["transformer_blocks"],
            attention_weight_callback=lambda _: 0.6,
        )
    elif cache_method == "none":
        return None


def decode_allegro(pipe: AllegroPipeline, latents: torch.Tensor, filename: pathlib.Path, **kwargs):
    filename = f"{filename.as_posix()}.mp4"
    video = pipe.decode_latents(latents)
    video = pipe.video_processor.postprocess_video(video=video, output_type="pil")[0]
    export_to_video(video, filename, fps=8)
    return filename


def decode_cogvideox_1_0(pipe: CogVideoXPipeline, latents: torch.Tensor, filename: pathlib.Path, **kwargs):
    filename = f"{filename.as_posix()}.mp4"
    video = pipe.decode_latents(latents)
    video = pipe.video_processor.postprocess_video(video=video, output_type="pil")[0]
    export_to_video(video, filename, fps=8)
    return filename


def decode_flux(pipe: FluxPipeline, latents: torch.Tensor, filename: pathlib.Path, **kwargs):
    height = kwargs["height"]
    width = kwargs["width"]
    filename = f"{filename.as_posix()}.png"
    latents = pipe._unpack_latents(latents, height, width, pipe.vae_scale_factor)
    latents = (latents / pipe.vae.config.scaling_factor) + pipe.vae.config.shift_factor
    image = pipe.vae.decode(latents, return_dict=False)[0]
    image = pipe.image_processor.postprocess(image, output_type="pil")[0]
    image.save(filename)
    return filename


def decode_hunyuan_video(pipe: HunyuanVideoPipeline, latents: torch.Tensor, filename: pathlib.Path, **kwargs):
    filename = f"{filename.as_posix()}.mp4"
    latents = latents.to(pipe.vae.dtype) / pipe.vae.config.scaling_factor
    video = pipe.vae.decode(latents, return_dict=False)[0]
    video = pipe.video_processor.postprocess_video(video, output_type="pil")[0]
    export_to_video(video, filename, fps=8)
    return filename


def decode_latte(pipe: LattePipeline, latents: torch.Tensor, filename: pathlib.Path, **kwargs):
    filename = f"{filename.as_posix()}.mp4"
    video = pipe.decode_latents(latents, video_length=kwargs["video_length"])
    video = pipe.video_processor.postprocess_video(video=video, output_type="pil")[0]
    export_to_video(video, filename, fps=8)
    return filename


def decode_mochi(pipe: MochiPipeline, latents: torch.Tensor, filename: pathlib.Path, **kwargs):
    filename = f"{filename.as_posix()}.mp4"
    latents_mean = torch.tensor(pipe.vae.config.latents_mean).view(1, 12, 1, 1, 1).to(latents.device, latents.dtype)
    latents_std = torch.tensor(pipe.vae.config.latents_std).view(1, 12, 1, 1, 1).to(latents.device, latents.dtype)
    latents = latents * latents_std / pipe.vae.config.scaling_factor + latents_mean
    video = pipe.vae.decode(latents, return_dict=False)[0]
    video = pipe.video_processor.postprocess_video(video=video, output_type="pil")[0]
    export_to_video(video, filename, fps=8)
    return filename


MODEL_MAPPING = {
    "allegro": {
        "prepare": prepare_allegro,
        "config": prepare_allegro_config,
        "decode": decode_allegro,
    },
    "cogvideox-1.0": {
        "prepare": prepare_cogvideox_1_0,
        "config": prepare_cogvideox_1_0_config,
        "decode": decode_cogvideox_1_0,
    },
    "flux": {
        "prepare": prepare_flux,
        "config": prepare_flux_config,
        "decode": decode_flux,
    },
    "hunyuan_video": {
        "prepare": prepare_hunyuan_video,
        "config": prepare_hunyuan_video_config,
        "decode": decode_hunyuan_video,
    },
    "latte": {
        "prepare": prepare_latte,
        "config": prepare_latte_config,
        "decode": decode_latte,
    },
    "mochi": {
        "prepare": prepare_mochi,
        "config": prepare_mochi_config,
        "decode": decode_mochi,
    },
}

STR_TO_COMPUTE_DTYPE = {
    "bf16": torch.bfloat16,
    "fp16": torch.float16,
    "fp32": torch.float32,
}


def run_inference(pipe, generation_kwargs):
    generator = torch.Generator("cuda").manual_seed(181201)
    output = pipe(generator=generator, output_type="latent", **generation_kwargs)[0]
    torch.cuda.synchronize()
    return output


@torch.no_grad()
def main(model_id: str, cache_method: str, output_dir: str, dtype: str, compile: bool = False):
    if model_id not in MODEL_MAPPING.keys():
        raise ValueError("Unsupported `model_id` specified.")

    output_dir = pathlib.Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    csv_filename = output_dir / f"{model_id}.csv"

    compute_dtype = STR_TO_COMPUTE_DTYPE[dtype]
    model = MODEL_MAPPING[model_id]

    try:
        torch.cuda.reset_peak_memory_stats()
        torch.cuda.reset_accumulated_memory_stats()
        gc.collect()
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()
        torch.cuda.synchronize()

        # 1. Prepare inputs and generation kwargs
        pipe, generation_kwargs = model["prepare"](dtype=compute_dtype, compile=compile)

        model_memory = round(torch.cuda.memory_allocated() / 1024**3, 3)
        model_max_memory_reserved = round(torch.cuda.max_memory_reserved() / 1024**3, 3)

        # 2. Apply attention approximation technique
        config = model["config"](cache_method)
        if cache_method == "pyramid_attention_broadcast":
            apply_pyramid_attention_broadcast(pipe, config)
        elif cache_method == "fastercache":
            apply_fastercache(pipe, config)
        elif cache_method == "none":
            pass
        else:
            raise ValueError(f"Invalid {cache_method=} provided.")

        # 3. Warmup
        num_warmups = 1
        original_num_inference_steps = generation_kwargs["num_inference_steps"]
        generation_kwargs["num_inference_steps"] = 2
        for _ in range(num_warmups):
            run_inference(pipe, generation_kwargs)
        generation_kwargs["num_inference_steps"] = original_num_inference_steps

        # 4. Benchmark
        time, latents = benchmark_fn(run_inference, pipe, generation_kwargs)
        inference_memory = round(torch.cuda.memory_allocated() / 1024**3, 3)
        inference_max_memory_reserved = round(torch.cuda.max_memory_reserved() / 1024**3, 3)

        # 5. Decode latents
        filename = output_dir / f"{model_id}---dtype-{dtype}---cache_method-{cache_method}---compile-{compile}"
        filename = model["decode"](
            pipe,
            latents,
            filename,
            height=generation_kwargs["height"],
            width=generation_kwargs["width"],
            video_length=generation_kwargs.get("video_length", None),
        )

        # 6. Save artifacts
        info = {
            "model_id": model_id,
            "cache_method": cache_method,
            "compute_dtype": dtype,
            "compile": compile,
            "time": time,
            "model_memory": model_memory,
            "model_max_memory_reserved": model_max_memory_reserved,
            "inference_memory": inference_memory,
            "inference_max_memory_reserved": inference_max_memory_reserved,
            "branch": branch,
            "filename": filename,
            "exception": None,
        }

    except Exception as e:
        print(f"An error occurred: {e}")
        traceback.print_exc()

        # 6. Save artifacts
        info = {
            "model_id": model_id,
            "cache_method": cache_method,
            "compute_dtype": dtype,
            "compile": compile,
            "time": None,
            "model_memory": None,
            "model_max_memory_reserved": None,
            "inference_memory": None,
            "inference_max_memory_reserved": None,
            "branch": branch,
            "filename": None,
            "exception": str(e),
        }

    pretty_print_results(info, precision=3)

    df = pd.DataFrame([info])
    df.to_csv(csv_filename.as_posix(), mode="a", index=False, header=not csv_filename.is_file())


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model_id",
        type=str,
        default="flux",
        choices=["flux", "cogvideox-1.0", "latte", "allegro", "hunyuan_video", "mochi"],
        help="Model to run benchmark for.",
    )
    parser.add_argument(
        "--cache_method",
        type=str,
        default="pyramid_attention_broadcast",
        choices=["pyramid_attention_broadcast", "fastercache", "none"],
        help="Cache method to use.",
    )
    parser.add_argument(
        "--output_dir", type=str, help="Path where the benchmark artifacts and outputs are the be saved."
    )
    parser.add_argument("--dtype", type=str, help="torch.dtype to use for inference")
    parser.add_argument(
        "--compile",
        action="store_true",
        default=False,
        help="Whether to torch.compile the denoiser.",
    )
    parser.add_argument("-v", "--verbose", action="store_true", help="Enable verbose logging.")
    args = parser.parse_args()

    if args.verbose:
        set_verbosity_debug()
    else:
        set_verbosity_info()

    main(args.model_id, args.cache_method, args.output_dir, args.dtype, args.compile)

TODO:

  • Docs
  • Tests

cc @cszy98 @ChenyangSi

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@a-r-r-o-w a-r-r-o-w added the performance Anything related to performance improvements, profiling and benchmarking label Dec 9, 2024
@a-r-r-o-w a-r-r-o-w mentioned this pull request Dec 23, 2024
7 tasks
@sayakpaul
Copy link
Member

Have some ideas around the design. LMK when is a good time to pass those off.

@a-r-r-o-w
Copy link
Member Author

@sayakpaul Feel free to mention them, but this is not finalized yet either.

Also, could you please not merge main into the branch unless it is ready for reviews? It causes merge conflicts unnecessarily that I don't want to deal with because I have changes locally 😭

import torch


# Reference: https://github.com/huggingface/accelerate/blob/ba7ab93f5e688466ea56908ea3b056fae2f9a023/src/accelerate/hooks.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit) Since some of the utils are taken from accelerate and repurposed here, adding a note on why we're not directly importing them from accelerate would be nice.


# Reference: https://github.com/Vchitect/FasterCache/blob/fab32c15014636dc854948319c0a9a8d92c7acb4/scripts/latte/fastercache_sample_latte.py#L127C1-L143C39
@torch.no_grad()
def _fft(tensor):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could go to torch_utils.py similar to:

def fourier_filter(x_in: "torch.Tensor", threshold: int, scale: int) -> "torch.Tensor":

@a-r-r-o-w a-r-r-o-w added the roadmap Add to current release roadmap label Jan 2, 2025
Copy link
Contributor

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@a-r-r-o-w a-r-r-o-w marked this pull request as ready for review January 28, 2025 22:14
@a-r-r-o-w a-r-r-o-w requested a review from DN6 January 28, 2025 22:14
@AC27MJ
Copy link

AC27MJ commented Feb 11, 2025

Hi, thanks for the great work! I’m interested in this feature and looking forward to using it. I noticed that all checks have passed, but it’s still awaiting review. Just wondering if there’s any update on the review process?Thank you. @DN6

Copy link
Contributor

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label Mar 17, 2025
@vladmandic
Copy link
Contributor

what is the status of this pr?

@a-r-r-o-w
Copy link
Member Author

@vladmandic This is ready to merge and there are no blockers. The only thing holding me off from merging is that, in light of some other caching techniques, I want to do it in a way that is agnostic to all our model implementations. The current implementation works agnostically up to some extent (I haven't tried it on all models though) as can be seen with the benchmarks in the PR desc. But, it's done with a variety of assumptions about what each internal transformer block returns. In Diffusers, we mostly always return the following:

  • hidden_states
  • hidden_states, encoder_hidden_states
  • encoder_hidden_states, hidden_states

Now, unique to each transformer block implementation, sometimes we might decide to skip the execution path (such as Skip Layer Guidance, First Block Cache, etc). We can do this in two ways: remove the block from the transformer stack, or return the input hidden_states/encoder_hidden_states as the output. Doing the former is not possible for us and is not the route we want to take. Doing the latter can easily be done with the "hooks" design introduced in this release schedule.

However, we need to maintain some metadata about what each transformer block returns, in which order does it return, can the inputs be directly returned as outputs if the block is to be skipped, etc. This information will significantly simplify adding new cache methods, since having a bunch of if-else statements to handle the different cases drastically reduces the debugability/maintainability (atleast for me).

I have a plan on how to implement it better so that I can quickly add other cache techniques, but I want to hold off from introducing changes in this release that I know I might have to break/change in next release. I currently am looking into a few other things though, and as a result this might have to wait a bit longer. I'm about ~75% confident that any changes I make will not affect the user-facing API and will not be breaking in any way, but better be safe than sorry. I'll discuss with Dhruv and see what he thinks, and merge if this is okay

@a-r-r-o-w
Copy link
Member Author

@bot /style

Copy link
Contributor

Style fixes have been applied. View the workflow run here.

@vladmandic
Copy link
Contributor

@a-r-r-o-w thanks for the detailed writeup!

@github-actions github-actions bot removed the stale Issues that haven't received updates label Mar 19, 2025
@a-r-r-o-w
Copy link
Member Author

Failing tests seem unrelated

@a-r-r-o-w a-r-r-o-w merged commit 844221a into main Mar 21, 2025
28 of 32 checks passed
@github-project-automation github-project-automation bot moved this from In Progress to Done in Diffusers Roadmap 0.35 Mar 21, 2025
@a-r-r-o-w a-r-r-o-w deleted the fastercache branch March 21, 2025 04:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
performance Anything related to performance improvements, profiling and benchmarking roadmap Add to current release roadmap
Projects
Archived in project
Development

Successfully merging this pull request may close these issues.

Is there any plan to support fastercache?
6 participants