From dabb12f23d476fcf8ccf8cb5e02cbe9663eba56e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 13 Jun 2025 22:09:37 +0300 Subject: [PATCH 1/8] first draft --- docs/source/en/_toctree.yml | 6 + .../en/api/models/autoencoder_kl_magi.md | 34 + .../en/api/models/magi_transformer_3d.md | 32 + docs/source/en/api/pipelines/magi.md | 309 ++++++++ scripts/convert_magi_to_diffusers.py | 657 ++++++++++++++++ src/diffusers/models/__init__.py | 4 + src/diffusers/models/autoencoders/__init__.py | 1 + .../autoencoders/autoencoder_kl_magi.py | 719 ++++++++++++++++++ .../models/transformers/transformer_magi.py | 668 ++++++++++++++++ src/diffusers/pipelines/__init__.py | 3 + src/diffusers/pipelines/magi/__init__.py | 51 ++ src/diffusers/pipelines/magi/pipeline_magi.py | 641 ++++++++++++++++ .../pipelines/magi/pipeline_magi_i2v.py | 546 +++++++++++++ .../pipelines/magi/pipeline_magi_v2v.py | 552 ++++++++++++++ .../pipelines/magi/pipeline_output.py | 34 + .../test_models_autoencoder_kl_magi.py | 155 ++++ .../test_models_transformer_magi.py | 91 +++ tests/pipelines/magi/test_magi.py | 158 ++++ .../magi/test_magi_image_to_video.py | 215 ++++++ .../magi/test_magi_video_to_video.py | 148 ++++ ...test_model_magi_autoencoder_single_file.py | 64 ++ ...st_model_magi_transformer3d_single_file.py | 84 ++ 22 files changed, 5172 insertions(+) create mode 100644 docs/source/en/api/models/autoencoder_kl_magi.md create mode 100644 docs/source/en/api/models/magi_transformer_3d.md create mode 100644 docs/source/en/api/pipelines/magi.md create mode 100644 scripts/convert_magi_to_diffusers.py create mode 100644 src/diffusers/models/autoencoders/autoencoder_kl_magi.py create mode 100644 src/diffusers/models/transformers/transformer_magi.py create mode 100644 src/diffusers/pipelines/magi/__init__.py create mode 100644 src/diffusers/pipelines/magi/pipeline_magi.py create mode 100644 src/diffusers/pipelines/magi/pipeline_magi_i2v.py create mode 100644 src/diffusers/pipelines/magi/pipeline_magi_v2v.py create mode 100644 src/diffusers/pipelines/magi/pipeline_output.py create mode 100644 tests/models/autoencoders/test_models_autoencoder_kl_magi.py create mode 100644 tests/models/transformers/test_models_transformer_magi.py create mode 100644 tests/pipelines/magi/test_magi.py create mode 100644 tests/pipelines/magi/test_magi_image_to_video.py create mode 100644 tests/pipelines/magi/test_magi_video_to_video.py create mode 100644 tests/single_file/test_model_magi_autoencoder_single_file.py create mode 100644 tests/single_file/test_model_magi_transformer3d_single_file.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index f13b7d54aec4..1a386615f6ad 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -313,6 +313,8 @@ title: Lumina2Transformer2DModel - local: api/models/lumina_nextdit2d title: LuminaNextDiT2DModel + - local: api/models/magi_transformer_3d + title: MagiTransformer3DModel - local: api/models/mochi_transformer3d title: MochiTransformer3DModel - local: api/models/omnigen_transformer @@ -367,6 +369,8 @@ title: AutoencoderKLHunyuanVideo - local: api/models/autoencoderkl_ltx_video title: AutoencoderKLLTXVideo + - local: api/models/autoencoder_kl_magi + title: AutoencoderKLMagi - local: api/models/autoencoderkl_magvit title: AutoencoderKLMagvit - local: api/models/autoencoderkl_mochi @@ -487,6 +491,8 @@ title: Lumina 2.0 - local: api/pipelines/lumina title: Lumina-T2X + - local: api/pipelines/magi + title: MAGI-1 - local: api/pipelines/marigold title: Marigold - local: api/pipelines/mochi diff --git a/docs/source/en/api/models/autoencoder_kl_magi.md b/docs/source/en/api/models/autoencoder_kl_magi.md new file mode 100644 index 000000000000..cc5f16a4e713 --- /dev/null +++ b/docs/source/en/api/models/autoencoder_kl_magi.md @@ -0,0 +1,34 @@ + + +# AutoencoderKLMagi + +The 3D variational autoencoder (VAE) model with KL loss used in [MAGI-1: Autoregressive Video Generation at Scale](https://arxiv.org/abs/2505.13211) by Sand.ai. + +MAGI-1 uses a transformer-based VAE with 8x spatial and 4x temporal compression, providing fast average decoding time and highly competitive reconstruction quality. + +The model can be loaded with the following code snippet. + +```python +from diffusers import AutoencoderKLMagi + +vae = AutoencoderKLMagi.from_pretrained("sand-ai/MAGI-1", subfolder="vae", torch_dtype=torch.float32) +``` + +## AutoencoderKLMagi + +[[autodoc]] AutoencoderKLMagi + - decode + - all + +## DecoderOutput + +[[autodoc]] models.autoencoders.vae.DecoderOutput \ No newline at end of file diff --git a/docs/source/en/api/models/magi_transformer_3d.md b/docs/source/en/api/models/magi_transformer_3d.md new file mode 100644 index 000000000000..d737dfc956a7 --- /dev/null +++ b/docs/source/en/api/models/magi_transformer_3d.md @@ -0,0 +1,32 @@ + + +# MagiTransformer3DModel + +A Diffusion Transformer model for 3D video-like data was introduced in [MAGI-1: Autoregressive Video Generation at Scale](https://arxiv.org/abs/2505.13211) by Sand.ai. + +MAGI-1 is an autoregressive denoising video generation model that generates videos chunk-by-chunk instead of as a whole. Each chunk (24 frames) is denoised holistically, and the generation of the next chunk begins as soon as the current one reaches a certain level of denoising. + +The model can be loaded with the following code snippet. + +```python +from diffusers import MagiTransformer3DModel + +transformer = MagiTransformer3DModel.from_pretrained("sand-ai/MAGI-1", subfolder="transformer", torch_dtype=torch.bfloat16) +``` + +## MagiTransformer3DModel + +[[autodoc]] MagiTransformer3DModel + +## Transformer2DModelOutput + +[[autodoc]] models.modeling_outputs.Transformer2DModelOutput \ No newline at end of file diff --git a/docs/source/en/api/pipelines/magi.md b/docs/source/en/api/pipelines/magi.md new file mode 100644 index 000000000000..3a8a90cdb3a5 --- /dev/null +++ b/docs/source/en/api/pipelines/magi.md @@ -0,0 +1,309 @@ + + +
+
+ + LoRA + +
+
+ +# MAGI-1 + +[MAGI-1: Autoregressive Video Generation at Scale](https://arxiv.org/abs/2505.13211) by Sand.ai. + +*MAGI-1 is an autoregressive video generation model that generates videos chunk-by-chunk instead of as a whole. Each chunk (24 frames) is denoised holistically, and the generation of the next chunk begins as soon as the current one reaches a certain level of denoising. This pipeline design enables concurrent processing of up to four chunks for efficient video generation. The model leverages a specialized architecture with a transformer-based VAE with 8x spatial and 4x temporal compression, and a diffusion transformer with several key innovations including Block-Causal Attention, Parallel Attention Block, QK-Norm and GQA, Sandwich Normalization in FFN, SwiGLU, and Softcap Modulation.* + +You can find the MAGI-1 checkpoints under the [sand-ai](https://huggingface.co/sand-ai) organization. + +The following MAGI models are supported in Diffusers: +- [MAGI-1 24B](https://huggingface.co/sand-ai/MAGI-1) +- [MAGI-1 4.5B](https://huggingface.co/sand-ai/MAGI-1-4.5B) + +> [!TIP] +> Click on the MAGI-1 models in the right sidebar for more examples of video generation. + +### Text-to-Video Generation + +The example below demonstrates how to generate a video from text optimized for memory or inference speed. + + + + +Refer to the [Reduce memory usage](../../optimization/memory) guide for more details about the various memory saving techniques. + +The MAGI-1 text-to-video model below requires ~13GB of VRAM. + +```py +import torch +import numpy as np +from diffusers import AutoModel, MagiPipeline +from diffusers.hooks.group_offloading import apply_group_offloading +from diffusers.utils import export_to_video +from transformers import T5EncoderModel + +text_encoder = T5EncoderModel.from_pretrained("sand-ai/MAGI-1", subfolder="text_encoder", torch_dtype=torch.bfloat16) +vae = AutoModel.from_pretrained("sand-ai/MAGI-1", subfolder="vae", torch_dtype=torch.float32) +transformer = AutoModel.from_pretrained("sand-ai/MAGI-1", subfolder="transformer", torch_dtype=torch.bfloat16) + +# group-offloading +onload_device = torch.device("cuda") +offload_device = torch.device("cpu") +apply_group_offloading(text_encoder, + onload_device=onload_device, + offload_device=offload_device, + offload_type="block_level", + num_blocks_per_group=4 +) +transformer.enable_group_offload( + onload_device=onload_device, + offload_device=offload_device, + offload_type="leaf_level", + use_stream=True +) + +pipeline = MagiPipeline.from_pretrained( + "sand-ai/MAGI-1", + vae=vae, + transformer=transformer, + text_encoder=text_encoder, + torch_dtype=torch.bfloat16 +) +pipeline.to("cuda") + +prompt = """ +A majestic eagle soaring over a mountain landscape. The eagle's wings are spread wide, +catching the golden sunlight as it glides through the clear blue sky. Below, snow-capped +mountains stretch to the horizon, with pine forests and a winding river visible in the valley. +""" +negative_prompt = """ +Poor quality, blurry, pixelated, low resolution, distorted proportions, unnatural colors, +watermark, text overlay, incomplete rendering, glitches, artifacts, unrealistic lighting +""" + +output = pipeline( + prompt=prompt, + negative_prompt=negative_prompt, + num_frames=24, + guidance_scale=7.0, +).frames[0] +export_to_video(output, "output.mp4", fps=8) +``` + + + + +[Compilation](../../optimization/fp16#torchcompile) is slow the first time but subsequent calls to the pipeline are faster. + +```py +import torch +import numpy as np +from diffusers import AutoModel, MagiPipeline +from diffusers.utils import export_to_video +from transformers import T5EncoderModel + +text_encoder = T5EncoderModel.from_pretrained("sand-ai/MAGI-1", subfolder="text_encoder", torch_dtype=torch.bfloat16) +vae = AutoModel.from_pretrained("sand-ai/MAGI-1", subfolder="vae", torch_dtype=torch.float32) +transformer = AutoModel.from_pretrained("sand-ai/MAGI-1", subfolder="transformer", torch_dtype=torch.bfloat16) + +pipeline = MagiPipeline.from_pretrained( + "sand-ai/MAGI-1", + vae=vae, + transformer=transformer, + text_encoder=text_encoder, + torch_dtype=torch.bfloat16 +) +pipeline.to("cuda") + +# torch.compile +pipeline.transformer.to(memory_format=torch.channels_last) +pipeline.transformer = torch.compile( + pipeline.transformer, mode="max-autotune", fullgraph=True +) + +prompt = """ +A majestic eagle soaring over a mountain landscape. The eagle's wings are spread wide, +catching the golden sunlight as it glides through the clear blue sky. Below, snow-capped +mountains stretch to the horizon, with pine forests and a winding river visible in the valley. +""" +negative_prompt = """ +Poor quality, blurry, pixelated, low resolution, distorted proportions, unnatural colors, +watermark, text overlay, incomplete rendering, glitches, artifacts, unrealistic lighting +""" + +output = pipeline( + prompt=prompt, + negative_prompt=negative_prompt, + num_frames=24, + guidance_scale=7.0, +).frames[0] +export_to_video(output, "output.mp4", fps=8) +``` + + + + +### Image-to-Video Generation + +The example below demonstrates how to use the image-to-video pipeline to generate a video using a text description and a starting frame. + + + + +```python +import numpy as np +import torch +import torchvision.transforms.functional as TF +from diffusers import AutoencoderKLMagi, MagiImageToVideoPipeline +from diffusers.utils import export_to_video, load_image +from transformers import CLIPVisionModel + +model_id = "sand-ai/MAGI-1" +image_encoder = CLIPVisionModel.from_pretrained(model_id, subfolder="image_encoder", torch_dtype=torch.float32) +vae = AutoencoderKLMagi.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32) +pipe = MagiImageToVideoPipeline.from_pretrained( + model_id, vae=vae, image_encoder=image_encoder, torch_dtype=torch.bfloat16 +) +pipe.to("cuda") + +image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image.png") + +def aspect_ratio_resize(image, pipe, max_area=720 * 1280): + aspect_ratio = image.height / image.width + mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1] + height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value + width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value + image = image.resize((width, height)) + return image, height, width + +image, height, width = aspect_ratio_resize(image, pipe) + +prompt = "A beautiful landscape with mountains and a lake. The camera slowly pans from left to right, revealing more of the landscape." + +output = pipe( + image=image, prompt=prompt, height=height, width=width, guidance_scale=7.5, num_frames=24 +).frames[0] +export_to_video(output, "output.mp4", fps=8) +``` + + + + +### First-Last-Frame-to-Video Generation + +The example below demonstrates how to use the image-to-video pipeline to generate a video using a text description, a starting frame, and an ending frame. + + + + +```python +import numpy as np +import torch +import torchvision.transforms.functional as TF +from diffusers import AutoencoderKLMagi, MagiImageToVideoPipeline +from diffusers.utils import export_to_video, load_image +from transformers import CLIPVisionModel + +model_id = "sand-ai/MAGI-1" +image_encoder = CLIPVisionModel.from_pretrained(model_id, subfolder="image_encoder", torch_dtype=torch.float32) +vae = AutoencoderKLMagi.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32) +pipe = MagiImageToVideoPipeline.from_pretrained( + model_id, vae=vae, image_encoder=image_encoder, torch_dtype=torch.bfloat16 +) +pipe.to("cuda") + +first_frame = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/first_frame.png") +last_frame = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/last_frame.png") + +def aspect_ratio_resize(image, pipe, max_area=720 * 1280): + aspect_ratio = image.height / image.width + mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1] + height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value + width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value + image = image.resize((width, height)) + return image, height, width + +def center_crop_resize(image, height, width): + # Calculate resize ratio to match first frame dimensions + resize_ratio = max(width / image.width, height / image.height) + + # Resize the image + width = round(image.width * resize_ratio) + height = round(image.height * resize_ratio) + size = [width, height] + image = TF.center_crop(image, size) + + return image, height, width + +first_frame, height, width = aspect_ratio_resize(first_frame, pipe) +if last_frame.size != first_frame.size: + last_frame, _, _ = center_crop_resize(last_frame, height, width) + +prompt = "A car driving down a winding mountain road. The camera follows the car as it navigates the curves, revealing beautiful mountain scenery in the background." + +output = pipe( + image=first_frame, last_image=last_frame, prompt=prompt, height=height, width=width, guidance_scale=7.5, num_frames=24 +).frames[0] +export_to_video(output, "output.mp4", fps=8) +``` + + + + +### Video-to-Video Generation + +The example below demonstrates how to use the video-to-video pipeline to generate a video based on an existing video and text prompt. + + + + +```python +import torch +import numpy as np +from diffusers import AutoencoderKLMagi, MagiVideoToVideoPipeline +from diffusers.utils import export_to_video, load_video +from transformers import T5EncoderModel + +model_id = "sand-ai/MAGI-1" +text_encoder = T5EncoderModel.from_pretrained(model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16) +vae = AutoencoderKLMagi.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32) +pipe = MagiVideoToVideoPipeline.from_pretrained( + model_id, vae=vae, text_encoder=text_encoder, torch_dtype=torch.bfloat16 +) +pipe.to("cuda") + +# Load input video +video_path = "input_video.mp4" +video = load_video(video_path) + +prompt = "Convert this video to an anime style with vibrant colors and exaggerated features" +negative_prompt = "Poor quality, blurry, distorted, unrealistic lighting, bad composition" + +output = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + video=video, + strength=0.7, # Controls how much to preserve from original video + guidance_scale=7.5, +).frames[0] +export_to_video(output, "output.mp4", fps=8) +``` + + + + +## Notes + +- MAGI-1 supports LoRAs with [`~loaders.MagiLoraLoaderMixin.load_lora_weights`]. \ No newline at end of file diff --git a/scripts/convert_magi_to_diffusers.py b/scripts/convert_magi_to_diffusers.py new file mode 100644 index 000000000000..38c50271f632 --- /dev/null +++ b/scripts/convert_magi_to_diffusers.py @@ -0,0 +1,657 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Convert MAGI-1 checkpoints to diffusers format.""" + +import argparse +import json +import os +from pathlib import Path + +import torch +from huggingface_hub import hf_hub_download +from safetensors import safe_open +from safetensors.torch import load_file +from transformers import AutoTokenizer, UMT5EncoderModel + +from diffusers import ( + AutoencoderKLMagi, + MagiPipeline, + MagiTransformer3DModel, + FlowMatchEulerDiscreteScheduler, +) + + +# Mapping dictionary for transformer weights +TRANSFORMER_KEYS_RENAME_DICT = { + "t_embedder.mlp.0": "time_embedding.0", + "t_embedder.mlp.2": "time_embedding.2", + "y_embedder.y_proj_adaln.0": "text_embedding.0", + "y_embedder.y_proj_xattn.0": "cross_attention_proj", + "y_embedder.null_caption_embedding": "null_caption_embedding", + "rope.bands": "rotary_emb.bands", + "videodit_blocks.final_layernorm": "transformer_blocks.norm_final", + "final_linear.linear": "proj_out", +} + +# Layer-specific mappings +LAYER_KEYS_RENAME_DICT = { + "ada_modulate_layer.proj.0": "ff_norm", + "self_attention.linear_kv_xattn": "attn1.to_kv", + "self_attention.linear_proj": "attn1.to_out", + "mlp.linear_fc1": "ff.net.0.proj", +} + + +def convert_magi_vae_checkpoint(checkpoint_path, vae_config_file=None, dtype=None): + """ + Convert a MAGI-1 VAE checkpoint to a diffusers AutoencoderKLMagi. + + Args: + checkpoint_path: Path to the MAGI-1 VAE checkpoint. + vae_config_file: Optional path to a VAE config file. + dtype: Optional dtype for the model. + + Returns: + A diffusers AutoencoderKLMagi model. + """ + if vae_config_file is not None: + with open(vae_config_file, "r") as f: + config = json.load(f) + else: + # Default config for MAGI-1 VAE based on the checkpoint structure + config = { + "in_channels": 3, + "out_channels": 3, + "latent_channels": 16, # Based on encoder.last_layer.weight shape [32, 1024] -> 16 channels (32/2) + "block_out_channels": [1024], # Hidden dimension in transformer blocks + "layers_per_block": 24, # 24 transformer blocks in encoder/decoder + "act_fn": "gelu", + "latent_channels": 16, + "norm_num_groups": 32, + "scaling_factor": 0.18215, + "sample_size": 256, # Typical image size + } + + # Create the diffusers VAE model + vae = AutoencoderKLMagi( + in_channels=config["in_channels"], + out_channels=config["out_channels"], + latent_channels=config["latent_channels"], + layers_per_block=config["layers_per_block"], + block_out_channels=config["block_out_channels"], + act_fn=config["act_fn"], + norm_num_groups=config["norm_num_groups"], + scaling_factor=config["scaling_factor"], + sample_size=config["sample_size"], + ) + + # Load the checkpoint + if checkpoint_path.endswith(".safetensors"): + # Load safetensors file + checkpoint = load_file(checkpoint_path) + else: + # Load PyTorch checkpoint + checkpoint = torch.load(checkpoint_path, map_location="cpu") + + # Convert and load the state dict + converted_state_dict = convert_vae_state_dict(checkpoint) + + # Load the state dict + missing_keys, unexpected_keys = vae.load_state_dict(converted_state_dict, strict=False) + + if missing_keys: + print(f"Missing keys in VAE: {missing_keys}") + if unexpected_keys: + print(f"Unexpected keys in VAE: {unexpected_keys}") + + if dtype is not None: + vae = vae.to(dtype=dtype) + + return vae + + +def convert_vae_state_dict(checkpoint): + """ + Convert MAGI-1 VAE state dict to diffusers format. + + Maps the keys from the MAGI-1 VAE state dict to the diffusers VAE state dict. + """ + state_dict = {} + + # Encoder mappings + # Patch embedding + if "encoder.patch_embed.proj.weight" in checkpoint: + state_dict["encoder.conv_in.weight"] = checkpoint["encoder.patch_embed.proj.weight"] + state_dict["encoder.conv_in.bias"] = checkpoint["encoder.patch_embed.proj.bias"] + + # Position embeddings + if "encoder.pos_embed" in checkpoint: + state_dict["encoder.pos_embed"] = checkpoint["encoder.pos_embed"] + + # Class token + if "encoder.cls_token" in checkpoint: + state_dict["encoder.class_embedding"] = checkpoint["encoder.cls_token"] + + # Encoder blocks + for i in range(24): # Assuming 24 blocks in the encoder + # Check if this block exists + if f"encoder.blocks.{i}.attn.qkv.weight" not in checkpoint: + continue + + # Attention components + state_dict[f"encoder.transformer_blocks.{i}.attn1.to_qkv.weight"] = checkpoint[f"encoder.blocks.{i}.attn.qkv.weight"] + state_dict[f"encoder.transformer_blocks.{i}.attn1.to_qkv.bias"] = checkpoint[f"encoder.blocks.{i}.attn.qkv.bias"] + state_dict[f"encoder.transformer_blocks.{i}.attn1.to_out.0.weight"] = checkpoint[f"encoder.blocks.{i}.attn.proj.weight"] + state_dict[f"encoder.transformer_blocks.{i}.attn1.to_out.0.bias"] = checkpoint[f"encoder.blocks.{i}.attn.proj.bias"] + + # Normalization + state_dict[f"encoder.transformer_blocks.{i}.norm2.weight"] = checkpoint[f"encoder.blocks.{i}.norm2.weight"] + state_dict[f"encoder.transformer_blocks.{i}.norm2.bias"] = checkpoint[f"encoder.blocks.{i}.norm2.bias"] + + # MLP components + state_dict[f"encoder.transformer_blocks.{i}.ff.net.0.proj.weight"] = checkpoint[f"encoder.blocks.{i}.mlp.fc1.weight"] + state_dict[f"encoder.transformer_blocks.{i}.ff.net.0.proj.bias"] = checkpoint[f"encoder.blocks.{i}.mlp.fc1.bias"] + state_dict[f"encoder.transformer_blocks.{i}.ff.net.2.weight"] = checkpoint[f"encoder.blocks.{i}.mlp.fc2.weight"] + state_dict[f"encoder.transformer_blocks.{i}.ff.net.2.bias"] = checkpoint[f"encoder.blocks.{i}.mlp.fc2.bias"] + + # Encoder norm + if "encoder.norm.weight" in checkpoint: + state_dict["encoder.norm_out.weight"] = checkpoint["encoder.norm.weight"] + state_dict["encoder.norm_out.bias"] = checkpoint["encoder.norm.bias"] + + # Encoder last layer (projection to latent space) + if "encoder.last_layer.weight" in checkpoint: + state_dict["encoder.conv_out.weight"] = checkpoint["encoder.last_layer.weight"] + state_dict["encoder.conv_out.bias"] = checkpoint["encoder.last_layer.bias"] + + # Decoder mappings + # Projection from latent space + if "decoder.proj_in.weight" in checkpoint: + state_dict["decoder.conv_in.weight"] = checkpoint["decoder.proj_in.weight"] + state_dict["decoder.conv_in.bias"] = checkpoint["decoder.proj_in.bias"] + + # Position embeddings + if "decoder.pos_embed" in checkpoint: + state_dict["decoder.pos_embed"] = checkpoint["decoder.pos_embed"] + + # Class token + if "decoder.cls_token" in checkpoint: + state_dict["decoder.class_embedding"] = checkpoint["decoder.cls_token"] + + # Decoder blocks + for i in range(24): # Assuming 24 blocks in the decoder + # Check if this block exists + if f"decoder.blocks.{i}.attn.qkv.weight" not in checkpoint: + continue + + # Attention components + state_dict[f"decoder.transformer_blocks.{i}.attn1.to_qkv.weight"] = checkpoint[f"decoder.blocks.{i}.attn.qkv.weight"] + state_dict[f"decoder.transformer_blocks.{i}.attn1.to_qkv.bias"] = checkpoint[f"decoder.blocks.{i}.attn.qkv.bias"] + state_dict[f"decoder.transformer_blocks.{i}.attn1.to_out.0.weight"] = checkpoint[f"decoder.blocks.{i}.attn.proj.weight"] + state_dict[f"decoder.transformer_blocks.{i}.attn1.to_out.0.bias"] = checkpoint[f"decoder.blocks.{i}.attn.proj.bias"] + + # Normalization + state_dict[f"decoder.transformer_blocks.{i}.norm2.weight"] = checkpoint[f"decoder.blocks.{i}.norm2.weight"] + state_dict[f"decoder.transformer_blocks.{i}.norm2.bias"] = checkpoint[f"decoder.blocks.{i}.norm2.bias"] + + # MLP components + state_dict[f"decoder.transformer_blocks.{i}.ff.net.0.proj.weight"] = checkpoint[f"decoder.blocks.{i}.mlp.fc1.weight"] + state_dict[f"decoder.transformer_blocks.{i}.ff.net.0.proj.bias"] = checkpoint[f"decoder.blocks.{i}.mlp.fc1.bias"] + state_dict[f"decoder.transformer_blocks.{i}.ff.net.2.weight"] = checkpoint[f"decoder.blocks.{i}.mlp.fc2.weight"] + state_dict[f"decoder.transformer_blocks.{i}.ff.net.2.bias"] = checkpoint[f"decoder.blocks.{i}.mlp.fc2.bias"] + + # Decoder norm + if "decoder.norm.weight" in checkpoint: + state_dict["decoder.norm_out.weight"] = checkpoint["decoder.norm.weight"] + state_dict["decoder.norm_out.bias"] = checkpoint["decoder.norm.bias"] + + # Decoder last layer (projection to pixel space) + if "decoder.last_layer.weight" in checkpoint: + state_dict["decoder.conv_out.weight"] = checkpoint["decoder.last_layer.weight"] + state_dict["decoder.conv_out.bias"] = checkpoint["decoder.last_layer.bias"] + + # Quant conv (encoder output to latent distribution) + if "quant_conv.weight" in checkpoint: + state_dict["quant_conv.weight"] = checkpoint["quant_conv.weight"] + state_dict["quant_conv.bias"] = checkpoint["quant_conv.bias"] + + # Post quant conv (latent to decoder input) + if "post_quant_conv.weight" in checkpoint: + state_dict["post_quant_conv.weight"] = checkpoint["post_quant_conv.weight"] + state_dict["post_quant_conv.bias"] = checkpoint["post_quant_conv.bias"] + + return state_dict + + +def load_magi_transformer_checkpoint(checkpoint_path): + """ + Load a MAGI-1 transformer checkpoint. + + Args: + checkpoint_path: Path to the MAGI-1 transformer checkpoint. + + Returns: + The loaded checkpoint state dict. + """ + if checkpoint_path.endswith(".safetensors"): + # Load safetensors file directly + state_dict = load_file(checkpoint_path) + elif os.path.isdir(checkpoint_path): + # Check for sharded safetensors files + safetensors_files = [f for f in os.listdir(checkpoint_path) if f.endswith(".safetensors")] + if safetensors_files: + # Load and merge sharded safetensors files + state_dict = {} + for safetensors_file in safetensors_files: + file_path = os.path.join(checkpoint_path, safetensors_file) + with safe_open(file_path, framework="pt", device="cpu") as f: + for key in f.keys(): + state_dict[key] = f.get_tensor(key) + else: + # Try loading PyTorch checkpoint + checkpoint_files = [f for f in os.listdir(checkpoint_path) if f.endswith(".pt") or f.endswith(".pth")] + if not checkpoint_files: + raise ValueError(f"No checkpoint files found in {checkpoint_path}") + + checkpoint_file = os.path.join(checkpoint_path, checkpoint_files[0]) + state_dict = torch.load(checkpoint_file, map_location="cpu") + else: + # Try loading PyTorch checkpoint + state_dict = torch.load(checkpoint_path, map_location="cpu") + + return state_dict + + +def convert_magi_transformer_checkpoint(checkpoint_path, transformer_config_file=None, dtype=None): + """ + Convert a MAGI-1 transformer checkpoint to a diffusers MagiTransformer3DModel. + + Args: + checkpoint_path: Path to the MAGI-1 transformer checkpoint. + transformer_config_file: Optional path to a transformer config file. + dtype: Optional dtype for the model. + + Returns: + A diffusers MagiTransformer3DModel model. + """ + if transformer_config_file is not None: + with open(transformer_config_file, "r") as f: + config = json.load(f) + else: + # Default config for MAGI-1 transformer based on the full parameter list + config = { + "in_channels": 16, # Must match VAE latent channels + "out_channels": 16, # Must match VAE latent channels + "num_layers": 34, # Based on the full parameter list (0-33) + "num_attention_heads": 16, + "attention_head_dim": 64, + "cross_attention_dim": 4096, # T5 hidden size + "patch_size": [1, 2, 2], + "use_linear_projection": True, + "upcast_attention": False, + } + + # Create the diffusers transformer model + transformer = MagiTransformer3DModel( + in_channels=config["in_channels"], + out_channels=config["out_channels"], + num_layers=config["num_layers"], + num_attention_heads=config["num_attention_heads"], + attention_head_dim=config["attention_head_dim"], + cross_attention_dim=config["cross_attention_dim"], + patch_size=config["patch_size"], + use_linear_projection=config["use_linear_projection"], + upcast_attention=config["upcast_attention"], + ) + + # Load the checkpoint + checkpoint = load_magi_transformer_checkpoint(checkpoint_path) + + # Convert and load the state dict + converted_state_dict = convert_transformer_state_dict(checkpoint) + + # Load the state dict + missing_keys, unexpected_keys = transformer.load_state_dict(converted_state_dict, strict=False) + + if missing_keys: + print(f"Missing keys in transformer: {missing_keys}") + if unexpected_keys: + print(f"Unexpected keys in transformer: {unexpected_keys}") + + if dtype is not None: + transformer = transformer.to(dtype=dtype) + + return transformer + + +def convert_transformer_state_dict(checkpoint): + """ + Convert MAGI-1 transformer state dict to diffusers format. + + Maps the keys from the MAGI-1 transformer state dict to the diffusers transformer state dict. + """ + state_dict = {} + + # Process input projection + if "x_embedder.weight" in checkpoint: + state_dict["input_proj.weight"] = checkpoint["x_embedder.weight"] + + # Process time embedding + if "t_embedder.mlp.0.weight" in checkpoint: + state_dict["time_embedding.0.weight"] = checkpoint["t_embedder.mlp.0.weight"] + state_dict["time_embedding.0.bias"] = checkpoint["t_embedder.mlp.0.bias"] + state_dict["time_embedding.2.weight"] = checkpoint["t_embedder.mlp.2.weight"] + state_dict["time_embedding.2.bias"] = checkpoint["t_embedder.mlp.2.bias"] + + # Process text embedding + if "y_embedder.y_proj_adaln.0.weight" in checkpoint: + state_dict["text_embedding.0.weight"] = checkpoint["y_embedder.y_proj_adaln.0.weight"] + state_dict["text_embedding.0.bias"] = checkpoint["y_embedder.y_proj_adaln.0.bias"] + + if "y_embedder.y_proj_xattn.0.weight" in checkpoint: + state_dict["cross_attention_proj.weight"] = checkpoint["y_embedder.y_proj_xattn.0.weight"] + state_dict["cross_attention_proj.bias"] = checkpoint["y_embedder.y_proj_xattn.0.bias"] + + # Process null caption embedding + if "y_embedder.null_caption_embedding" in checkpoint: + state_dict["null_caption_embedding"] = checkpoint["y_embedder.null_caption_embedding"] + + # Process rotary embedding + if "rope.bands" in checkpoint: + state_dict["rotary_emb.bands"] = checkpoint["rope.bands"] + + # Process final layer norm + if "videodit_blocks.final_layernorm.weight" in checkpoint: + state_dict["transformer_blocks.norm_final.weight"] = checkpoint["videodit_blocks.final_layernorm.weight"] + state_dict["transformer_blocks.norm_final.bias"] = checkpoint["videodit_blocks.final_layernorm.bias"] + + # Process final linear projection + if "final_linear.linear.weight" in checkpoint: + state_dict["proj_out.weight"] = checkpoint["final_linear.linear.weight"] + + # Process transformer blocks + # Based on the full parameter list, there are 34 layers (0-33) + num_layers = 34 + for i in range(num_layers): + # Check if this layer exists in the checkpoint + layer_prefix = f"videodit_blocks.layers.{i}" + if f"{layer_prefix}.ada_modulate_layer.proj.0.weight" not in checkpoint: + continue + + # FF norm (AdaLN projection) + state_dict[f"transformer_blocks.{i}.ff_norm.weight"] = checkpoint[f"{layer_prefix}.ada_modulate_layer.proj.0.weight"] + state_dict[f"transformer_blocks.{i}.ff_norm.bias"] = checkpoint[f"{layer_prefix}.ada_modulate_layer.proj.0.bias"] + + # Self-attention components + + # Query normalization + if f"{layer_prefix}.self_attention.q_layernorm.weight" in checkpoint: + state_dict[f"transformer_blocks.{i}.attn1.norm_q.weight"] = checkpoint[f"{layer_prefix}.self_attention.q_layernorm.weight"] + state_dict[f"transformer_blocks.{i}.attn1.norm_q.bias"] = checkpoint[f"{layer_prefix}.self_attention.q_layernorm.bias"] + + # Key normalization + if f"{layer_prefix}.self_attention.k_layernorm.weight" in checkpoint: + state_dict[f"transformer_blocks.{i}.attn1.norm_k.weight"] = checkpoint[f"{layer_prefix}.self_attention.k_layernorm.weight"] + state_dict[f"transformer_blocks.{i}.attn1.norm_k.bias"] = checkpoint[f"{layer_prefix}.self_attention.k_layernorm.bias"] + + # Cross-attention key normalization + if f"{layer_prefix}.self_attention.k_layernorm_xattn.weight" in checkpoint: + state_dict[f"transformer_blocks.{i}.attn1.norm_k_xattn.weight"] = checkpoint[f"{layer_prefix}.self_attention.k_layernorm_xattn.weight"] + state_dict[f"transformer_blocks.{i}.attn1.norm_k_xattn.bias"] = checkpoint[f"{layer_prefix}.self_attention.k_layernorm_xattn.bias"] + + # Cross-attention query normalization + if f"{layer_prefix}.self_attention.q_layernorm_xattn.weight" in checkpoint: + state_dict[f"transformer_blocks.{i}.attn1.norm_q_xattn.weight"] = checkpoint[f"{layer_prefix}.self_attention.q_layernorm_xattn.weight"] + state_dict[f"transformer_blocks.{i}.attn1.norm_q_xattn.bias"] = checkpoint[f"{layer_prefix}.self_attention.q_layernorm_xattn.bias"] + + # QKV linear projections + if f"{layer_prefix}.self_attention.linear_qkv.q.weight" in checkpoint: + state_dict[f"transformer_blocks.{i}.attn1.to_q.weight"] = checkpoint[f"{layer_prefix}.self_attention.linear_qkv.q.weight"] + + if f"{layer_prefix}.self_attention.linear_qkv.k.weight" in checkpoint: + state_dict[f"transformer_blocks.{i}.attn1.to_k.weight"] = checkpoint[f"{layer_prefix}.self_attention.linear_qkv.k.weight"] + + if f"{layer_prefix}.self_attention.linear_qkv.v.weight" in checkpoint: + state_dict[f"transformer_blocks.{i}.attn1.to_v.weight"] = checkpoint[f"{layer_prefix}.self_attention.linear_qkv.v.weight"] + + if f"{layer_prefix}.self_attention.linear_qkv.qx.weight" in checkpoint: + state_dict[f"transformer_blocks.{i}.attn1.to_q_xattn.weight"] = checkpoint[f"{layer_prefix}.self_attention.linear_qkv.qx.weight"] + + # QKV layer norm + if f"{layer_prefix}.self_attention.linear_qkv.layer_norm.weight" in checkpoint: + state_dict[f"transformer_blocks.{i}.attn1.qkv_norm.weight"] = checkpoint[f"{layer_prefix}.self_attention.linear_qkv.layer_norm.weight"] + state_dict[f"transformer_blocks.{i}.attn1.qkv_norm.bias"] = checkpoint[f"{layer_prefix}.self_attention.linear_qkv.layer_norm.bias"] + + # KV cross-attention + if f"{layer_prefix}.self_attention.linear_kv_xattn.weight" in checkpoint: + state_dict[f"transformer_blocks.{i}.attn1.to_kv_xattn.weight"] = checkpoint[f"{layer_prefix}.self_attention.linear_kv_xattn.weight"] + + # Output projection + if f"{layer_prefix}.self_attention.linear_proj.weight" in checkpoint: + state_dict[f"transformer_blocks.{i}.attn1.to_out.0.weight"] = checkpoint[f"{layer_prefix}.self_attention.linear_proj.weight"] + + # Self-attention post normalization + if f"{layer_prefix}.self_attn_post_norm.weight" in checkpoint: + state_dict[f"transformer_blocks.{i}.norm1.weight"] = checkpoint[f"{layer_prefix}.self_attn_post_norm.weight"] + state_dict[f"transformer_blocks.{i}.norm1.bias"] = checkpoint[f"{layer_prefix}.self_attn_post_norm.bias"] + + # MLP components + # MLP layer norm + if f"{layer_prefix}.mlp.layer_norm.weight" in checkpoint: + state_dict[f"transformer_blocks.{i}.ff.norm.weight"] = checkpoint[f"{layer_prefix}.mlp.layer_norm.weight"] + state_dict[f"transformer_blocks.{i}.ff.norm.bias"] = checkpoint[f"{layer_prefix}.mlp.layer_norm.bias"] + + # MLP FC1 (projection) + if f"{layer_prefix}.mlp.linear_fc1.weight" in checkpoint: + state_dict[f"transformer_blocks.{i}.ff.net.0.proj.weight"] = checkpoint[f"{layer_prefix}.mlp.linear_fc1.weight"] + + # MLP FC2 (projection) + if f"{layer_prefix}.mlp.linear_fc2.weight" in checkpoint: + state_dict[f"transformer_blocks.{i}.ff.net.2.weight"] = checkpoint[f"{layer_prefix}.mlp.linear_fc2.weight"] + + # MLP post normalization + if f"{layer_prefix}.mlp_post_norm.weight" in checkpoint: + state_dict[f"transformer_blocks.{i}.norm2.weight"] = checkpoint[f"{layer_prefix}.mlp_post_norm.weight"] + state_dict[f"transformer_blocks.{i}.norm2.bias"] = checkpoint[f"{layer_prefix}.mlp_post_norm.bias"] + + return state_dict + + +def convert_magi_checkpoint( + magi_checkpoint_path, + vae_checkpoint_path=None, + transformer_checkpoint_path=None, + t5_model_name="google/umt5-xxl", + output_path=None, + dtype=None, +): + """ + Convert MAGI-1 checkpoints to a diffusers pipeline. + + Args: + magi_checkpoint_path: Path to the MAGI-1 checkpoint directory. + vae_checkpoint_path: Optional path to the VAE checkpoint. + transformer_checkpoint_path: Optional path to the transformer checkpoint. + t5_model_name: Name of the T5 model to use. + output_path: Path to save the converted pipeline. + dtype: Optional dtype for the models. + + Returns: + A diffusers MagiPipeline. + """ + # Load or convert the VAE + if vae_checkpoint_path is None: + vae_checkpoint_path = os.path.join(magi_checkpoint_path, "ckpt/vae") + + vae = convert_magi_vae_checkpoint(vae_checkpoint_path, dtype=dtype) + + # Load or convert the transformer + if transformer_checkpoint_path is None: + transformer_checkpoint_path = os.path.join(magi_checkpoint_path, "ckpt/magi/4.5B_base/inference_weight") + + transformer = convert_magi_transformer_checkpoint(transformer_checkpoint_path, dtype=dtype) + + # Load the text encoder and tokenizer + tokenizer = AutoTokenizer.from_pretrained(t5_model_name) + text_encoder = UMT5EncoderModel.from_pretrained(t5_model_name) + + if dtype is not None: + text_encoder = text_encoder.to(dtype=dtype) + + # Create the scheduler + scheduler = FlowMatchEulerDiscreteScheduler() + + # Create the pipeline + pipeline = MagiPipeline( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + ) + + # Save the pipeline if output_path is provided + if output_path is not None: + pipeline.save_pretrained(output_path) + + return pipeline + + +def parse_args(): + parser = argparse.ArgumentParser(description="Convert MAGI-1 checkpoints to diffusers format.") + parser.add_argument( + "--magi_checkpoint_path", + type=str, + required=True, + help="Path to the MAGI-1 checkpoint directory.", + ) + parser.add_argument( + "--vae_checkpoint_path", + type=str, + help="Path to the VAE checkpoint. If not provided, will look in magi_checkpoint_path/ckpt/vae.", + ) + parser.add_argument( + "--transformer_checkpoint_path", + type=str, + help="Path to the transformer checkpoint. If not provided, will look in magi_checkpoint_path/ckpt/magi/4.5B_base.", + ) + parser.add_argument( + "--t5_model_name", + type=str, + default="google/umt5-xxl", + help="Name of the T5 model to use.", + ) + parser.add_argument( + "--output_path", + type=str, + required=True, + help="Path to save the converted pipeline.", + ) + parser.add_argument( + "--dtype", + type=str, + choices=["float32", "float16", "bfloat16"], + default="float32", + help="Data type for the models.", + ) + + return parser.parse_args() + + +def main(): + args = parse_args() + + # Set the dtype + if args.dtype == "float16": + dtype = torch.float16 + elif args.dtype == "bfloat16": + dtype = torch.bfloat16 + else: + dtype = torch.float32 + + print(f"Starting MAGI-1 conversion to diffusers format...") + print(f"Output will be saved to: {args.output_path}") + print(f"Using dtype: {args.dtype}") + + try: + # Convert the VAE + print(f"Converting VAE checkpoint...") + if args.vae_checkpoint_path: + vae_path = args.vae_checkpoint_path + else: + vae_path = os.path.join(args.magi_checkpoint_path, "ckpt/vae/diffusion_pytorch_model.safetensors") + if not os.path.exists(vae_path): + vae_path = os.path.join(args.magi_checkpoint_path, "ckpt/vae") + + print(f"VAE checkpoint path: {vae_path}") + vae = convert_magi_vae_checkpoint(vae_path, dtype=dtype) + print(f"VAE conversion complete.") + + # Convert the transformer + print(f"Converting transformer checkpoint...") + if args.transformer_checkpoint_path: + transformer_path = args.transformer_checkpoint_path + else: + transformer_path = os.path.join(args.magi_checkpoint_path, "ckpt/magi/4.5B_base/inference_weight") + + print(f"Transformer checkpoint path: {transformer_path}") + transformer = convert_magi_transformer_checkpoint(transformer_path, dtype=dtype) + print(f"Transformer conversion complete.") + + # Load the text encoder and tokenizer + print(f"Loading text encoder and tokenizer from {args.t5_model_name}...") + tokenizer = AutoTokenizer.from_pretrained(args.t5_model_name) + text_encoder = UMT5EncoderModel.from_pretrained(args.t5_model_name) + + if dtype is not None: + text_encoder = text_encoder.to(dtype=dtype) + print(f"Text encoder and tokenizer loaded successfully.") + + # Create the scheduler + print(f"Creating scheduler...") + scheduler = FlowMatchEulerDiscreteScheduler() + print(f"Scheduler created successfully.") + + # Create the pipeline + print(f"Creating MAGI pipeline...") + pipeline = MagiPipeline( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + ) + print(f"MAGI pipeline created successfully.") + + # Save the pipeline + print(f"Saving pipeline to {args.output_path}...") + pipeline.save_pretrained(args.output_path) + print(f"Pipeline saved successfully.") + + print(f"Conversion complete! MAGI-1 pipeline saved to {args.output_path}") + + except Exception as e: + print(f"Error during conversion: {str(e)}") + import traceback + traceback.print_exc() + return 1 + + return 0 + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 8723fbca2187..7d70bc5476fa 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -35,6 +35,7 @@ _import_structure["autoencoders.autoencoder_kl_cosmos"] = ["AutoencoderKLCosmos"] _import_structure["autoencoders.autoencoder_kl_hunyuan_video"] = ["AutoencoderKLHunyuanVideo"] _import_structure["autoencoders.autoencoder_kl_ltx"] = ["AutoencoderKLLTXVideo"] + _import_structure["autoencoders.autoencoder_kl_magi"] = ["AutoencoderKLMagi"] _import_structure["autoencoders.autoencoder_kl_magvit"] = ["AutoencoderKLMagvit"] _import_structure["autoencoders.autoencoder_kl_mochi"] = ["AutoencoderKLMochi"] _import_structure["autoencoders.autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"] @@ -84,6 +85,7 @@ _import_structure["transformers.transformer_hunyuan_video_framepack"] = ["HunyuanVideoFramepackTransformer3DModel"] _import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"] _import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"] + _import_structure["transformers.transformer_magi"] = ["MagiTransformer3DModel"] _import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"] _import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformer2DModel"] _import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"] @@ -120,6 +122,7 @@ AutoencoderKLCosmos, AutoencoderKLHunyuanVideo, AutoencoderKLLTXVideo, + AutoencoderKLMagi, AutoencoderKLMagvit, AutoencoderKLMochi, AutoencoderKLTemporalDecoder, @@ -168,6 +171,7 @@ LTXVideoTransformer3DModel, Lumina2Transformer2DModel, LuminaNextDiT2DModel, + MagiTransformer3DModel, MochiTransformer3DModel, OmniGenTransformer2DModel, PixArtTransformer2DModel, diff --git a/src/diffusers/models/autoencoders/__init__.py b/src/diffusers/models/autoencoders/__init__.py index 742d747ae25e..1f4d412f3bbf 100644 --- a/src/diffusers/models/autoencoders/__init__.py +++ b/src/diffusers/models/autoencoders/__init__.py @@ -6,6 +6,7 @@ from .autoencoder_kl_cosmos import AutoencoderKLCosmos from .autoencoder_kl_hunyuan_video import AutoencoderKLHunyuanVideo from .autoencoder_kl_ltx import AutoencoderKLLTXVideo +from .autoencoder_kl_magi import AutoencoderKLMagi from .autoencoder_kl_magvit import AutoencoderKLMagvit from .autoencoder_kl_mochi import AutoencoderKLMochi from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_magi.py b/src/diffusers/models/autoencoders/autoencoder_kl_magi.py new file mode 100644 index 000000000000..5b15298cd238 --- /dev/null +++ b/src/diffusers/models/autoencoders/autoencoder_kl_magi.py @@ -0,0 +1,719 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Dict, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import BaseOutput, logging +from ...utils.accelerate_utils import apply_forward_hook +from ..modeling_utils import ModelMixin + + +logger = logging.get_logger(__name__) + + +@dataclass +class AutoencoderKLMagiOutput(BaseOutput): + """ + Output of AutoencoderKLMagi encoding method. + + Args: + latent_dist (`DiagonalGaussianDistribution`): + Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`. + `DiagonalGaussianDistribution` allows for sampling from the encoded latent vector. + """ + + latent_dist: "DiagonalGaussianDistribution" + + +class DiagonalGaussianDistribution(object): + """ + Diagonal Gaussian distribution with mean and logvar. + """ + + def __init__(self, parameters, deterministic=False): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like(self.mean) + + def sample(self, generator=None): + x = self.mean + self.std * torch.randn(self.mean.shape, device=self.parameters.device, generator=generator) + return x + + def kl(self, other=None): + if self.deterministic: + return torch.Tensor([0.0]) + else: + if other is None: + return 0.5 * torch.sum(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2, 3, 4]) + else: + return 0.5 * torch.sum( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var + - 1.0 + - self.logvar + + other.logvar, + dim=[1, 2, 3, 4], + ) + + def nll(self, sample, dims=[1, 2, 3, 4]): + if self.deterministic: + return torch.Tensor([0.0]) + logtwopi = torch.log(torch.tensor(2.0 * 3.141592653589793)) + return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims) + + def mode(self): + return self.mean + + +class ManualLayerNorm(nn.Module): + """ + Manual implementation of LayerNorm for better compatibility. + """ + def __init__(self, normalized_shape, eps=1e-5): + super().__init__() + self.normalized_shape = normalized_shape + self.eps = eps + + def forward(self, x): + mean = x.mean(dim=-1, keepdim=True) + std = x.std(dim=-1, keepdim=True, unbiased=False) + x_normalized = (x - mean) / (std + self.eps) + return x_normalized + + +class Mlp(nn.Module): + """ + MLP module used in the transformer architecture. + """ + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + """ + Multi-head attention module. + """ + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.0, proj_drop=0.0): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Block(nn.Module): + """ + Transformer block with attention and MLP. + """ + def __init__( + self, + dim, + num_heads, + mlp_ratio=4.0, + qkv_bias=False, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + ): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + ) + self.drop_path = nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x): + x = x + self.drop_path(self.attn(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class PatchEmbed(nn.Module): + """ + Image to Patch Embedding for 3D data. + """ + def __init__(self, video_size=224, video_length=16, patch_size=16, patch_length=1, in_chans=3, embed_dim=768): + super().__init__() + self.video_size = video_size + self.video_length = video_length + self.patch_size = patch_size + self.patch_length = patch_length + + self.grid_size = video_size // patch_size + self.grid_length = video_length // patch_length + self.num_patches = self.grid_length * self.grid_size * self.grid_size + + self.proj = nn.Conv3d( + in_chans, embed_dim, + kernel_size=(patch_length, patch_size, patch_size), + stride=(patch_length, patch_size, patch_size) + ) + + def forward(self, x): + B, C, T, H, W = x.shape + assert H == self.video_size and W == self.video_size, \ + f"Input image size ({H}*{W}) doesn't match model ({self.video_size}*{self.video_size})." + assert T == self.video_length, \ + f"Input video length ({T}) doesn't match model ({self.video_length})." + + x = self.proj(x).flatten(2).transpose(1, 2) + return x + + +class ViTEncoder(nn.Module): + """ + Vision Transformer Encoder for MAGI-1 VAE. + """ + def __init__( + self, + video_size=256, + video_length=16, + patch_size=8, + patch_length=4, + in_chans=3, + z_chans=4, + double_z=True, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=False, + qk_scale=None, + drop_rate=0.0, + attn_drop_rate=0.0, + norm_layer=nn.LayerNorm, + with_cls_token=True, + ): + super().__init__() + self.video_size = video_size + self.video_length = video_length + self.patch_size = patch_size + self.patch_length = patch_length + self.z_chans = z_chans + self.double_z = double_z + self.with_cls_token = with_cls_token + + # Patch embedding + self.patch_embed = PatchEmbed( + video_size=video_size, + video_length=video_length, + patch_size=patch_size, + patch_length=patch_length, + in_chans=in_chans, + embed_dim=embed_dim, + ) + + num_patches = self.patch_embed.num_patches + + # Class token and position embedding + if with_cls_token: + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.num_tokens = 1 + else: + self.num_tokens = 0 + + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) + self.pos_drop = nn.Dropout(p=drop_rate) + + # Transformer blocks + dpr = [x.item() for x in torch.linspace(0, 0.0, depth)] # stochastic depth decay rule + self.blocks = nn.ModuleList([ + Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer, + ) + for i in range(depth) + ]) + + self.norm = norm_layer(embed_dim) + + # Projection to latent space + self.proj = nn.Linear(embed_dim, z_chans * 2 if double_z else z_chans) + + # Initialize weights + self._init_weights() + + def _init_weights(self): + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d) + w = self.patch_embed.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + + # Initialize position embedding + nn.init.normal_(self.pos_embed, std=0.02) + + # Initialize cls token if used + if self.with_cls_token: + nn.init.normal_(self.cls_token, std=0.02) + + def forward(self, x): + # Patch embedding + x = self.patch_embed(x) + + # Add class token if used + if self.with_cls_token: + cls_tokens = self.cls_token.expand(x.shape[0], -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + + # Add position embedding and apply dropout + x = x + self.pos_embed + x = self.pos_drop(x) + + # Apply transformer blocks + for blk in self.blocks: + x = blk(x) + + x = self.norm(x) + + # Use class token for output if available, otherwise use patch tokens + if self.with_cls_token: + x = x[:, 0] + else: + x = x.mean(dim=1) + + # Project to latent space + x = self.proj(x) + + # Reshape to [B, C, T, H, W] + B = x.shape[0] + T = self.video_length // self.patch_length + H = self.video_size // self.patch_size + W = self.video_size // self.patch_size + C = self.z_chans * 2 if self.double_z else self.z_chans + + x = x.view(B, C, T, H, W) + + return x + + +class ViTDecoder(nn.Module): + """ + Vision Transformer Decoder for MAGI-1 VAE. + """ + def __init__( + self, + video_size=256, + video_length=16, + patch_size=8, + patch_length=4, + in_chans=3, + z_chans=4, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=False, + qk_scale=None, + drop_rate=0.0, + attn_drop_rate=0.0, + norm_layer=nn.LayerNorm, + with_cls_token=True, + ): + super().__init__() + self.video_size = video_size + self.video_length = video_length + self.patch_size = patch_size + self.patch_length = patch_length + self.z_chans = z_chans + self.with_cls_token = with_cls_token + + # Calculate patch dimensions + self.grid_size = video_size // patch_size + self.grid_length = video_length // patch_length + num_patches = self.grid_length * self.grid_size * self.grid_size + + # Input projection from latent space to embedding dimension + self.proj_in = nn.Linear(z_chans, embed_dim) + + # Class token and position embedding + if with_cls_token: + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.num_tokens = 1 + else: + self.num_tokens = 0 + + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) + self.pos_drop = nn.Dropout(p=drop_rate) + + # Transformer blocks + dpr = [x.item() for x in torch.linspace(0, 0.0, depth)] # stochastic depth decay rule + self.blocks = nn.ModuleList([ + Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer, + ) + for i in range(depth) + ]) + + self.norm = norm_layer(embed_dim) + + # Output projection to image space + self.proj_out = nn.ConvTranspose3d( + embed_dim, + in_chans, + kernel_size=(patch_length, patch_size, patch_size), + stride=(patch_length, patch_size, patch_size) + ) + + # Initialize weights + self._init_weights() + + def _init_weights(self): + # Initialize position embedding + nn.init.normal_(self.pos_embed, std=0.02) + + # Initialize cls token if used + if self.with_cls_token: + nn.init.normal_(self.cls_token, std=0.02) + + # Initialize output projection + w = self.proj_out.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + + def forward(self, z): + # Get dimensions + B, C, T, H, W = z.shape + + # Flatten spatial dimensions and transpose to [B, T*H*W, C] + z = z.flatten(2).transpose(1, 2) + + # Project to embedding dimension + x = self.proj_in(z) + + # Add class token if used + if self.with_cls_token: + cls_tokens = self.cls_token.expand(x.shape[0], -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + + # Add position embedding and apply dropout + x = x + self.pos_embed + x = self.pos_drop(x) + + # Apply transformer blocks + for blk in self.blocks: + x = blk(x) + + x = self.norm(x) + + # Remove class token if used + if self.with_cls_token: + x = x[:, 1:] + + # Reshape to [B, T, H, W, C] + x = x.reshape(B, T, H, W, -1) + + # Transpose to [B, C, T, H, W] + x = x.permute(0, 4, 1, 2, 3) + + # Project to image space + x = self.proj_out(x) + + return x + + +class AutoencoderKLMagi(ModelMixin, ConfigMixin): + """ + Variational Autoencoder (VAE) model with KL loss for MAGI-1. + + This model inherits from [`ModelMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic methods + implemented for all models (downloading, saving, loading, etc.) + + Parameters: + in_channels (`int`, *optional*, defaults to 3): Number of channels in the input image. + out_channels (`int`, *optional*, defaults to 3): Number of channels in the output. + down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock3D",)`): + Tuple of downsample block types. + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock3D",)`): + Tuple of upsample block types. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`): + Tuple of block output channels. + layers_per_block (`int`, *optional*, defaults to 1): Number of layers per block. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + latent_channels (`int`, *optional*, defaults to 8): Number of channels in the latent space. + norm_num_groups (`int`, *optional*, defaults to 32): Number of groups for the normalization. + scaling_factor (`float`, *optional*, defaults to 0.18215): + The component-wise standard deviation of the trained latent space computed using the first batch of the + training set. This is used to scale the latent space to have unit variance when training the diffusion + model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the + diffusion model. When decoding, the latents are scaled back to the original scale with the formula: + `z = 1 / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution + Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. + temporal_downsample_factor (`Tuple[int]`, *optional*, defaults to (1, 2, 1, 1)): + Tuple of temporal downsampling factors for each block. + """ + + @register_to_config + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + down_block_types: Tuple[str] = ("DownEncoderBlock3D",), + up_block_types: Tuple[str] = ("UpDecoderBlock3D",), + block_out_channels: Tuple[int] = (64,), + layers_per_block: int = 1, + act_fn: str = "silu", + latent_channels: int = 8, + norm_num_groups: int = 32, + scaling_factor: float = 0.18215, + temporal_downsample_factor: Tuple[int] = (1, 2, 1, 1), + video_size: int = 256, + video_length: int = 16, + patch_size: int = 8, + patch_length: int = 4, + embed_dim: int = 768, + depth: int = 12, + num_heads: int = 12, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + qk_scale: Optional[float] = None, + drop_rate: float = 0.0, + attn_drop_rate: float = 0.0, + with_cls_token: bool = True, + ): + super().__init__() + + # Save important parameters + self.latent_channels = latent_channels + self.scaling_factor = scaling_factor + self.temperal_downsample = temporal_downsample_factor + + # Create encoder and decoder + self.encoder = ViTEncoder( + video_size=video_size, + video_length=video_length, + patch_size=patch_size, + patch_length=patch_length, + in_chans=in_channels, + z_chans=latent_channels, + embed_dim=embed_dim, + depth=depth, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + norm_layer=nn.LayerNorm, + with_cls_token=with_cls_token, + double_z=True, + ) + + self.decoder = ViTDecoder( + video_size=video_size, + video_length=video_length, + patch_size=patch_size, + patch_length=patch_length, + in_chans=out_channels, + z_chans=latent_channels, + embed_dim=embed_dim, + depth=depth, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + norm_layer=nn.LayerNorm, + with_cls_token=with_cls_token, + ) + + # Enable tiling + self._enable_tiling = False + self._tile_sample_min_size = None + self._tile_sample_stride = None + + @property + def spatial_downsample_factor(self) -> int: + """ + Returns the spatial downsample factor for the VAE. + """ + return self.encoder.patch_size # MAGI-1 uses patch_size as spatial downsampling + + @property + def temporal_downsample_factor(self) -> int: + """ + Returns the temporal downsample factor for the VAE. + """ + return self.encoder.patch_length # MAGI-1 uses patch_length as temporal downsampling + + @apply_forward_hook + def encode( + self, x: torch.FloatTensor, return_dict: bool = True + ) -> Union[AutoencoderKLMagiOutput, Tuple[DiagonalGaussianDistribution]]: + """ + Encode a batch of videos. + + Args: + x (`torch.FloatTensor`): Input batch of videos. + return_dict (`bool`, *optional*, defaults to `True`): Whether to return a dictionary or tuple. + + Returns: + `AutoencoderKLMagiOutput` or `tuple`: + If return_dict is True, returns an `AutoencoderKLMagiOutput` object, otherwise returns a tuple. + """ + h = self.encoder(x) + posterior = DiagonalGaussianDistribution(h) + + if not return_dict: + return (posterior,) + + return AutoencoderKLMagiOutput(latent_dist=posterior) + + @apply_forward_hook + def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[torch.FloatTensor, BaseOutput]: + """ + Decode a batch of latent vectors. + + Args: + z (`torch.FloatTensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): Whether to return a dictionary or tuple. + + Returns: + `BaseOutput` or `torch.FloatTensor`: + If return_dict is True, returns a `BaseOutput` object, otherwise returns the decoded tensor. + """ + dec = self.decoder(z) + + if not return_dict: + return (dec,) + + return BaseOutput(sample=dec) + + def enable_tiling( + self, + tile_sample_min_height: Optional[int] = None, + tile_sample_min_width: Optional[int] = None, + ) -> None: + """ + Enable tiled processing for large videos. + + Args: + tile_sample_min_height (`int`, *optional*): Minimum tile height. + tile_sample_min_width (`int`, *optional*): Minimum tile width. + """ + self._enable_tiling = True + self._tile_sample_min_size = (tile_sample_min_height, tile_sample_min_width) + + def disable_tiling(self) -> None: + """ + Disable tiled processing. + """ + self._enable_tiling = False + self._tile_sample_min_size = None + self._tile_sample_stride = None + + def forward( + self, + sample: torch.FloatTensor, + sample_posterior: bool = False, + return_dict: bool = True, + generator: Optional[torch.Generator] = None, + ) -> Union[BaseOutput, Tuple]: + """ + Forward pass of the model. + + Args: + sample (`torch.FloatTensor`): Input batch of videos. + sample_posterior (`bool`, *optional*, defaults to `False`): + Whether to sample from the posterior distribution. + return_dict (`bool`, *optional*, defaults to `True`): Whether to return a dictionary or tuple. + generator (`torch.Generator`, *optional*): Generator for random sampling. + + Returns: + `BaseOutput` or `tuple`: + If return_dict is True, returns a `BaseOutput` object, otherwise returns a tuple. + """ + posterior = self.encode(sample, return_dict=True).latent_dist + + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + + # Scale latents by the scaling factor + z = self.scaling_factor * z + + # Decode the latents + dec = self.decode(z, return_dict=return_dict) + + if not return_dict: + return (dec,) + + return BaseOutput(sample=dec.sample) \ No newline at end of file diff --git a/src/diffusers/models/transformers/transformer_magi.py b/src/diffusers/models/transformers/transformer_magi.py new file mode 100644 index 000000000000..4e527cecf228 --- /dev/null +++ b/src/diffusers/models/transformers/transformer_magi.py @@ -0,0 +1,668 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import BaseOutput, logging +from ..embeddings import TimestepEmbedding, Timesteps +from ..modeling_utils import ModelMixin + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class MagiTransformerOutput(BaseOutput): + """ + The output of [`MagiTransformer3DModel`]. + + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch_size, num_channels, frames, height, width)`): + The hidden states output from the last layer of the model. + """ + + sample: torch.FloatTensor + + +class MagiAttention(nn.Module): + """ + A cross attention layer for MAGI-1. + + This implements the specialized attention mechanism from the MAGI-1 model, + including query/key normalization and proper handling of rotary embeddings. + """ + + def __init__( + self, + query_dim: int, + cross_attention_dim: Optional[int] = None, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + upcast_attention: bool = False, + upcast_softmax: bool = False, + added_kv_proj_dim: Optional[int] = None, + norm_num_groups: Optional[int] = None, + ): + super().__init__() + inner_dim = dim_head * heads + cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim + + self.upcast_attention = upcast_attention + self.upcast_softmax = upcast_softmax + + self.scale = dim_head**-0.5 + self.heads = heads + self.dim_head = dim_head + + # Projection layers + self.to_q = nn.Linear(query_dim, inner_dim, bias=bias) + self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias) + self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias) + + # Normalization layers for query and key - important part of MAGI-1's attention mechanism + self.norm_q = nn.LayerNorm(dim_head, eps=1e-5) + self.norm_k = nn.LayerNorm(dim_head, eps=1e-5) + + self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) + + def reshape_heads_to_batch_dim(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size) + return tensor + + def reshape_batch_dim_to_heads(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) + return tensor + + def forward( + self, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + rotary_pos_emb=None, + **cross_attention_kwargs, + ): + batch_size, sequence_length, _ = hidden_states.shape + + encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + + # Project to query, key, value + query = self.to_q(hidden_states) + key = self.to_k(encoder_hidden_states) + value = self.to_v(encoder_hidden_states) + + # Reshape for multi-head attention + query = query.reshape(batch_size, sequence_length, self.heads, self.dim_head) + key = key.reshape(batch_size, -1, self.heads, self.dim_head) + value = value.reshape(batch_size, -1, self.heads, self.dim_head) + + # Apply layer normalization to query and key (as in MAGI-1) + # Convert to float32 for better numerical stability during normalization + orig_dtype = query.dtype + query = self.norm_q(query.float()).to(orig_dtype) + key = self.norm_k(key.float()).to(orig_dtype) + + # Transpose for attention + # [batch_size, seq_len, heads, dim_head] -> [batch_size, heads, seq_len, dim_head] + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + + # Apply rotary position embeddings if provided + if rotary_pos_emb is not None: + # Apply rotary embeddings using the same method as in MAGI-1 + def apply_rotary_emb(hidden_states, freqs): + dtype = torch.float32 if hidden_states.device.type == "mps" else torch.float64 + # Convert to complex numbers + x_complex = torch.view_as_complex(hidden_states.to(dtype).unflatten(-1, (-1, 2))) + # Apply rotation in complex space + x_rotated = x_complex * freqs + # Convert back to real + x_out = torch.view_as_real(x_rotated).flatten(-2) + return x_out.type_as(hidden_states) + + # Apply rotary embeddings + query = apply_rotary_emb(query, rotary_pos_emb) + key = apply_rotary_emb(key, rotary_pos_emb) + + # Use scaled_dot_product_attention if available (PyTorch 2.0+) + if hasattr(F, "scaled_dot_product_attention"): + # Apply scaled dot product attention + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + # [batch_size, heads, seq_len, dim_head] -> [batch_size, seq_len, heads*dim_head] + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, sequence_length, -1) + else: + # Manual implementation of attention + # Reshape for bmm + query = query.reshape(batch_size * self.heads, sequence_length, self.dim_head) + key = key.reshape(batch_size * self.heads, -1, self.dim_head) + value = value.reshape(batch_size * self.heads, -1, self.dim_head) + + # Compute attention scores + if self.upcast_attention: + query = query.float() + key = key.float() + + attention_scores = torch.bmm(query, key.transpose(-1, -2)) * self.scale + + if attention_mask is not None: + attention_scores = attention_scores + attention_mask + + if self.upcast_softmax: + attention_scores = attention_scores.float() + + attention_probs = F.softmax(attention_scores, dim=-1) + attention_probs = attention_probs.to(value.dtype) + + # Compute output + hidden_states = torch.bmm(attention_probs, value) + hidden_states = hidden_states.reshape(batch_size, self.heads, sequence_length, self.dim_head) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, sequence_length, -1) + + # Project to output + hidden_states = self.to_out(hidden_states) + + return hidden_states + + +class MagiTransformerBlock(nn.Module): + """ + A transformer block for MAGI-1. + + This is a simplified version of the MAGI-1 transformer block. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout: float = 0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "gelu", + num_embeds_ada_norm: Optional[int] = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + norm_elementwise_affine: bool = True, + norm_type: str = "layer_norm", + norm_eps: float = 1e-5, + final_dropout: bool = False, + ): + super().__init__() + self.only_cross_attention = only_cross_attention + + # Self-attention + self.norm1 = nn.LayerNorm(dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine) + self.attn1 = MagiAttention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + ) + + # Cross-attention + if cross_attention_dim is not None: + self.norm2 = nn.LayerNorm(dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine) + self.attn2 = MagiAttention( + query_dim=dim, + cross_attention_dim=cross_attention_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + ) + else: + self.norm2 = None + self.attn2 = None + + # Feed-forward + self.norm3 = nn.LayerNorm(dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine) + + # Choose activation function + if activation_fn == "gelu": + self.ff = nn.Sequential( + nn.Linear(dim, dim * 4), + nn.GELU(), + nn.Dropout(dropout) if final_dropout else nn.Identity(), + nn.Linear(dim * 4, dim), + ) + elif activation_fn == "gelu-approximate": + self.ff = nn.Sequential( + nn.Linear(dim, dim * 4), + nn.GELU(approximate="tanh"), + nn.Dropout(dropout) if final_dropout else nn.Identity(), + nn.Linear(dim * 4, dim), + ) + elif activation_fn == "silu": + self.ff = nn.Sequential( + nn.Linear(dim, dim * 4), + nn.SiLU(), + nn.Dropout(dropout) if final_dropout else nn.Identity(), + nn.Linear(dim * 4, dim), + ) + else: + raise ValueError(f"Unsupported activation function: {activation_fn}") + + self.final_dropout = nn.Dropout(dropout) if final_dropout else nn.Identity() + + def forward( + self, + hidden_states, + encoder_hidden_states=None, + timestep=None, + attention_mask=None, + cross_attention_kwargs=None, + rotary_pos_emb=None, + **kwargs, + ): + # Self-attention + norm_hidden_states = self.norm1(hidden_states) + + if self.only_cross_attention: + hidden_states = hidden_states + self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + rotary_pos_emb=rotary_pos_emb, + ) + else: + hidden_states = hidden_states + self.attn1( + norm_hidden_states, + attention_mask=attention_mask, + rotary_pos_emb=rotary_pos_emb, + ) + + # Cross-attention + if self.attn2 is not None: + norm_hidden_states = self.norm2(hidden_states) + hidden_states = hidden_states + self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + rotary_pos_emb=rotary_pos_emb, + **(cross_attention_kwargs if cross_attention_kwargs is not None else {}), + ) + + # Feed-forward + norm_hidden_states = self.norm3(hidden_states) + ff_output = self.ff(norm_hidden_states) + hidden_states = hidden_states + self.final_dropout(ff_output) + + return hidden_states + + +class LearnableRotaryEmbedding(nn.Module): + """ + Learnable rotary position embeddings similar to the one used in MAGI-1. + + This implementation is based on MAGI-1's LearnableRotaryEmbeddingCat class, + which creates rotary embeddings for 3D data (frames, height, width). + """ + def __init__( + self, + dim: int, + max_seq_len: int = 1024, + temperature: float = 10000.0, + in_pixels: bool = True, + linear_bands: bool = False, + ): + super().__init__() + self.dim = dim + self.max_seq_len = max_seq_len + self.temperature = temperature + self.in_pixels = in_pixels + self.linear_bands = linear_bands + + # Initialize frequency bands + self.register_buffer("freqs", self._get_default_bands()) + + def _get_default_bands(self): + """Generate default frequency bands""" + if self.linear_bands: + # Linear spacing + bands = torch.linspace(1.0, self.max_seq_len / 2, self.dim // 2, dtype=torch.float32) + else: + # Log spacing (as in original RoPE) + bands = 1.0 / (self.temperature ** (torch.arange(0, self.dim // 2, dtype=torch.float32) / (self.dim // 2))) + + return bands * torch.pi + + def get_embed(self, shape: List[int]) -> torch.Tensor: + """ + Generate rotary position embeddings for the given shape. + + Args: + shape: List of dimensions [frames, height, width] + + Returns: + Rotary position embeddings (sin and cos components) + """ + frames, height, width = shape + seq_len = frames * height * width + + # Generate position indices + if self.in_pixels: + # Normalize positions to [-1, 1] + t = torch.linspace(-1.0, 1.0, steps=frames, device=self.freqs.device) + h = torch.linspace(-1.0, 1.0, steps=height, device=self.freqs.device) + w = torch.linspace(-1.0, 1.0, steps=width, device=self.freqs.device) + else: + # Use integer positions + t = torch.arange(frames, device=self.freqs.device, dtype=torch.float32) + h = torch.arange(height, device=self.freqs.device, dtype=torch.float32) + w = torch.arange(width, device=self.freqs.device, dtype=torch.float32) + + # Center spatial dimensions (as in MAGI-1) + h = h - (height - 1) / 2 + w = w - (width - 1) / 2 + + # Create position grid + grid = torch.stack(torch.meshgrid(t, h, w, indexing="ij"), dim=-1) + grid = grid.reshape(-1, 3) # [seq_len, 3] + + # Get frequency bands + freqs = self.freqs.to(grid.device) + + # Compute embeddings for each dimension + # Temporal dimension + t_emb = torch.outer(grid[:, 0], freqs[:self.dim//6]) + t_sin = torch.sin(t_emb) + t_cos = torch.cos(t_emb) + + # Height dimension + h_emb = torch.outer(grid[:, 1], freqs[:self.dim//6]) + h_sin = torch.sin(h_emb) + h_cos = torch.cos(h_emb) + + # Width dimension + w_emb = torch.outer(grid[:, 2], freqs[:self.dim//6]) + w_sin = torch.sin(w_emb) + w_cos = torch.cos(w_emb) + + # Concatenate all embeddings + sin_emb = torch.cat([t_sin, h_sin, w_sin], dim=-1) + cos_emb = torch.cat([t_cos, h_cos, w_cos], dim=-1) + + # Pad or trim to match expected dimension + target_dim = self.dim // 2 + if sin_emb.shape[1] < target_dim: + pad_size = target_dim - sin_emb.shape[1] + sin_emb = F.pad(sin_emb, (0, pad_size)) + cos_emb = F.pad(cos_emb, (0, pad_size)) + elif sin_emb.shape[1] > target_dim: + sin_emb = sin_emb[:, :target_dim] + cos_emb = cos_emb[:, :target_dim] + + # Combine sin and cos for rotary embeddings + return torch.cat([cos_emb.unsqueeze(-1), sin_emb.unsqueeze(-1)], dim=-1).reshape(seq_len, target_dim, 2) + + +class MagiTransformer3DModel(ModelMixin, ConfigMixin): + """ + Transformer model for MAGI-1. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods implemented + for all models (downloading, saving, loading, etc.) + + Parameters: + sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): + Height and width of input/output sample. + in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample. + out_channels (`int`, *optional*, defaults to 4): Number of channels in the output. + num_layers (`int`, *optional*, defaults to 24): Number of transformer blocks. + num_attention_heads (`int`, *optional*, defaults to 16): Number of attention heads. + attention_head_dim (`int`, *optional*, defaults to 64): Dimension of attention heads. + cross_attention_dim (`int`, *optional*, defaults to 1280): Dimension of cross-attention conditioning. + activation_fn (`str`, *optional*, defaults to `"gelu"`): Activation function. + norm_elementwise_affine (`bool`, *optional*, defaults to `True`): + Whether to use learnable elementwise affine parameters for normalization. + norm_type (`str`, *optional*, defaults to `"layer_norm"`): Type of normalization. + norm_eps (`float`, *optional*, defaults to 1e-5): Epsilon for normalization. + attention_bias (`bool`, *optional*, defaults to `False`): Whether to use bias in attention. + num_embeds_ada_norm (`int`, *optional*, defaults to `None`): Number of embeddings for AdaLayerNorm. + only_cross_attention (`bool`, *optional*, defaults to `False`): Whether to only use cross-attention. + upcast_attention (`bool`, *optional*, defaults to `False`): Whether to upcast attention operations. + dropout (`float`, *optional*, defaults to 0.0): Dropout probability. + """ + + @register_to_config + def __init__( + self, + sample_size: Optional[Union[int, Tuple[int, int]]] = None, + in_channels: int = 4, + out_channels: int = 4, + num_layers: int = 24, + num_attention_heads: int = 16, + attention_head_dim: int = 64, + cross_attention_dim: int = 1280, + activation_fn: str = "gelu", + norm_elementwise_affine: bool = True, + norm_type: str = "layer_norm", + norm_eps: float = 1e-5, + attention_bias: bool = False, + num_embeds_ada_norm: Optional[int] = None, + only_cross_attention: bool = False, + upcast_attention: bool = False, + dropout: float = 0.0, + patch_size: Tuple[int, int, int] = (1, 1, 1), + max_seq_len: int = 1024, + ): + super().__init__() + + self.sample_size = sample_size + self.patch_size = patch_size + self.max_seq_len = max_seq_len + + # Input embedding + self.in_channels = in_channels + time_embed_dim = attention_head_dim * num_attention_heads + self.time_proj = Timesteps(time_embed_dim, flip_sin_to_cos=True, downscale_freq_shift=0) + self.time_embedding = TimestepEmbedding(time_embed_dim, time_embed_dim) + + # Input projection + self.input_proj = nn.Conv3d( + in_channels, + time_embed_dim, + kernel_size=patch_size, + stride=patch_size + ) + + # Rotary position embeddings + self.rotary_embedding = LearnableRotaryEmbedding( + dim=attention_head_dim, + max_seq_len=max_seq_len, + temperature=10000.0, + ) + + # Transformer blocks + self.transformer_blocks = nn.ModuleList( + [ + MagiTransformerBlock( + dim=time_embed_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + norm_elementwise_affine=norm_elementwise_affine, + norm_type=norm_type, + norm_eps=norm_eps, + ) + for _ in range(num_layers) + ] + ) + + # Output projection + self.out_channels = out_channels + self.output_proj = nn.Conv3d( + time_embed_dim, + out_channels, + kernel_size=1 + ) + + self.gradient_checkpointing = False + + def set_attention_slice(self, slice_size): + """ + Enable sliced attention computation. + + Args: + slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): + When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. + If `"max"`, maximum amount of memory is saved by running only one slice at a time. + If a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, + `attention_head_dim` must be a multiple of `slice_size`. + """ + logger.warning( + "Calling `set_attention_slice` is deprecated and will be removed in a future version. Use" + " `set_attention_processor` instead." + ) + + # Not implemented for MAGI-1 yet + pass + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, MagiTransformerBlock): + module.gradient_checkpointing = value + + def forward( + self, + hidden_states: torch.Tensor, + timesteps: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, + ) -> Union[MagiTransformerOutput, Tuple]: + """ + Forward pass of the model. + + Args: + hidden_states (`torch.Tensor`): + Input tensor of shape `(batch_size, in_channels, frames, height, width)`. + timesteps (`torch.Tensor`, *optional*): + Timesteps tensor of shape `(batch_size,)`. + encoder_hidden_states (`torch.Tensor`, *optional*): + Encoder hidden states for cross-attention. + attention_mask (`torch.Tensor`, *optional*): + Attention mask for cross-attention. + cross_attention_kwargs (`dict`, *optional*): + Additional arguments for cross-attention. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a dictionary. + + Returns: + `MagiTransformerOutput` or `tuple`: + If `return_dict` is `True`, a `MagiTransformerOutput` is returned, otherwise a tuple + `(sample,)` is returned where `sample` is the output tensor. + """ + # 1. Input processing + batch_size, channels, frames, height, width = hidden_states.shape + residual = hidden_states + + # 2. Time embedding + if timesteps is not None: + timesteps = timesteps.to(hidden_states.device) + time_embeds = self.time_proj(timesteps) + time_embeds = self.time_embedding(time_embeds) + else: + time_embeds = None + + # 3. Project input + hidden_states = self.input_proj(hidden_states) + + # Get patched dimensions + p_t, p_h, p_w = self.patch_size + patched_frames = frames // p_t + patched_height = height // p_h + patched_width = width // p_w + + # 4. Reshape for transformer blocks + hidden_states = hidden_states.permute(0, 2, 3, 4, 1) # [B, C, F, H, W] -> [B, F, H, W, C] + hidden_states = hidden_states.reshape(batch_size, patched_frames * patched_height * patched_width, -1) # [B, F*H*W, C] + + # 5. Add time embeddings if provided + if time_embeds is not None: + time_embeds = time_embeds.unsqueeze(1) # [B, 1, C] + hidden_states = hidden_states + time_embeds + + # 6. Generate rotary position embeddings + rotary_pos_emb = self.rotary_embedding.get_embed([patched_frames, patched_height, patched_width]) + rotary_pos_emb = rotary_pos_emb.to(hidden_states.device) + + # Convert to complex representation for the attention mechanism + # This matches MAGI-1's approach to applying rotary embeddings + cos_emb = rotary_pos_emb[..., 0] + sin_emb = rotary_pos_emb[..., 1] + rotary_pos_emb = torch.complex(cos_emb, sin_emb).unsqueeze(0) # [1, seq_len, dim//2] + + # 7. Process with transformer blocks + for block in self.transformer_blocks: + if self.gradient_checkpointing and self.training: + hidden_states = torch.utils.checkpoint.checkpoint( + block, + hidden_states, + encoder_hidden_states, + timesteps, + attention_mask, + None, # cross_attention_kwargs + rotary_pos_emb, + ) + else: + hidden_states = block( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + timestep=timesteps, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + rotary_pos_emb=rotary_pos_emb, + ) + + # 8. Reshape back to video format + hidden_states = hidden_states.reshape(batch_size, patched_frames, patched_height, patched_width, -1) + hidden_states = hidden_states.permute(0, 4, 1, 2, 3) # [B, F, H, W, C] -> [B, C, F, H, W] + + # 9. Project output + hidden_states = self.output_proj(hidden_states) + + # 10. Add residual connection + hidden_states = hidden_states + residual + + if not return_dict: + return (hidden_states,) + + return MagiTransformerOutput(sample=hidden_states) \ No newline at end of file diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 268e5c2a8c39..c7d9a2f29a1f 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -28,6 +28,7 @@ "deprecated": [], "latent_diffusion": [], "ledits_pp": [], + "magi": [], "marigold": [], "pag": [], "stable_diffusion": [], @@ -283,6 +284,7 @@ "MarigoldNormalsPipeline", ] ) + _import_structure["magi"] = ["MagiPipeline", "MagiImageToVideoPipeline", "MagiVideoToVideoPipeline"] _import_structure["mochi"] = ["MochiPipeline"] _import_structure["musicldm"] = ["MusicLDMPipeline"] _import_structure["omnigen"] = ["OmniGenPipeline"] @@ -655,6 +657,7 @@ MarigoldIntrinsicsPipeline, MarigoldNormalsPipeline, ) + from .magi import MagiPipeline, MagiImageToVideoPipeline, MagiVideoToVideoPipeline from .mochi import MochiPipeline from .musicldm import MusicLDMPipeline from .omnigen import OmniGenPipeline diff --git a/src/diffusers/pipelines/magi/__init__.py b/src/diffusers/pipelines/magi/__init__.py new file mode 100644 index 000000000000..4fb6cc376987 --- /dev/null +++ b/src/diffusers/pipelines/magi/__init__.py @@ -0,0 +1,51 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +from ....utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + is_torch_available, + is_transformers_available, +) + + +_import_structure = {} + +if is_torch_available() and is_transformers_available(): + _import_structure["pipeline_magi"] = ["MagiPipeline"] + _import_structure["pipeline_magi_i2v"] = ["MagiImageToVideoPipeline"] + _import_structure["pipeline_magi_v2v"] = ["MagiVideoToVideoPipeline"] + _import_structure["pipeline_output"] = ["MagiPipelineOutput"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if is_torch_available() and is_transformers_available(): + from .pipeline_magi import MagiPipeline + from .pipeline_magi_i2v import MagiImageToVideoPipeline + from .pipeline_magi_v2v import MagiVideoToVideoPipeline + from .pipeline_output import MagiPipelineOutput + except OptionalDependencyNotAvailable: + pass +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) \ No newline at end of file diff --git a/src/diffusers/pipelines/magi/pipeline_magi.py b/src/diffusers/pipelines/magi/pipeline_magi.py new file mode 100644 index 000000000000..92805fec11e0 --- /dev/null +++ b/src/diffusers/pipelines/magi/pipeline_magi.py @@ -0,0 +1,641 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import math +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +from transformers import UMT5EncoderModel, AutoTokenizer + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...models import AutoencoderKLMagi, MagiTransformer3DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import ( + is_torch_xla_available, + logging, + replace_example_docstring, + randn_tensor, +) +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import MagiPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers import MagiPipeline + >>> from diffusers.utils import export_to_video + + >>> # Text-to-video generation + >>> pipeline = MagiPipeline.from_pretrained("sand-ai/MAGI-1-4.5B", torch_dtype=torch.float16) + >>> pipeline = pipeline.to("cuda") + >>> prompt = "A cat and a dog playing in a garden. The cat is chasing a butterfly while the dog is digging a hole." + >>> output = pipeline( + ... prompt=prompt, + ... num_frames=24, + ... height=720, + ... width=720, + ... t_schedule_func="sd3", + ... ).frames[0] + >>> export_to_video(output, "magi_output.mp4", fps=8) + ``` +""" + + +class MagiPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-video generation using MAGI-1. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + tokenizer (`AutoTokenizer`): + Tokenizer for the text encoder. + text_encoder (`UMT5EncoderModel`): + Text encoder for conditioning. + transformer (`MagiTransformer3DModel`): + Conditional Transformer to denoise the latent video. + vae (`AutoencoderKLMagi`): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + scheduler (`FlowMatchEulerDiscreteScheduler`): + A scheduler to be used in combination with `transformer` to denoise the encoded video latents. + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + tokenizer: AutoTokenizer, + text_encoder: UMT5EncoderModel, + transformer: MagiTransformer3DModel, + vae: AutoencoderKLMagi, + scheduler: FlowMatchEulerDiscreteScheduler, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + ) + + self.vae_scale_factor_temporal = vae.temporal_downsample_factor + self.vae_scale_factor_spatial = vae.spatial_downsample_factor + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + def _encode_prompt( + self, + prompt: Union[str, List[str]], + device: torch.device, + num_videos_per_prompt: int, + do_classifier_free_guidance: bool, + negative_prompt: Optional[Union[str, List[str]]] = None, + max_length: Optional[int] = None, + ) -> torch.Tensor: + """ + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`): The prompt or prompts to guide the video generation. + device: The device to place the encoded prompt on. + num_videos_per_prompt (`int`): The number of videos that should be generated per prompt. + do_classifier_free_guidance (`bool`): Whether to use classifier-free guidance or not. + negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the video generation. + max_length (`int`, *optional*): The maximum length of the prompt to be encoded. + + Returns: + `torch.Tensor`: A tensor containing the encoded text embeddings. + """ + batch_size = len(prompt) if isinstance(prompt, list) else 1 + + # Default to 77 if not specified + if max_length is None: + max_length = self.tokenizer.model_max_length + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids.to(device) + attention_mask = text_inputs.attention_mask.to(device) + + prompt_embeds = self.text_encoder(text_input_ids, attention_mask=attention_mask)[0] + + # Process special tokens if present (following MAGI-1's approach) + # In diffusers style, we don't need to explicitly handle special tokens as they're part of the tokenizer + # But we can ensure proper mask handling similar to MAGI-1 + seq_len = prompt_embeds.shape[1] + + # Duplicate text embeddings for each generation per prompt + prompt_embeds = prompt_embeds.repeat_interleave(num_videos_per_prompt, dim=0) + attention_mask = attention_mask.repeat_interleave(num_videos_per_prompt, dim=0) + + # Get unconditional embeddings for classifier-free guidance + if do_classifier_free_guidance: + uncond_tokens = [""] * batch_size + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + uncond_input_ids = uncond_input.input_ids.to(device) + uncond_attention_mask = uncond_input.attention_mask.to(device) + negative_prompt_embeds = self.text_encoder(uncond_input_ids, attention_mask=uncond_attention_mask)[0] + + # Duplicate unconditional embeddings for each generation per prompt + negative_prompt_embeds = negative_prompt_embeds.repeat_interleave(num_videos_per_prompt, dim=0) + uncond_attention_mask = uncond_attention_mask.repeat_interleave(num_videos_per_prompt, dim=0) + + # Ensure null embeddings have proper attention mask handling (similar to MAGI-1's null_emb_masks) + # In MAGI-1, they set attention to first 50 tokens and zero for the rest + if uncond_attention_mask.shape[1] > 50: + uncond_attention_mask[:, :50] = 1 + uncond_attention_mask[:, 50:] = 0 + + # Concatenate unconditional and text embeddings for classifier-free guidance + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + attention_mask = torch.cat([uncond_attention_mask, attention_mask]) + + return prompt_embeds + + def _prepare_latents( + self, + batch_size: int, + num_channels_latents: int, + num_frames: int, + height: int, + width: int, + dtype: torch.dtype, + device: torch.device, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Prepare latents for diffusion. + + Args: + batch_size (`int`): The batch size. + num_channels_latents (`int`): The number of channels in the latent space. + num_frames (`int`): The number of frames to generate. + height (`int`): The height of the video. + width (`int`): The width of the video. + dtype (`torch.dtype`): The data type of the latents. + device (`torch.device`): The device to place the latents on. + generator (`torch.Generator`, *optional*): A generator to use for random number generation. + latents (`torch.Tensor`, *optional*): Pre-generated latent vectors. If not provided, latents will be generated randomly. + + Returns: + `torch.Tensor`: The prepared latent vectors. + """ + shape = ( + batch_size, + num_channels_latents, + num_frames // self.vae_scale_factor_temporal, + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial, + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + # Scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + + return latents + + def _get_chunk_indices(self, num_frames: int, chunk_size: int = 24) -> List[Tuple[int, int]]: + """ + Get the indices for processing video in chunks. + + Args: + num_frames (`int`): Total number of frames. + chunk_size (`int`, *optional*, defaults to 24): Size of each chunk. + + Returns: + `List[Tuple[int, int]]`: List of (start_idx, end_idx) tuples for each chunk. + """ + chunks = [] + for i in range(0, num_frames, chunk_size): + chunks.append((i, min(i + chunk_size, num_frames))) + return chunks + + def check_inputs( + self, + prompt, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + """ + Validate the inputs for the pipeline. + + Args: + prompt (`str` or `List[str]`): The prompt or prompts to guide generation. + height (`int`): The height in pixels of the generated video. + width (`int`): The width in pixels of the generated video. + callback_steps (`int`): The frequency at which the callback function is called. + negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide generation. + prompt_embeds (`torch.Tensor`, *optional*): Pre-computed text embeddings. + negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-computed negative text embeddings. + """ + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + f"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def _prepare_timesteps( + self, + num_inference_steps: int, + device: torch.device, + t_schedule_func: str = "sd3", + shift: float = 3.0, + ) -> torch.Tensor: + """ + Prepare timesteps for diffusion process, with scheduling options similar to MAGI-1. + + Args: + num_inference_steps (`int`): Number of diffusion steps. + device (`torch.device`): Device to place timesteps on. + t_schedule_func (`str`, optional, defaults to "sd3"): + Timestep scheduling function. Options: "sd3", "square", "piecewise", "linear". + shift (`float`, optional, defaults to 3.0): Shift parameter for sd3 scheduler. + + Returns: + `torch.Tensor`: Prepared timesteps. + """ + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # Apply custom scheduling similar to MAGI-1 if needed + if t_schedule_func == "sd3": + # Apply quadratic transformation + t = torch.linspace(0, 1, num_inference_steps + 1, device=device) + t = t ** 2 + + # Apply SD3-style transformation + def t_resolution_transform(x, shift_value=shift): + assert shift_value >= 1.0, "shift should >=1" + shift_inv = 1.0 / shift_value + return shift_inv * x / (1 + (shift_inv - 1) * x) + + t = t_resolution_transform(t, shift) + + # Map to scheduler timesteps + # Note: This is a simplified approach - in a full implementation, + # we would need to properly map these values to the scheduler's timesteps + return self.scheduler.timesteps + + elif t_schedule_func == "square": + # Simple quadratic scheduling + t = torch.linspace(0, 1, num_inference_steps + 1, device=device) + t = t ** 2 + return self.scheduler.timesteps + + elif t_schedule_func == "piecewise": + # Piecewise scheduling as in MAGI-1 + t = torch.linspace(0, 1, num_inference_steps + 1, device=device) + + # Apply piecewise transformation + mask = t < 0.875 + t_transformed = torch.zeros_like(t) + t_transformed[mask] = t[mask] * (0.5 / 0.875) + t_transformed[~mask] = 0.5 + (t[~mask] - 0.875) * (0.5 / (1 - 0.875)) + + # Map to scheduler timesteps (simplified) + return self.scheduler.timesteps + + # Default: use scheduler's default timesteps + return timesteps + + def denoise_latents( + self, + latents: torch.Tensor, + prompt_embeds: torch.Tensor, + timesteps: List[int], + callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, + callback_steps: int = 1, + guidance_scale: float = 7.5, + ) -> torch.Tensor: + """ + Denoise the latents using the transformer model. + + Args: + latents (`torch.Tensor`): The initial noisy latents. + prompt_embeds (`torch.Tensor`): The text embeddings for conditioning. + timesteps (`List[int]`): The timesteps for the diffusion process. + callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps. + callback_steps (`int`, *optional*, defaults to 1): The frequency at which the callback is called. + guidance_scale (`float`, *optional*, defaults to 7.5): The scale for classifier-free guidance. + + Returns: + `torch.Tensor`: The denoised latents. + """ + do_classifier_free_guidance = guidance_scale > 1.0 + batch_size = latents.shape[0] // (2 if do_classifier_free_guidance else 1) + + for i, t in enumerate(timesteps): + # Expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # Predict the noise residual + noise_pred = self.transformer( + latent_model_input, + timesteps=torch.tensor([t], device=latents.device), + encoder_hidden_states=prompt_embeds, + ).sample + + # Perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # Compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents).prev_sample + + # Call the callback, if provided + if i % callback_steps == 0: + if callback is not None: + callback(i, t, latents) + + return latents + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]], + height: Optional[int] = 720, + width: Optional[int] = 720, + num_frames: Optional[int] = 24, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_videos_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + output_type: Optional[str] = "np", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + chunk_size: int = 24, + t_schedule_func: str = "sd3", + t_schedule_shift: float = 3.0, + ) -> Union[MagiPipelineOutput, Tuple]: + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the video generation. + height (`int`, *optional*, defaults to 720): + The height in pixels of the generated video. + width (`int`, *optional*, defaults to 720): + The width in pixels of the generated video. + num_frames (`int`, *optional*, defaults to 24): + The number of video frames to generate. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality video at the expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate videos that are closely linked to the text `prompt`, usually at the expense of lower video quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the video generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for video generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + output_type (`str`, *optional*, defaults to `"np"`): + The output format of the generate video. Choose between `np` for `numpy.array`, `pt` for `torch.Tensor` or `latent` to get the latent space output. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.magi.MagiPipelineOutput`] instead of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [`diffusers.cross_attention`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + chunk_size (`int`, *optional*, defaults to 24): + The chunk size to use for autoregressive generation. Measured in frames. + t_schedule_func (`str`, *optional*, defaults to "sd3"): + Timestep scheduling function. Options: "sd3", "square", "piecewise", "linear". + t_schedule_shift (`float`, *optional*, defaults to 3.0): + Shift parameter for sd3 scheduler. + + Examples: + + Returns: + [`~pipelines.magi.MagiPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.magi.MagiPipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated frames. + """ + # 0. Default height and width to unet + height = height or self.transformer.config.sample_size + width = width or self.transformer.config.sample_size + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds = self._encode_prompt( + prompt, + device, + num_videos_per_prompt, + do_classifier_free_guidance, + negative_prompt, + ) + + # 4. Prepare timesteps + timesteps = self._prepare_timesteps( + num_inference_steps, + device, + t_schedule_func=t_schedule_func, + shift=t_schedule_shift + ) + + # 5. Prepare latent variables + num_channels_latents = self.transformer.in_channels + + # Regular text-to-video case + latents = self._prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + num_frames, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Process in chunks for autoregressive generation + chunk_indices = self._get_chunk_indices(num_frames // self.vae_scale_factor_temporal, chunk_size // self.vae_scale_factor_temporal) + all_latents = [] + + # 7. Denoise the latents + for chunk_idx, (start_idx, end_idx) in enumerate(chunk_indices): + # Extract the current chunk + chunk_frames = end_idx - start_idx + if chunk_idx == 0: + # For the first chunk, use the initial latents + chunk_latents = latents[:, :, start_idx:end_idx, :, :] + else: + # For subsequent chunks, implement proper autoregressive conditioning + # In MAGI-1, each chunk conditions the next in an autoregressive manner + # We use the previous chunk's output as conditioning for the current chunk + prev_chunk_end = chunk_indices[chunk_idx - 1][1] + overlap_start = max(0, start_idx - self.vae_scale_factor_temporal) # Add overlap for conditioning + + # Initialize with noise + chunk_latents = randn_tensor( + (batch_size * num_videos_per_prompt, num_channels_latents, chunk_frames, + height // self.vae_scale_factor_spatial, width // self.vae_scale_factor_spatial), + generator=generator, + device=device, + dtype=prompt_embeds.dtype, + ) + chunk_latents = chunk_latents * self.scheduler.init_noise_sigma + + # Use previous chunk output as conditioning by copying overlapping frames + if start_idx > 0 and chunk_idx > 0: + overlap_frames = min(self.vae_scale_factor_temporal, start_idx) + if overlap_frames > 0: + # Copy overlapping frames from previous chunk's output + chunk_latents[:, :, :overlap_frames, :, :] = all_latents[chunk_idx - 1][:, :, -overlap_frames:, :, :] + + # Denoise this chunk + chunk_latents = self.denoise_latents( + chunk_latents, + prompt_embeds, + timesteps, + callback=callback if chunk_idx == 0 else None, # Only use callback for first chunk + callback_steps=callback_steps, + guidance_scale=guidance_scale, + ) + + all_latents.append(chunk_latents) + + # 8. Concatenate all chunks + latents = torch.cat(all_latents, dim=2) + + # 9. Post-processing + if output_type == "latent": + video = latents + else: + # Decode the latents + latents = 1 / self.vae.scaling_factor * latents + video = self.vae.decode(latents).sample + video = (video / 2 + 0.5).clamp(0, 1) + + # Convert to the desired output format + if output_type == "pt": + video = video + else: + video = video.cpu().permute(0, 2, 3, 4, 1).float().numpy() + + # 10. Return output + if not return_dict: + return (video,) + + return MagiPipelineOutput(frames=video) \ No newline at end of file diff --git a/src/diffusers/pipelines/magi/pipeline_magi_i2v.py b/src/diffusers/pipelines/magi/pipeline_magi_i2v.py new file mode 100644 index 000000000000..662a3940987e --- /dev/null +++ b/src/diffusers/pipelines/magi/pipeline_magi_i2v.py @@ -0,0 +1,546 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import PIL +import torch +from transformers import UMT5EncoderModel, AutoTokenizer + +from ...models import AutoencoderKLMagi, MagiTransformer3DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import ( + logging, + replace_example_docstring, + randn_tensor, +) +from ...image_processor import PipelineImageInput +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import MagiPipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers import MagiImageToVideoPipeline + >>> from diffusers.utils import export_to_video, load_image + + >>> pipeline = MagiImageToVideoPipeline.from_pretrained("sand-ai/MAGI-1-4.5B", torch_dtype=torch.float16) + >>> pipeline = pipeline.to("cuda") + + >>> image = load_image("path/to/image.jpg") + >>> prompt = "A cat playing in a garden. The cat is chasing a butterfly." + >>> output = pipeline(prompt=prompt, image=image, num_frames=24, height=720, width=720).frames[0] + >>> export_to_video(output, "magi_i2v_output.mp4", fps=8) + ``` +""" + + +class MagiImageToVideoPipeline(DiffusionPipeline): + r""" + Pipeline for image-to-video generation using MAGI-1. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + tokenizer (`AutoTokenizer`): + Tokenizer for the text encoder. + text_encoder (`UMT5EncoderModel`): + Text encoder for conditioning. + transformer (`MagiTransformer3DModel`): + Conditional Transformer to denoise the latent video. + vae (`AutoencoderKLMagi`): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + scheduler (`FlowMatchEulerDiscreteScheduler`): + A scheduler to be used in combination with `transformer` to denoise the encoded video latents. + """ + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + tokenizer: AutoTokenizer, + text_encoder: UMT5EncoderModel, + transformer: MagiTransformer3DModel, + vae: AutoencoderKLMagi, + scheduler: FlowMatchEulerDiscreteScheduler, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, + text_encoder=text_encoder, + transformer=transformer, + vae=vae, + scheduler=scheduler, + ) + + self.vae_scale_factor_temporal = 2 ** (1 if hasattr(self.vae, "temporal_downsample") else 0) + self.vae_scale_factor_spatial = 2 ** (3 if hasattr(self.vae, "config") else 8) # Default to 8 for 3 downsamples + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + def _encode_prompt( + self, + prompt: Union[str, List[str]], + device: torch.device, + num_videos_per_prompt: int, + do_classifier_free_guidance: bool, + negative_prompt: Optional[Union[str, List[str]]] = None, + max_sequence_length: int = 512, + ) -> torch.Tensor: + """ + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`): + prompt to be encoded + device (`torch.device`): + torch device + num_videos_per_prompt (`int`): + number of videos that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the video generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + max_sequence_length (`int`, *optional*, defaults to 512): + The maximum length of the sequence to be processed by the text encoder. + + Returns: + `torch.Tensor`: text embeddings + """ + batch_size = len(prompt) if isinstance(prompt, list) else 1 + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids.to(device) + + prompt_embeds = self.text_encoder(text_input_ids).last_hidden_state + + # duplicate text embeddings for each generation per prompt + prompt_embeds = prompt_embeds.repeat_interleave(num_videos_per_prompt, dim=0) + + # get unconditional embeddings for classifier-free guidance + if do_classifier_free_guidance: + uncond_tokens = [""] * batch_size + max_length = text_inputs.input_ids.shape[-1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + uncond_input_ids = uncond_input.input_ids.to(device) + negative_prompt_embeds = self.text_encoder(uncond_input_ids).last_hidden_state + + # duplicate unconditional embeddings for each generation per prompt + negative_prompt_embeds = negative_prompt_embeds.repeat_interleave(num_videos_per_prompt, dim=0) + + # For classifier-free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + def check_inputs( + self, + prompt, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + f"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def prepare_image_latents( + self, + image: Union[torch.Tensor, PIL.Image.Image], + batch_size: int, + num_videos_per_prompt: int, + do_classifier_free_guidance: bool, + device: torch.device, + ) -> torch.Tensor: + """ + Encode an input image to latent space. + + Args: + image (`torch.Tensor` or `PIL.Image.Image`): + Input image to be encoded. + batch_size (`int`): + Batch size. + num_videos_per_prompt (`int`): + Number of videos per prompt. + do_classifier_free_guidance (`bool`): + Whether to use classifier-free guidance. + device (`torch.device`): + Device to place the latents on. + + Returns: + `torch.Tensor`: Encoded image latents. + """ + # Convert PIL image to tensor + if isinstance(image, PIL.Image.Image): + image = self.video_processor.preprocess_image(image) + image = image.to(device=device) + + # Encode image + image_latents = self.vae.encode(image).latent_dist.sample() + image_latents = image_latents * self.vae.scaling_factor + + # Expand for batch size and classifier-free guidance + image_latents = image_latents.repeat(batch_size * num_videos_per_prompt, 1, 1, 1) + if do_classifier_free_guidance: + image_latents = torch.cat([image_latents, image_latents], dim=0) + + return image_latents + + def _prepare_image_based_latents( + self, + batch_size: int, + num_channels_latents: int, + num_frames: int, + height: int, + width: int, + dtype: torch.dtype, + device: torch.device, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.Tensor] = None, + image_latents: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Prepare latents for diffusion with image conditioning. + + Args: + batch_size (`int`): The batch size. + num_channels_latents (`int`): The number of channels in the latent space. + num_frames (`int`): The number of frames to generate. + height (`int`): The height of the video. + width (`int`): The width of the video. + dtype (`torch.dtype`): The data type of the latents. + device (`torch.device`): The device to place the latents on. + generator (`torch.Generator`, *optional*): A generator to use for random number generation. + latents (`torch.Tensor`, *optional*): Pre-generated latent vectors. If not provided, latents will be generated randomly. + image_latents (`torch.Tensor`, *optional*): Image latents for conditioning. + + Returns: + `torch.Tensor`: The prepared latent vectors. + """ + shape = ( + batch_size, + num_channels_latents, + num_frames // self.vae_scale_factor_temporal, + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial, + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + # Scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + + # If we have image latents, use them to condition the first frame + if image_latents is not None: + # Expand image latents to match the temporal dimension of the first frame + image_latents = image_latents.unsqueeze(2) # [B, C, 1, H, W] + + # Only replace the first frame with the image latents + latents[:, :, 0, :, :] = image_latents.squeeze(2) + + return latents + + def _get_chunk_indices(self, num_frames: int, chunk_size: int) -> List[Tuple[int, int]]: + """ + Get the start and end indices for each chunk. + + Args: + num_frames (`int`): Total number of frames. + chunk_size (`int`): Size of each chunk. + + Returns: + `List[Tuple[int, int]]`: List of (start_idx, end_idx) tuples for each chunk. + """ + return [(i, min(i + chunk_size, num_frames)) for i in range(0, num_frames, chunk_size)] + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]], + image: Union[torch.Tensor, PIL.Image.Image], + height: Optional[int] = 720, + width: Optional[int] = 720, + num_frames: Optional[int] = 24, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_videos_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + output_type: Optional[str] = "np", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + chunk_size: int = 24, + ) -> Union[MagiPipelineOutput, Tuple]: + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the video generation. + image (`torch.Tensor` or `PIL.Image.Image`): + The input image to guide the video generation. + height (`int`, *optional*, defaults to 720): + The height in pixels of the generated video. + width (`int`, *optional*, defaults to 720): + The width in pixels of the generated video. + num_frames (`int`, *optional*, defaults to 24): + The number of video frames to generate. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality video at the expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate videos that are closely linked to the text `prompt`, usually at the expense of lower video quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the video generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for video generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + output_type (`str`, *optional*, defaults to `"np"`): + The output format of the generate video. Choose between `np` for `numpy.array`, `pt` for `torch.Tensor` or `latent` to get the latent space output. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.magi.MagiPipelineOutput`] instead of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [`diffusers.cross_attention`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + chunk_size (`int`, *optional*, defaults to 24): + The chunk size to use for autoregressive generation. Measured in frames. + + Examples: + + Returns: + [`~pipelines.magi.MagiPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.magi.MagiPipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated frames. + """ + # 0. Default height and width to unet + height = height or self.transformer.config.sample_size + width = width or self.transformer.config.sample_size + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds = self._encode_prompt( + prompt, + device, + num_videos_per_prompt, + do_classifier_free_guidance, + negative_prompt, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare image latents + image_latents = self.prepare_image_latents( + image, + batch_size, + num_videos_per_prompt, + do_classifier_free_guidance, + device, + ) + + # 6. Prepare latent variables with image conditioning + num_channels_latents = self.transformer.in_channels + latents = self._prepare_image_based_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + num_frames, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + image_latents, + ) + + # 7. Process in chunks for autoregressive generation + chunk_indices = self._get_chunk_indices(num_frames // self.vae_scale_factor_temporal, chunk_size // self.vae_scale_factor_temporal) + all_latents = [] + + # 8. Process each chunk + for chunk_idx, (start_idx, end_idx) in enumerate(chunk_indices): + # Extract the current chunk + chunk_frames = end_idx - start_idx + if chunk_idx == 0: + # For the first chunk, use the initial latents + chunk_latents = latents[:, :, start_idx:end_idx, :, :] + else: + # For subsequent chunks, use the previous chunk as conditioning + # This is a simplified version - in a real implementation, we would need to handle + # the autoregressive conditioning properly + chunk_latents = randn_tensor( + (batch_size * num_videos_per_prompt, num_channels_latents, chunk_frames, + height // self.vae_scale_factor_spatial, width // self.vae_scale_factor_spatial), + generator=generator, + device=device, + dtype=prompt_embeds.dtype, + ) + chunk_latents = chunk_latents * self.scheduler.init_noise_sigma + + # 9. Denoising loop for this chunk + with self.progress_bar(total=len(timesteps)) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([chunk_latents] * 2) if do_classifier_free_guidance else chunk_latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.transformer( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + chunk_latents = self.scheduler.step(noise_pred, t, chunk_latents).prev_sample + + # call the callback, if provided + if callback is not None and i % callback_steps == 0: + callback(i, t, chunk_latents) + + progress_bar.update() + + all_latents.append(chunk_latents) + + # 10. Concatenate all chunks + latents = torch.cat(all_latents, dim=2) + + # 11. Post-processing + if output_type == "latent": + video = latents + else: + # Decode the latents + latents = 1 / self.vae.scaling_factor * latents + video = self.vae.decode(latents).sample + video = (video / 2 + 0.5).clamp(0, 1) + + # Convert to the desired output format + if output_type == "pt": + video = video + else: + video = video.cpu().permute(0, 2, 3, 4, 1).float().numpy() + + # 12. Return output + if not return_dict: + return (video,) + + return MagiPipelineOutput(frames=video) \ No newline at end of file diff --git a/src/diffusers/pipelines/magi/pipeline_magi_v2v.py b/src/diffusers/pipelines/magi/pipeline_magi_v2v.py new file mode 100644 index 000000000000..d0ee5d67dc94 --- /dev/null +++ b/src/diffusers/pipelines/magi/pipeline_magi_v2v.py @@ -0,0 +1,552 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +from transformers import UMT5EncoderModel, AutoTokenizer + +from ...models import AutoencoderKLMagi, MagiTransformer3DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import ( + logging, + replace_example_docstring, + randn_tensor, +) +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import MagiPipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers import MagiVideoToVideoPipeline + >>> from diffusers.utils import export_to_video, load_video + + >>> pipeline = MagiVideoToVideoPipeline.from_pretrained("sand-ai/MAGI-1-4.5B", torch_dtype=torch.float16) + >>> pipeline = pipeline.to("cuda") + + >>> input_video = load_video("path/to/video.mp4") + >>> prompt = "A cat playing in a garden. The cat is chasing a butterfly." + >>> output = pipeline(prompt=prompt, video=input_video, num_frames=24, height=720, width=720).frames[0] + >>> export_to_video(output, "magi_v2v_output.mp4", fps=8) + ``` +""" + + +class MagiVideoToVideoPipeline(DiffusionPipeline): + r""" + Pipeline for video-to-video generation using MAGI-1. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + tokenizer (`AutoTokenizer`): + Tokenizer for the text encoder. + text_encoder (`UMT5EncoderModel`): + Text encoder for conditioning. + transformer (`MagiTransformer3DModel`): + Conditional Transformer to denoise the latent video. + vae (`AutoencoderKLMagi`): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + scheduler (`FlowMatchEulerDiscreteScheduler`): + A scheduler to be used in combination with `transformer` to denoise the encoded video latents. + """ + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + tokenizer: AutoTokenizer, + text_encoder: UMT5EncoderModel, + transformer: MagiTransformer3DModel, + vae: AutoencoderKLMagi, + scheduler: FlowMatchEulerDiscreteScheduler, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, + text_encoder=text_encoder, + transformer=transformer, + vae=vae, + scheduler=scheduler, + ) + + self.vae_scale_factor_temporal = 2 ** (1 if hasattr(self.vae, "temporal_downsample") else 0) + self.vae_scale_factor_spatial = 2 ** (3 if hasattr(self.vae, "config") else 8) # Default to 8 for 3 downsamples + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + def _encode_prompt( + self, + prompt: Union[str, List[str]], + device: torch.device, + num_videos_per_prompt: int, + do_classifier_free_guidance: bool, + negative_prompt: Optional[Union[str, List[str]]] = None, + max_sequence_length: int = 512, + ) -> torch.Tensor: + """ + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`): + prompt to be encoded + device (`torch.device`): + torch device + num_videos_per_prompt (`int`): + number of videos that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the video generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + max_sequence_length (`int`, *optional*, defaults to 512): + The maximum length of the sequence to be processed by the text encoder. + + Returns: + `torch.Tensor`: text embeddings + """ + batch_size = len(prompt) if isinstance(prompt, list) else 1 + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids.to(device) + + prompt_embeds = self.text_encoder(text_input_ids).last_hidden_state + + # duplicate text embeddings for each generation per prompt + prompt_embeds = prompt_embeds.repeat_interleave(num_videos_per_prompt, dim=0) + + # get unconditional embeddings for classifier-free guidance + if do_classifier_free_guidance: + uncond_tokens = [""] * batch_size + max_length = text_inputs.input_ids.shape[-1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + uncond_input_ids = uncond_input.input_ids.to(device) + negative_prompt_embeds = self.text_encoder(uncond_input_ids).last_hidden_state + + # duplicate unconditional embeddings for each generation per prompt + negative_prompt_embeds = negative_prompt_embeds.repeat_interleave(num_videos_per_prompt, dim=0) + + # For classifier-free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + def check_inputs( + self, + prompt, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + video=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + f"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if video is None: + raise ValueError("`video` input cannot be undefined.") + + def prepare_video_latents( + self, + video: torch.Tensor, + batch_size: int, + num_videos_per_prompt: int, + do_classifier_free_guidance: bool, + device: torch.device, + ) -> torch.Tensor: + """ + Encode an input video to latent space. + + Args: + video (`torch.Tensor`): + Input video to be encoded. + batch_size (`int`): + Batch size. + num_videos_per_prompt (`int`): + Number of videos per prompt. + do_classifier_free_guidance (`bool`): + Whether to use classifier-free guidance. + device (`torch.device`): + Device to place the latents on. + + Returns: + `torch.Tensor`: Encoded video latents. + """ + # Ensure video is on the correct device + video = video.to(device=device) + + # Encode video + video_latents = self.vae.encode(video).latent_dist.sample() + video_latents = video_latents * self.vae.scaling_factor + + # Expand for batch size and classifier-free guidance + video_latents = video_latents.repeat(batch_size * num_videos_per_prompt, 1, 1, 1, 1) + if do_classifier_free_guidance: + video_latents = torch.cat([video_latents, video_latents], dim=0) + + return video_latents + + def _prepare_video_based_latents( + self, + batch_size: int, + num_channels_latents: int, + num_frames: int, + height: int, + width: int, + dtype: torch.dtype, + device: torch.device, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.Tensor] = None, + video_latents: Optional[torch.Tensor] = None, + num_frames_to_condition: Optional[int] = None, + ) -> torch.Tensor: + """ + Prepare latents for diffusion with video conditioning. + + Args: + batch_size (`int`): The batch size. + num_channels_latents (`int`): The number of channels in the latent space. + num_frames (`int`): The number of frames to generate. + height (`int`): The height of the video. + width (`int`): The width of the video. + dtype (`torch.dtype`): The data type of the latents. + device (`torch.device`): The device to place the latents on. + generator (`torch.Generator`, *optional*): A generator to use for random number generation. + latents (`torch.Tensor`, *optional*): Pre-generated latent vectors. If not provided, latents will be generated randomly. + video_latents (`torch.Tensor`, *optional*): Video latents for conditioning. + num_frames_to_condition (`int`, *optional*): Number of frames to use for conditioning. + + Returns: + `torch.Tensor`: The prepared latent vectors. + """ + shape = ( + batch_size, + num_channels_latents, + num_frames // self.vae_scale_factor_temporal, + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial, + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + # Scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + + # If we have video latents, use them to condition the first frames + if video_latents is not None: + if num_frames_to_condition is None: + num_frames_to_condition = video_latents.shape[2] + + # Only replace the first N frames with the video latents + latents[:, :, :num_frames_to_condition, :, :] = video_latents[:, :, :num_frames_to_condition, :, :] + + return latents + + def _get_chunk_indices(self, num_frames: int, chunk_size: int) -> List[Tuple[int, int]]: + """ + Get the start and end indices for each chunk. + + Args: + num_frames (`int`): Total number of frames. + chunk_size (`int`): Size of each chunk. + + Returns: + `List[Tuple[int, int]]`: List of (start_idx, end_idx) tuples for each chunk. + """ + return [(i, min(i + chunk_size, num_frames)) for i in range(0, num_frames, chunk_size)] + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]], + video: torch.Tensor, + height: Optional[int] = 720, + width: Optional[int] = 720, + num_frames: Optional[int] = 24, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_videos_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + output_type: Optional[str] = "np", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + chunk_size: int = 24, + num_frames_to_condition: Optional[int] = None, + ) -> Union[MagiPipelineOutput, Tuple]: + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the video generation. + video (`torch.Tensor`): + The input video to guide the video generation. Should be a tensor of shape (B, C, F, H, W). + height (`int`, *optional*, defaults to 720): + The height in pixels of the generated video. + width (`int`, *optional*, defaults to 720): + The width in pixels of the generated video. + num_frames (`int`, *optional*, defaults to 24): + The number of video frames to generate. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality video at the expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate videos that are closely linked to the text `prompt`, usually at the expense of lower video quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the video generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for video generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + output_type (`str`, *optional*, defaults to `"np"`): + The output format of the generate video. Choose between `np` for `numpy.array`, `pt` for `torch.Tensor` or `latent` to get the latent space output. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.magi.MagiPipelineOutput`] instead of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [`diffusers.cross_attention`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + chunk_size (`int`, *optional*, defaults to 24): + The chunk size to use for autoregressive generation. Measured in frames. + num_frames_to_condition (`int`, *optional*): + Number of frames from the input video to use for conditioning. If not provided, all frames from the input video will be used. + + Examples: + + Returns: + [`~pipelines.magi.MagiPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.magi.MagiPipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated frames. + """ + # 0. Default height and width to unet + height = height or self.transformer.config.sample_size + width = width or self.transformer.config.sample_size + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds, video + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds = self._encode_prompt( + prompt, + device, + num_videos_per_prompt, + do_classifier_free_guidance, + negative_prompt, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare video latents + video_latents = self.prepare_video_latents( + video, + batch_size, + num_videos_per_prompt, + do_classifier_free_guidance, + device, + ) + + # 6. Prepare latent variables with video conditioning + num_channels_latents = self.transformer.in_channels + latents = self._prepare_video_based_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + num_frames, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + video_latents, + num_frames_to_condition, + ) + + # 7. Process in chunks for autoregressive generation + chunk_indices = self._get_chunk_indices(num_frames // self.vae_scale_factor_temporal, chunk_size // self.vae_scale_factor_temporal) + all_latents = [] + + # 8. Process each chunk + for chunk_idx, (start_idx, end_idx) in enumerate(chunk_indices): + # Extract the current chunk + chunk_frames = end_idx - start_idx + if chunk_idx == 0: + # For the first chunk, use the initial latents + chunk_latents = latents[:, :, start_idx:end_idx, :, :] + else: + # For subsequent chunks, use the previous chunk as conditioning + # This is a simplified version - in a real implementation, we would need to handle + # the autoregressive conditioning properly + chunk_latents = randn_tensor( + (batch_size * num_videos_per_prompt, num_channels_latents, chunk_frames, + height // self.vae_scale_factor_spatial, width // self.vae_scale_factor_spatial), + generator=generator, + device=device, + dtype=prompt_embeds.dtype, + ) + chunk_latents = chunk_latents * self.scheduler.init_noise_sigma + + # 9. Denoising loop for this chunk + with self.progress_bar(total=len(timesteps)) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([chunk_latents] * 2) if do_classifier_free_guidance else chunk_latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.transformer( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + chunk_latents = self.scheduler.step(noise_pred, t, chunk_latents).prev_sample + + # call the callback, if provided + if callback is not None and i % callback_steps == 0: + callback(i, t, chunk_latents) + + progress_bar.update() + + all_latents.append(chunk_latents) + + # 10. Concatenate all chunks + latents = torch.cat(all_latents, dim=2) + + # 11. Post-processing + if output_type == "latent": + video = latents + else: + # Decode the latents + latents = 1 / self.vae.scaling_factor * latents + video = self.vae.decode(latents).sample + video = (video / 2 + 0.5).clamp(0, 1) + + # Convert to the desired output format + if output_type == "pt": + video = video + else: + video = video.cpu().permute(0, 2, 3, 4, 1).float().numpy() + + # 12. Return output + if not return_dict: + return (video,) + + return MagiPipelineOutput(frames=video) \ No newline at end of file diff --git a/src/diffusers/pipelines/magi/pipeline_output.py b/src/diffusers/pipelines/magi/pipeline_output.py new file mode 100644 index 000000000000..60b2b59e719a --- /dev/null +++ b/src/diffusers/pipelines/magi/pipeline_output.py @@ -0,0 +1,34 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import List, Optional, Union + +import numpy as np +import torch + +from ...utils import BaseOutput + + +@dataclass +class MagiPipelineOutput(BaseOutput): + """ + Output class for MAGI-1 pipeline. + + Args: + frames (`torch.Tensor` or `np.ndarray`): + List of denoised frames from the diffusion process, as a NumPy array of shape `(batch_size, num_frames, height, width, num_channels)` or a PyTorch tensor of shape `(batch_size, num_channels, num_frames, height, width)`. + """ + + frames: Union[torch.Tensor, np.ndarray, List[List[np.ndarray]]] \ No newline at end of file diff --git a/tests/models/autoencoders/test_models_autoencoder_kl_magi.py b/tests/models/autoencoders/test_models_autoencoder_kl_magi.py new file mode 100644 index 000000000000..0593209c5427 --- /dev/null +++ b/tests/models/autoencoders/test_models_autoencoder_kl_magi.py @@ -0,0 +1,155 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from diffusers import AutoencoderKLMagi +from diffusers.utils.testing_utils import enable_full_determinism, floats_tensor, torch_device + +from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin + + +enable_full_determinism() + + +class AutoencoderKLMagiTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): + model_class = AutoencoderKLMagi + main_input_name = "sample" + base_precision = 1e-2 + + def get_autoencoder_kl_magi_config(self): + return { + "base_dim": 3, + "z_dim": 16, + "dim_mult": [1, 1, 1, 1], + "num_res_blocks": 1, + "temperal_downsample": [False, True, True], + } + + @property + def dummy_input(self): + batch_size = 2 + num_frames = 9 + num_channels = 3 + sizes = (16, 16) + image = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device) + return {"sample": image} + + @property + def dummy_input_tiling(self): + batch_size = 2 + num_frames = 9 + num_channels = 3 + sizes = (128, 128) + image = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device) + return {"sample": image} + + @property + def input_shape(self): + return (3, 9, 16, 16) + + @property + def output_shape(self): + return (3, 9, 16, 16) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = self.get_autoencoder_kl_magi_config() + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def prepare_init_args_and_inputs_for_tiling(self): + init_dict = self.get_autoencoder_kl_magi_config() + inputs_dict = self.dummy_input_tiling + return init_dict, inputs_dict + + def test_enable_disable_tiling(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_tiling() + + torch.manual_seed(0) + model = self.model_class(**init_dict).to(torch_device) + + inputs_dict.update({"return_dict": False}) + + torch.manual_seed(0) + output_without_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0] + + torch.manual_seed(0) + model.enable_tiling(96, 96, 64, 64) + output_with_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0] + + self.assertLess( + (output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy()).max(), + 0.5, + "VAE tiling should not affect the inference results", + ) + + torch.manual_seed(0) + model.disable_tiling() + output_without_tiling_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0] + + self.assertEqual( + output_without_tiling.detach().cpu().numpy().all(), + output_without_tiling_2.detach().cpu().numpy().all(), + "Without tiling outputs should match with the outputs when tiling is manually disabled.", + ) + + def test_enable_disable_slicing(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + torch.manual_seed(0) + model = self.model_class(**init_dict).to(torch_device) + + inputs_dict.update({"return_dict": False}) + + torch.manual_seed(0) + output_without_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0] + + torch.manual_seed(0) + model.enable_slicing() + output_with_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0] + + self.assertLess( + (output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy()).max(), + 0.05, + "VAE slicing should not affect the inference results", + ) + + torch.manual_seed(0) + model.disable_slicing() + output_without_slicing_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0] + + self.assertEqual( + output_without_slicing.detach().cpu().numpy().all(), + output_without_slicing_2.detach().cpu().numpy().all(), + "Without slicing outputs should match with the outputs when slicing is manually disabled.", + ) + + @unittest.skip("Gradient checkpointing has not been implemented yet") + def test_gradient_checkpointing_is_applied(self): + pass + + @unittest.skip("Test not supported") + def test_forward_with_norm_groups(self): + pass + + @unittest.skip("RuntimeError: fill_out not implemented for 'Float8_e4m3fn'") + def test_layerwise_casting_inference(self): + pass + + @unittest.skip("RuntimeError: fill_out not implemented for 'Float8_e4m3fn'") + def test_layerwise_casting_training(self): + pass \ No newline at end of file diff --git a/tests/models/transformers/test_models_transformer_magi.py b/tests/models/transformers/test_models_transformer_magi.py new file mode 100644 index 000000000000..bbe677b9b1d1 --- /dev/null +++ b/tests/models/transformers/test_models_transformer_magi.py @@ -0,0 +1,91 @@ +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from diffusers import MagiTransformer3DModel +from diffusers.utils.testing_utils import ( + enable_full_determinism, + torch_device, +) + +from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin + + +enable_full_determinism() + + +class MagiTransformer3DTests(ModelTesterMixin, unittest.TestCase): + model_class = MagiTransformer3DModel + main_input_name = "hidden_states" + uses_custom_attn_processor = True + + @property + def dummy_input(self): + batch_size = 1 + num_channels = 4 + num_frames = 2 + height = 16 + width = 16 + text_encoder_embedding_dim = 16 + sequence_length = 12 + + hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device) + timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device) + encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device) + + return { + "hidden_states": hidden_states, + "encoder_hidden_states": encoder_hidden_states, + "timestep": timestep, + } + + @property + def input_shape(self): + return (4, 1, 16, 16) + + @property + def output_shape(self): + return (4, 1, 16, 16) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "patch_size": (1, 2, 2), + "num_attention_heads": 2, + "attention_head_dim": 12, + "in_channels": 4, + "out_channels": 4, + "text_dim": 16, + "freq_dim": 256, + "ffn_dim": 32, + "num_layers": 2, + "cross_attn_norm": True, + "qk_norm": "rms_norm_across_heads", + "rope_max_seq_len": 32, + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def test_gradient_checkpointing_is_applied(self): + expected_set = {"MagiTransformer3DModel"} + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) + + +class MagiTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase): + model_class = MagiTransformer3DModel + + def prepare_init_args_and_inputs_for_common(self): + return MagiTransformer3DTests().prepare_init_args_and_inputs_for_common() \ No newline at end of file diff --git a/tests/pipelines/magi/test_magi.py b/tests/pipelines/magi/test_magi.py new file mode 100644 index 000000000000..af09350806fe --- /dev/null +++ b/tests/pipelines/magi/test_magi.py @@ -0,0 +1,158 @@ +# Copyright 2024 The HuggingFace Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import unittest + +import numpy as np +import torch +from transformers import AutoTokenizer, T5EncoderModel + +from diffusers import AutoencoderKLMagi, FlowMatchEulerDiscreteScheduler, MagiPipeline, MagiTransformer3DModel +from diffusers.utils.testing_utils import ( + backend_empty_cache, + enable_full_determinism, + require_torch_accelerator, + slow, + torch_device, +) + +from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS +from ..test_pipelines_common import ( + PipelineTesterMixin, +) + + +enable_full_determinism() + + +class MagiPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = MagiPipeline + params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} + batch_params = TEXT_TO_IMAGE_BATCH_PARAMS + image_params = TEXT_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + required_optional_params = frozenset( + [ + "num_inference_steps", + "generator", + "latents", + "return_dict", + "callback_on_step_end", + "callback_on_step_end_tensor_inputs", + ] + ) + test_xformers_attention = False + supports_dduf = False + + def get_dummy_components(self): + torch.manual_seed(0) + vae = AutoencoderKLMagi( + base_dim=3, + z_dim=16, + dim_mult=[1, 1, 1, 1], + num_res_blocks=1, + temperal_downsample=[False, True, True], + ) + + torch.manual_seed(0) + # TODO: impl FlowDPMSolverMultistepScheduler + scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0) + text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + + torch.manual_seed(0) + transformer = MagiTransformer3DModel( + patch_size=(1, 2, 2), + num_attention_heads=2, + attention_head_dim=12, + in_channels=16, + out_channels=16, + text_dim=32, + freq_dim=256, + ffn_dim=32, + num_layers=2, + cross_attn_norm=True, + qk_norm="rms_norm_across_heads", + rope_max_seq_len=32, + ) + + components = { + "transformer": transformer, + "vae": vae, + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + inputs = { + "prompt": "dance monkey", + "negative_prompt": "negative", # TODO + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 6.0, + "height": 16, + "width": 16, + "num_frames": 9, + "max_sequence_length": 16, + "output_type": "pt", + } + return inputs + + def test_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + video = pipe(**inputs).frames + generated_video = video[0] + + self.assertEqual(generated_video.shape, (9, 3, 16, 16)) + expected_video = torch.randn(9, 3, 16, 16) + max_diff = np.abs(generated_video - expected_video).max() + self.assertLessEqual(max_diff, 1e10) + + @unittest.skip("Test not supported") + def test_attention_slicing_forward_pass(self): + pass + + +@slow +@require_torch_accelerator +class MagiPipelineIntegrationTests(unittest.TestCase): + prompt = "A painting of a squirrel eating a burger." + + def setUp(self): + super().setUp() + gc.collect() + backend_empty_cache(torch_device) + + def tearDown(self): + super().tearDown() + gc.collect() + backend_empty_cache(torch_device) + + @unittest.skip("TODO: test needs to be implemented") + def test_Magi(self): + pass \ No newline at end of file diff --git a/tests/pipelines/magi/test_magi_image_to_video.py b/tests/pipelines/magi/test_magi_image_to_video.py new file mode 100644 index 000000000000..70780c9c5cb4 --- /dev/null +++ b/tests/pipelines/magi/test_magi_image_to_video.py @@ -0,0 +1,215 @@ +# Copyright 2024 The HuggingFace Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import unittest + +import numpy as np +import PIL +import torch +from transformers import AutoTokenizer, CLIPVisionModel, T5EncoderModel + +from diffusers import AutoencoderKLMagi, FlowMatchEulerDiscreteScheduler, MagiImageToVideoPipeline, MagiTransformer3DModel +from diffusers.utils.testing_utils import ( + backend_empty_cache, + enable_full_determinism, + load_image, + load_numpy, + nightly, + require_torch_accelerator, + torch_device, +) + +from ..pipeline_params import ( + IMAGE_TO_IMAGE_IMAGE_PARAMS, + TEXT_TO_IMAGE_BATCH_PARAMS, + TEXT_TO_IMAGE_IMAGE_PARAMS, + TEXT_TO_IMAGE_PARAMS, +) +from ..test_pipelines_common import ( + PipelineTesterMixin, +) + + +enable_full_determinism() + + +class MagiImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = MagiImageToVideoPipeline + params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} + batch_params = TEXT_TO_IMAGE_BATCH_PARAMS + image_params = TEXT_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + required_optional_params = frozenset( + [ + "num_inference_steps", + "generator", + "latents", + "return_dict", + "callback_on_step_end", + "callback_on_step_end_tensor_inputs", + ] + ) + test_xformers_attention = False + supports_dduf = False + + def get_dummy_components(self): + torch.manual_seed(0) + vae = AutoencoderKLMagi( + base_dim=3, + z_dim=16, + dim_mult=[1, 1, 1, 1], + num_res_blocks=1, + temperal_downsample=[False, True, True], + ) + + torch.manual_seed(0) + # TODO: impl FlowDPMSolverMultistepScheduler + scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0) + text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + image_encoder = CLIPVisionModel.from_pretrained("hf-internal-testing/tiny-random-clip") + + torch.manual_seed(0) + transformer = MagiTransformer3DModel( + patch_size=(1, 2, 2), + num_attention_heads=2, + attention_head_dim=12, + in_channels=16, + out_channels=16, + text_dim=32, + freq_dim=256, + ffn_dim=32, + num_layers=2, + cross_attn_norm=True, + qk_norm="rms_norm_across_heads", + rope_max_seq_len=32, + ) + + components = { + "transformer": transformer, + "vae": vae, + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "image_encoder": image_encoder, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + inputs = { + "prompt": "dance monkey", + "negative_prompt": "negative", # TODO + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 6.0, + "height": 16, + "width": 16, + "num_frames": 9, + "max_sequence_length": 16, + "output_type": "pt", + "image": PIL.Image.new("RGB", (16, 16)), + } + return inputs + + def test_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + video = pipe(**inputs).frames + generated_video = video[0] + + self.assertEqual(generated_video.shape, (9, 3, 16, 16)) + expected_video = torch.randn(9, 3, 16, 16) + max_diff = np.abs(generated_video - expected_video).max() + self.assertLessEqual(max_diff, 1e10) + + @unittest.skip("Test not supported") + def test_attention_slicing_forward_pass(self): + pass + + +class MagiFLFToVideoPipelineFastTests(MagiImageToVideoPipelineFastTests): + def get_dummy_components(self): + torch.manual_seed(0) + vae = AutoencoderKLMagi( + base_dim=3, + z_dim=16, + dim_mult=[1, 1, 1, 1], + num_res_blocks=1, + temperal_downsample=[False, True, True], + ) + + torch.manual_seed(0) + # TODO: impl FlowDPMSolverMultistepScheduler + scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0) + text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + image_encoder = CLIPVisionModel.from_pretrained("hf-internal-testing/tiny-random-clip") + + torch.manual_seed(0) + transformer = MagiTransformer3DModel( + patch_size=(1, 2, 2), + num_attention_heads=2, + attention_head_dim=12, + in_channels=16, + out_channels=16, + text_dim=32, + freq_dim=256, + ffn_dim=32, + num_layers=2, + cross_attn_norm=True, + qk_norm="rms_norm_across_heads", + rope_max_seq_len=32, + ) + + components = { + "transformer": transformer, + "vae": vae, + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "image_encoder": image_encoder, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + inputs = { + "prompt": "dance monkey", + "negative_prompt": "negative", # TODO + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 6.0, + "height": 16, + "width": 16, + "num_frames": 9, + "max_sequence_length": 16, + "output_type": "pt", + "image": PIL.Image.new("RGB", (16, 16)), + "last_image": PIL.Image.new("RGB", (16, 16)), + } + return inputs \ No newline at end of file diff --git a/tests/pipelines/magi/test_magi_video_to_video.py b/tests/pipelines/magi/test_magi_video_to_video.py new file mode 100644 index 000000000000..ae11830997b2 --- /dev/null +++ b/tests/pipelines/magi/test_magi_video_to_video.py @@ -0,0 +1,148 @@ +# Copyright 2024 The HuggingFace Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import unittest + +import numpy as np +import torch +from transformers import AutoTokenizer, T5EncoderModel + +from diffusers import AutoencoderKLMagi, UniPCMultistepScheduler, MagiTransformer3DModel, MagiVideoToVideoPipeline +from diffusers.utils.testing_utils import ( + backend_empty_cache, + enable_full_determinism, + require_torch_accelerator, + slow, + torch_device, +) + +from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS +from ..test_pipelines_common import ( + PipelineTesterMixin, +) + + +enable_full_determinism() + + +class MagiVideoToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = MagiVideoToVideoPipeline + params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} + batch_params = TEXT_TO_IMAGE_BATCH_PARAMS + image_params = TEXT_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + required_optional_params = frozenset( + [ + "num_inference_steps", + "generator", + "latents", + "return_dict", + "callback_on_step_end", + "callback_on_step_end_tensor_inputs", + ] + ) + test_xformers_attention = False + supports_dduf = False + + def get_dummy_components(self): + torch.manual_seed(0) + vae = AutoencoderKLMagi( + base_dim=3, + z_dim=16, + dim_mult=[1, 1, 1, 1], + num_res_blocks=1, + temperal_downsample=[False, True, True], + ) + + torch.manual_seed(0) + scheduler = UniPCMultistepScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000 + ) + text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + + torch.manual_seed(0) + transformer = MagiTransformer3DModel( + patch_size=(1, 2, 2), + num_attention_heads=2, + attention_head_dim=12, + in_channels=16, + out_channels=16, + text_dim=32, + freq_dim=256, + ffn_dim=32, + num_layers=2, + cross_attn_norm=True, + qk_norm="rms_norm_across_heads", + rope_max_seq_len=32, + ) + + components = { + "transformer": transformer, + "vae": vae, + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + inputs = { + "prompt": "dance monkey", + "negative_prompt": "negative", # TODO + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 6.0, + "height": 16, + "width": 16, + "num_frames": 9, + "max_sequence_length": 16, + "output_type": "pt", + "video": torch.randn((1, 3, 9, 16, 16)), + } + return inputs + + def test_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + video = pipe(**inputs).frames + generated_video = video[0] + + self.assertEqual(generated_video.shape, (9, 3, 16, 16)) + expected_video = torch.randn(9, 3, 16, 16) + max_diff = np.abs(generated_video - expected_video).max() + self.assertLessEqual(max_diff, 1e10) + + @unittest.skip("Test not supported") + def test_attention_slicing_forward_pass(self): + pass + + @unittest.skip("MagiVideoToVideoPipeline has to run in mixed precision. Casting the entire pipeline will result in errors") + def test_model_cpu_offload_forward_pass(self): + pass + + @unittest.skip("MagiVideoToVideoPipeline has to run in mixed precision. Save/Load the entire pipeline in FP16 will result in errors") + def test_save_load_float16(self): + pass \ No newline at end of file diff --git a/tests/single_file/test_model_magi_autoencoder_single_file.py b/tests/single_file/test_model_magi_autoencoder_single_file.py new file mode 100644 index 000000000000..732ab1c223d7 --- /dev/null +++ b/tests/single_file/test_model_magi_autoencoder_single_file.py @@ -0,0 +1,64 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from huggingface_hub import hf_hub_download + +from diffusers import AutoencoderKLMagi +from diffusers.utils.testing_utils import ( + require_torch_gpu, + slow, + torch_device, +) + + +class AutoencoderKLMagiSingleFileTests(unittest.TestCase): + model_class = AutoencoderKLMagi + ckpt_path = ( + "https://huggingface.co/sand-ai/MAGI-1/blob/main/vae/diffusion_pytorch_model.safetensors" + ) + repo_id = "sand-ai/MAGI-1" + + @slow + @require_torch_gpu + def test_single_file_components(self): + model = self.model_class.from_single_file(self.ckpt_path) + model.to(torch_device) + + batch_size = 1 + num_frames = 2 + num_channels = 3 + sizes = (16, 16) + image = torch.randn((batch_size, num_channels, num_frames) + sizes).to(torch_device) + + with torch.no_grad(): + model(image, return_dict=False) + + @slow + @require_torch_gpu + def test_single_file_components_from_hub(self): + model = self.model_class.from_pretrained(self.repo_id, subfolder="vae") + model.to(torch_device) + + batch_size = 1 + num_frames = 2 + num_channels = 3 + sizes = (16, 16) + image = torch.randn((batch_size, num_channels, num_frames) + sizes).to(torch_device) + + with torch.no_grad(): + model(image, return_dict=False) \ No newline at end of file diff --git a/tests/single_file/test_model_magi_transformer3d_single_file.py b/tests/single_file/test_model_magi_transformer3d_single_file.py new file mode 100644 index 000000000000..4c7a026b899b --- /dev/null +++ b/tests/single_file/test_model_magi_transformer3d_single_file.py @@ -0,0 +1,84 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from huggingface_hub import hf_hub_download + +from diffusers import MagiTransformer3DModel +from diffusers.utils.testing_utils import ( + require_torch_gpu, + slow, + torch_device, +) + + +class MagiTransformer3DModelText2VideoSingleFileTest(unittest.TestCase): + model_class = MagiTransformer3DModel + ckpt_path = "https://huggingface.co/sand-ai/MAGI-1/blob/main/transformer/diffusion_pytorch_model.safetensors" + repo_id = "sand-ai/MAGI-1" + + @slow + @require_torch_gpu + def test_single_file_components(self): + model = self.model_class.from_single_file(self.ckpt_path) + model.to(torch_device) + + batch_size = 1 + num_channels = 4 + num_frames = 2 + height = 16 + width = 16 + text_encoder_embedding_dim = 16 + sequence_length = 12 + + hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device) + timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device) + encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device) + + with torch.no_grad(): + model( + hidden_states=hidden_states, + timestep=timestep, + encoder_hidden_states=encoder_hidden_states, + return_dict=False, + ) + + @slow + @require_torch_gpu + def test_single_file_components_from_hub(self): + model = self.model_class.from_pretrained(self.repo_id, subfolder="transformer") + model.to(torch_device) + + batch_size = 1 + num_channels = 4 + num_frames = 2 + height = 16 + width = 16 + text_encoder_embedding_dim = 16 + sequence_length = 12 + + hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device) + timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device) + encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device) + + with torch.no_grad(): + model( + hidden_states=hidden_states, + timestep=timestep, + encoder_hidden_states=encoder_hidden_states, + return_dict=False, + ) \ No newline at end of file From 89806ea7805926542a6fc3b3abb7152f9b2b0766 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 13 Jun 2025 22:10:32 +0300 Subject: [PATCH 2/8] style --- scripts/convert_magi_to_diffusers.py | 28 +++++++++---------- .../autoencoders/autoencoder_kl_magi.py | 5 ++-- .../models/transformers/transformer_magi.py | 3 +- src/diffusers/pipelines/__init__.py | 2 +- src/diffusers/pipelines/magi/__init__.py | 2 +- src/diffusers/pipelines/magi/pipeline_magi.py | 11 ++------ .../pipelines/magi/pipeline_magi_i2v.py | 9 ++---- .../pipelines/magi/pipeline_magi_v2v.py | 8 ++---- .../pipelines/magi/pipeline_output.py | 4 +-- .../test_models_autoencoder_kl_magi.py | 2 +- .../test_models_transformer_magi.py | 2 +- tests/pipelines/magi/test_magi.py | 2 +- .../magi/test_magi_image_to_video.py | 17 +++++------ .../magi/test_magi_video_to_video.py | 9 ++---- ...test_model_magi_autoencoder_single_file.py | 3 +- ...st_model_magi_transformer3d_single_file.py | 3 +- 16 files changed, 43 insertions(+), 67 deletions(-) diff --git a/scripts/convert_magi_to_diffusers.py b/scripts/convert_magi_to_diffusers.py index 38c50271f632..3a616a2a2ae1 100644 --- a/scripts/convert_magi_to_diffusers.py +++ b/scripts/convert_magi_to_diffusers.py @@ -19,19 +19,17 @@ import argparse import json import os -from pathlib import Path import torch -from huggingface_hub import hf_hub_download from safetensors import safe_open from safetensors.torch import load_file from transformers import AutoTokenizer, UMT5EncoderModel from diffusers import ( AutoencoderKLMagi, + FlowMatchEulerDiscreteScheduler, MagiPipeline, MagiTransformer3DModel, - FlowMatchEulerDiscreteScheduler, ) @@ -583,13 +581,13 @@ def main(): else: dtype = torch.float32 - print(f"Starting MAGI-1 conversion to diffusers format...") + print("Starting MAGI-1 conversion to diffusers format...") print(f"Output will be saved to: {args.output_path}") print(f"Using dtype: {args.dtype}") try: # Convert the VAE - print(f"Converting VAE checkpoint...") + print("Converting VAE checkpoint...") if args.vae_checkpoint_path: vae_path = args.vae_checkpoint_path else: @@ -599,10 +597,10 @@ def main(): print(f"VAE checkpoint path: {vae_path}") vae = convert_magi_vae_checkpoint(vae_path, dtype=dtype) - print(f"VAE conversion complete.") + print("VAE conversion complete.") # Convert the transformer - print(f"Converting transformer checkpoint...") + print("Converting transformer checkpoint...") if args.transformer_checkpoint_path: transformer_path = args.transformer_checkpoint_path else: @@ -610,7 +608,7 @@ def main(): print(f"Transformer checkpoint path: {transformer_path}") transformer = convert_magi_transformer_checkpoint(transformer_path, dtype=dtype) - print(f"Transformer conversion complete.") + print("Transformer conversion complete.") # Load the text encoder and tokenizer print(f"Loading text encoder and tokenizer from {args.t5_model_name}...") @@ -619,15 +617,15 @@ def main(): if dtype is not None: text_encoder = text_encoder.to(dtype=dtype) - print(f"Text encoder and tokenizer loaded successfully.") + print("Text encoder and tokenizer loaded successfully.") # Create the scheduler - print(f"Creating scheduler...") + print("Creating scheduler...") scheduler = FlowMatchEulerDiscreteScheduler() - print(f"Scheduler created successfully.") + print("Scheduler created successfully.") # Create the pipeline - print(f"Creating MAGI pipeline...") + print("Creating MAGI pipeline...") pipeline = MagiPipeline( vae=vae, text_encoder=text_encoder, @@ -635,12 +633,12 @@ def main(): transformer=transformer, scheduler=scheduler, ) - print(f"MAGI pipeline created successfully.") + print("MAGI pipeline created successfully.") # Save the pipeline print(f"Saving pipeline to {args.output_path}...") pipeline.save_pretrained(args.output_path) - print(f"Pipeline saved successfully.") + print("Pipeline saved successfully.") print(f"Conversion complete! MAGI-1 pipeline saved to {args.output_path}") @@ -654,4 +652,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_magi.py b/src/diffusers/models/autoencoders/autoencoder_kl_magi.py index 5b15298cd238..11ce5edd0db6 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_magi.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_magi.py @@ -12,11 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass -from typing import Dict, Optional, Tuple, Union +from typing import Optional, Tuple, Union import torch import torch.nn as nn -import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config from ...utils import BaseOutput, logging @@ -716,4 +715,4 @@ def forward( if not return_dict: return (dec,) - return BaseOutput(sample=dec.sample) \ No newline at end of file + return BaseOutput(sample=dec.sample) diff --git a/src/diffusers/models/transformers/transformer_magi.py b/src/diffusers/models/transformers/transformer_magi.py index 4e527cecf228..070f2921d522 100644 --- a/src/diffusers/models/transformers/transformer_magi.py +++ b/src/diffusers/models/transformers/transformer_magi.py @@ -17,7 +17,6 @@ import torch import torch.nn as nn import torch.nn.functional as F -from einops import rearrange from ...configuration_utils import ConfigMixin, register_to_config from ...utils import BaseOutput, logging @@ -665,4 +664,4 @@ def forward( if not return_dict: return (hidden_states,) - return MagiTransformerOutput(sample=hidden_states) \ No newline at end of file + return MagiTransformerOutput(sample=hidden_states) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index c7d9a2f29a1f..ea264cb6b914 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -652,12 +652,12 @@ from .ltx import LTXConditionPipeline, LTXImageToVideoPipeline, LTXLatentUpsamplePipeline, LTXPipeline from .lumina import LuminaPipeline, LuminaText2ImgPipeline from .lumina2 import Lumina2Pipeline, Lumina2Text2ImgPipeline + from .magi import MagiImageToVideoPipeline, MagiPipeline, MagiVideoToVideoPipeline from .marigold import ( MarigoldDepthPipeline, MarigoldIntrinsicsPipeline, MarigoldNormalsPipeline, ) - from .magi import MagiPipeline, MagiImageToVideoPipeline, MagiVideoToVideoPipeline from .mochi import MochiPipeline from .musicldm import MusicLDMPipeline from .omnigen import OmniGenPipeline diff --git a/src/diffusers/pipelines/magi/__init__.py b/src/diffusers/pipelines/magi/__init__.py index 4fb6cc376987..2aa0f44f6f08 100644 --- a/src/diffusers/pipelines/magi/__init__.py +++ b/src/diffusers/pipelines/magi/__init__.py @@ -48,4 +48,4 @@ globals()["__file__"], _import_structure, module_spec=__spec__, - ) \ No newline at end of file + ) diff --git a/src/diffusers/pipelines/magi/pipeline_magi.py b/src/diffusers/pipelines/magi/pipeline_magi.py index 92805fec11e0..26957c9ccdff 100644 --- a/src/diffusers/pipelines/magi/pipeline_magi.py +++ b/src/diffusers/pipelines/magi/pipeline_magi.py @@ -12,22 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -import inspect -import math from typing import Any, Callable, Dict, List, Optional, Tuple, Union -import numpy as np import torch -from transformers import UMT5EncoderModel, AutoTokenizer +from transformers import AutoTokenizer, UMT5EncoderModel -from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...models import AutoencoderKLMagi, MagiTransformer3DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import ( is_torch_xla_available, logging, - replace_example_docstring, randn_tensor, + replace_example_docstring, ) from ...video_processor import VideoProcessor from ..pipeline_utils import DiffusionPipeline @@ -35,7 +31,6 @@ if is_torch_xla_available(): - import torch_xla.core.xla_model as xm XLA_AVAILABLE = True else: @@ -638,4 +633,4 @@ def __call__( if not return_dict: return (video,) - return MagiPipelineOutput(frames=video) \ No newline at end of file + return MagiPipelineOutput(frames=video) diff --git a/src/diffusers/pipelines/magi/pipeline_magi_i2v.py b/src/diffusers/pipelines/magi/pipeline_magi_i2v.py index 662a3940987e..40a66e21a5c7 100644 --- a/src/diffusers/pipelines/magi/pipeline_magi_i2v.py +++ b/src/diffusers/pipelines/magi/pipeline_magi_i2v.py @@ -12,22 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -import inspect from typing import Any, Callable, Dict, List, Optional, Tuple, Union -import numpy as np import PIL import torch -from transformers import UMT5EncoderModel, AutoTokenizer +from transformers import AutoTokenizer, UMT5EncoderModel from ...models import AutoencoderKLMagi, MagiTransformer3DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import ( logging, - replace_example_docstring, randn_tensor, + replace_example_docstring, ) -from ...image_processor import PipelineImageInput from ...video_processor import VideoProcessor from ..pipeline_utils import DiffusionPipeline from .pipeline_output import MagiPipelineOutput @@ -543,4 +540,4 @@ def __call__( if not return_dict: return (video,) - return MagiPipelineOutput(frames=video) \ No newline at end of file + return MagiPipelineOutput(frames=video) diff --git a/src/diffusers/pipelines/magi/pipeline_magi_v2v.py b/src/diffusers/pipelines/magi/pipeline_magi_v2v.py index d0ee5d67dc94..8c7f7a821c0d 100644 --- a/src/diffusers/pipelines/magi/pipeline_magi_v2v.py +++ b/src/diffusers/pipelines/magi/pipeline_magi_v2v.py @@ -12,19 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -import inspect from typing import Any, Callable, Dict, List, Optional, Tuple, Union -import numpy as np import torch -from transformers import UMT5EncoderModel, AutoTokenizer +from transformers import AutoTokenizer, UMT5EncoderModel from ...models import AutoencoderKLMagi, MagiTransformer3DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import ( logging, - replace_example_docstring, randn_tensor, + replace_example_docstring, ) from ...video_processor import VideoProcessor from ..pipeline_utils import DiffusionPipeline @@ -549,4 +547,4 @@ def __call__( if not return_dict: return (video,) - return MagiPipelineOutput(frames=video) \ No newline at end of file + return MagiPipelineOutput(frames=video) diff --git a/src/diffusers/pipelines/magi/pipeline_output.py b/src/diffusers/pipelines/magi/pipeline_output.py index 60b2b59e719a..1a71e788246a 100644 --- a/src/diffusers/pipelines/magi/pipeline_output.py +++ b/src/diffusers/pipelines/magi/pipeline_output.py @@ -13,7 +13,7 @@ # limitations under the License. from dataclasses import dataclass -from typing import List, Optional, Union +from typing import List, Union import numpy as np import torch @@ -31,4 +31,4 @@ class MagiPipelineOutput(BaseOutput): List of denoised frames from the diffusion process, as a NumPy array of shape `(batch_size, num_frames, height, width, num_channels)` or a PyTorch tensor of shape `(batch_size, num_channels, num_frames, height, width)`. """ - frames: Union[torch.Tensor, np.ndarray, List[List[np.ndarray]]] \ No newline at end of file + frames: Union[torch.Tensor, np.ndarray, List[List[np.ndarray]]] diff --git a/tests/models/autoencoders/test_models_autoencoder_kl_magi.py b/tests/models/autoencoders/test_models_autoencoder_kl_magi.py index 0593209c5427..4b38579508e6 100644 --- a/tests/models/autoencoders/test_models_autoencoder_kl_magi.py +++ b/tests/models/autoencoders/test_models_autoencoder_kl_magi.py @@ -152,4 +152,4 @@ def test_layerwise_casting_inference(self): @unittest.skip("RuntimeError: fill_out not implemented for 'Float8_e4m3fn'") def test_layerwise_casting_training(self): - pass \ No newline at end of file + pass diff --git a/tests/models/transformers/test_models_transformer_magi.py b/tests/models/transformers/test_models_transformer_magi.py index bbe677b9b1d1..cf2dd091cb13 100644 --- a/tests/models/transformers/test_models_transformer_magi.py +++ b/tests/models/transformers/test_models_transformer_magi.py @@ -88,4 +88,4 @@ class MagiTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase): model_class = MagiTransformer3DModel def prepare_init_args_and_inputs_for_common(self): - return MagiTransformer3DTests().prepare_init_args_and_inputs_for_common() \ No newline at end of file + return MagiTransformer3DTests().prepare_init_args_and_inputs_for_common() diff --git a/tests/pipelines/magi/test_magi.py b/tests/pipelines/magi/test_magi.py index af09350806fe..ba3145e41b1d 100644 --- a/tests/pipelines/magi/test_magi.py +++ b/tests/pipelines/magi/test_magi.py @@ -155,4 +155,4 @@ def tearDown(self): @unittest.skip("TODO: test needs to be implemented") def test_Magi(self): - pass \ No newline at end of file + pass diff --git a/tests/pipelines/magi/test_magi_image_to_video.py b/tests/pipelines/magi/test_magi_image_to_video.py index 70780c9c5cb4..7b5d5b721da9 100644 --- a/tests/pipelines/magi/test_magi_image_to_video.py +++ b/tests/pipelines/magi/test_magi_image_to_video.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import gc import unittest import numpy as np @@ -20,19 +19,17 @@ import torch from transformers import AutoTokenizer, CLIPVisionModel, T5EncoderModel -from diffusers import AutoencoderKLMagi, FlowMatchEulerDiscreteScheduler, MagiImageToVideoPipeline, MagiTransformer3DModel +from diffusers import ( + AutoencoderKLMagi, + FlowMatchEulerDiscreteScheduler, + MagiImageToVideoPipeline, + MagiTransformer3DModel, +) from diffusers.utils.testing_utils import ( - backend_empty_cache, enable_full_determinism, - load_image, - load_numpy, - nightly, - require_torch_accelerator, - torch_device, ) from ..pipeline_params import ( - IMAGE_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS, @@ -212,4 +209,4 @@ def get_dummy_inputs(self, device, seed=0): "image": PIL.Image.new("RGB", (16, 16)), "last_image": PIL.Image.new("RGB", (16, 16)), } - return inputs \ No newline at end of file + return inputs diff --git a/tests/pipelines/magi/test_magi_video_to_video.py b/tests/pipelines/magi/test_magi_video_to_video.py index ae11830997b2..db3c6a0299a6 100644 --- a/tests/pipelines/magi/test_magi_video_to_video.py +++ b/tests/pipelines/magi/test_magi_video_to_video.py @@ -12,20 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -import gc import unittest import numpy as np import torch from transformers import AutoTokenizer, T5EncoderModel -from diffusers import AutoencoderKLMagi, UniPCMultistepScheduler, MagiTransformer3DModel, MagiVideoToVideoPipeline +from diffusers import AutoencoderKLMagi, MagiTransformer3DModel, MagiVideoToVideoPipeline, UniPCMultistepScheduler from diffusers.utils.testing_utils import ( - backend_empty_cache, enable_full_determinism, - require_torch_accelerator, - slow, - torch_device, ) from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS @@ -145,4 +140,4 @@ def test_model_cpu_offload_forward_pass(self): @unittest.skip("MagiVideoToVideoPipeline has to run in mixed precision. Save/Load the entire pipeline in FP16 will result in errors") def test_save_load_float16(self): - pass \ No newline at end of file + pass diff --git a/tests/single_file/test_model_magi_autoencoder_single_file.py b/tests/single_file/test_model_magi_autoencoder_single_file.py index 732ab1c223d7..07e5701dbb79 100644 --- a/tests/single_file/test_model_magi_autoencoder_single_file.py +++ b/tests/single_file/test_model_magi_autoencoder_single_file.py @@ -16,7 +16,6 @@ import unittest import torch -from huggingface_hub import hf_hub_download from diffusers import AutoencoderKLMagi from diffusers.utils.testing_utils import ( @@ -61,4 +60,4 @@ def test_single_file_components_from_hub(self): image = torch.randn((batch_size, num_channels, num_frames) + sizes).to(torch_device) with torch.no_grad(): - model(image, return_dict=False) \ No newline at end of file + model(image, return_dict=False) diff --git a/tests/single_file/test_model_magi_transformer3d_single_file.py b/tests/single_file/test_model_magi_transformer3d_single_file.py index 4c7a026b899b..151f3e3997b4 100644 --- a/tests/single_file/test_model_magi_transformer3d_single_file.py +++ b/tests/single_file/test_model_magi_transformer3d_single_file.py @@ -16,7 +16,6 @@ import unittest import torch -from huggingface_hub import hf_hub_download from diffusers import MagiTransformer3DModel from diffusers.utils.testing_utils import ( @@ -81,4 +80,4 @@ def test_single_file_components_from_hub(self): timestep=timestep, encoder_hidden_states=encoder_hidden_states, return_dict=False, - ) \ No newline at end of file + ) From f4b5748c419d48f3d4c4c9d2a21f4900d11a500e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 13 Jun 2025 22:20:04 +0300 Subject: [PATCH 3/8] upp --- scripts/convert_magi_to_diffusers.py | 1 - src/diffusers/pipelines/magi/pipeline_magi.py | 107 +++++++----------- 2 files changed, 38 insertions(+), 70 deletions(-) diff --git a/scripts/convert_magi_to_diffusers.py b/scripts/convert_magi_to_diffusers.py index 3a616a2a2ae1..3b9dcda4c1a8 100644 --- a/scripts/convert_magi_to_diffusers.py +++ b/scripts/convert_magi_to_diffusers.py @@ -78,7 +78,6 @@ def convert_magi_vae_checkpoint(checkpoint_path, vae_config_file=None, dtype=Non "block_out_channels": [1024], # Hidden dimension in transformer blocks "layers_per_block": 24, # 24 transformer blocks in encoder/decoder "act_fn": "gelu", - "latent_channels": 16, "norm_num_groups": 32, "scaling_factor": 0.18215, "sample_size": 256, # Typical image size diff --git a/src/diffusers/pipelines/magi/pipeline_magi.py b/src/diffusers/pipelines/magi/pipeline_magi.py index 26957c9ccdff..44b51738bae1 100644 --- a/src/diffusers/pipelines/magi/pipeline_magi.py +++ b/src/diffusers/pipelines/magi/pipeline_magi.py @@ -151,7 +151,7 @@ def _encode_prompt( # Process special tokens if present (following MAGI-1's approach) # In diffusers style, we don't need to explicitly handle special tokens as they're part of the tokenizer # But we can ensure proper mask handling similar to MAGI-1 - seq_len = prompt_embeds.shape[1] + # Shape of prompt_embeds: [batch_size, seq_len, hidden_size] # Duplicate text embeddings for each generation per prompt prompt_embeds = prompt_embeds.repeat_interleave(num_videos_per_prompt, dim=0) @@ -375,58 +375,7 @@ def t_resolution_transform(x, shift_value=shift): # Default: use scheduler's default timesteps return timesteps - def denoise_latents( - self, - latents: torch.Tensor, - prompt_embeds: torch.Tensor, - timesteps: List[int], - callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, - callback_steps: int = 1, - guidance_scale: float = 7.5, - ) -> torch.Tensor: - """ - Denoise the latents using the transformer model. - Args: - latents (`torch.Tensor`): The initial noisy latents. - prompt_embeds (`torch.Tensor`): The text embeddings for conditioning. - timesteps (`List[int]`): The timesteps for the diffusion process. - callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps. - callback_steps (`int`, *optional*, defaults to 1): The frequency at which the callback is called. - guidance_scale (`float`, *optional*, defaults to 7.5): The scale for classifier-free guidance. - - Returns: - `torch.Tensor`: The denoised latents. - """ - do_classifier_free_guidance = guidance_scale > 1.0 - batch_size = latents.shape[0] // (2 if do_classifier_free_guidance else 1) - - for i, t in enumerate(timesteps): - # Expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - - # Predict the noise residual - noise_pred = self.transformer( - latent_model_input, - timesteps=torch.tensor([t], device=latents.device), - encoder_hidden_states=prompt_embeds, - ).sample - - # Perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - - # Compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents).prev_sample - - # Call the callback, if provided - if i % callback_steps == 0: - if callback is not None: - callback(i, t, latents) - - return latents @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) @@ -579,8 +528,8 @@ def __call__( # For subsequent chunks, implement proper autoregressive conditioning # In MAGI-1, each chunk conditions the next in an autoregressive manner # We use the previous chunk's output as conditioning for the current chunk - prev_chunk_end = chunk_indices[chunk_idx - 1][1] - overlap_start = max(0, start_idx - self.vae_scale_factor_temporal) # Add overlap for conditioning + # Calculate overlap for conditioning + overlap_frames = min(self.vae_scale_factor_temporal, start_idx) # Initialize with noise chunk_latents = randn_tensor( @@ -593,21 +542,41 @@ def __call__( chunk_latents = chunk_latents * self.scheduler.init_noise_sigma # Use previous chunk output as conditioning by copying overlapping frames - if start_idx > 0 and chunk_idx > 0: - overlap_frames = min(self.vae_scale_factor_temporal, start_idx) - if overlap_frames > 0: - # Copy overlapping frames from previous chunk's output - chunk_latents[:, :, :overlap_frames, :, :] = all_latents[chunk_idx - 1][:, :, -overlap_frames:, :, :] - - # Denoise this chunk - chunk_latents = self.denoise_latents( - chunk_latents, - prompt_embeds, - timesteps, - callback=callback if chunk_idx == 0 else None, # Only use callback for first chunk - callback_steps=callback_steps, - guidance_scale=guidance_scale, - ) + if start_idx > 0 and chunk_idx > 0 and overlap_frames > 0: + # Copy overlapping frames from previous chunk's output + chunk_latents[:, :, :overlap_frames, :, :] = all_latents[chunk_idx - 1][:, :, -overlap_frames:, :, :] + + # Denoising loop for this chunk + with self.progress_bar(total=len(timesteps)) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([chunk_latents] * 2) if do_classifier_free_guidance else chunk_latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.transformer( + latent_model_input, + timesteps=torch.tensor([t], device=chunk_latents.device), + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + chunk_latents = self.scheduler.step(noise_pred, t, chunk_latents).prev_sample + + # call the callback, if provided + if callback is not None and chunk_idx == 0 and i % callback_steps == 0: + callback(i, t, chunk_latents) + + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() all_latents.append(chunk_latents) From 9b4531746fc7c09d295d8d78153123f943810956 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 13 Jun 2025 22:21:57 +0300 Subject: [PATCH 4/8] style --- scripts/convert_magi_to_diffusers.py | 137 +++++++++++++----- .../autoencoders/autoencoder_kl_magi.py | 92 +++++++----- .../models/transformers/transformer_magi.py | 48 +++--- src/diffusers/pipelines/magi/pipeline_magi.py | 88 +++++++---- .../pipelines/magi/pipeline_magi_i2v.py | 70 ++++++--- .../pipelines/magi/pipeline_magi_v2v.py | 73 +++++++--- .../pipelines/magi/pipeline_output.py | 4 +- .../magi/test_magi_video_to_video.py | 8 +- ...test_model_magi_autoencoder_single_file.py | 4 +- 9 files changed, 349 insertions(+), 175 deletions(-) diff --git a/scripts/convert_magi_to_diffusers.py b/scripts/convert_magi_to_diffusers.py index 3b9dcda4c1a8..55d6422a63e0 100644 --- a/scripts/convert_magi_to_diffusers.py +++ b/scripts/convert_magi_to_diffusers.py @@ -150,19 +150,33 @@ def convert_vae_state_dict(checkpoint): continue # Attention components - state_dict[f"encoder.transformer_blocks.{i}.attn1.to_qkv.weight"] = checkpoint[f"encoder.blocks.{i}.attn.qkv.weight"] - state_dict[f"encoder.transformer_blocks.{i}.attn1.to_qkv.bias"] = checkpoint[f"encoder.blocks.{i}.attn.qkv.bias"] - state_dict[f"encoder.transformer_blocks.{i}.attn1.to_out.0.weight"] = checkpoint[f"encoder.blocks.{i}.attn.proj.weight"] - state_dict[f"encoder.transformer_blocks.{i}.attn1.to_out.0.bias"] = checkpoint[f"encoder.blocks.{i}.attn.proj.bias"] + state_dict[f"encoder.transformer_blocks.{i}.attn1.to_qkv.weight"] = checkpoint[ + f"encoder.blocks.{i}.attn.qkv.weight" + ] + state_dict[f"encoder.transformer_blocks.{i}.attn1.to_qkv.bias"] = checkpoint[ + f"encoder.blocks.{i}.attn.qkv.bias" + ] + state_dict[f"encoder.transformer_blocks.{i}.attn1.to_out.0.weight"] = checkpoint[ + f"encoder.blocks.{i}.attn.proj.weight" + ] + state_dict[f"encoder.transformer_blocks.{i}.attn1.to_out.0.bias"] = checkpoint[ + f"encoder.blocks.{i}.attn.proj.bias" + ] # Normalization state_dict[f"encoder.transformer_blocks.{i}.norm2.weight"] = checkpoint[f"encoder.blocks.{i}.norm2.weight"] state_dict[f"encoder.transformer_blocks.{i}.norm2.bias"] = checkpoint[f"encoder.blocks.{i}.norm2.bias"] # MLP components - state_dict[f"encoder.transformer_blocks.{i}.ff.net.0.proj.weight"] = checkpoint[f"encoder.blocks.{i}.mlp.fc1.weight"] - state_dict[f"encoder.transformer_blocks.{i}.ff.net.0.proj.bias"] = checkpoint[f"encoder.blocks.{i}.mlp.fc1.bias"] - state_dict[f"encoder.transformer_blocks.{i}.ff.net.2.weight"] = checkpoint[f"encoder.blocks.{i}.mlp.fc2.weight"] + state_dict[f"encoder.transformer_blocks.{i}.ff.net.0.proj.weight"] = checkpoint[ + f"encoder.blocks.{i}.mlp.fc1.weight" + ] + state_dict[f"encoder.transformer_blocks.{i}.ff.net.0.proj.bias"] = checkpoint[ + f"encoder.blocks.{i}.mlp.fc1.bias" + ] + state_dict[f"encoder.transformer_blocks.{i}.ff.net.2.weight"] = checkpoint[ + f"encoder.blocks.{i}.mlp.fc2.weight" + ] state_dict[f"encoder.transformer_blocks.{i}.ff.net.2.bias"] = checkpoint[f"encoder.blocks.{i}.mlp.fc2.bias"] # Encoder norm @@ -196,19 +210,33 @@ def convert_vae_state_dict(checkpoint): continue # Attention components - state_dict[f"decoder.transformer_blocks.{i}.attn1.to_qkv.weight"] = checkpoint[f"decoder.blocks.{i}.attn.qkv.weight"] - state_dict[f"decoder.transformer_blocks.{i}.attn1.to_qkv.bias"] = checkpoint[f"decoder.blocks.{i}.attn.qkv.bias"] - state_dict[f"decoder.transformer_blocks.{i}.attn1.to_out.0.weight"] = checkpoint[f"decoder.blocks.{i}.attn.proj.weight"] - state_dict[f"decoder.transformer_blocks.{i}.attn1.to_out.0.bias"] = checkpoint[f"decoder.blocks.{i}.attn.proj.bias"] + state_dict[f"decoder.transformer_blocks.{i}.attn1.to_qkv.weight"] = checkpoint[ + f"decoder.blocks.{i}.attn.qkv.weight" + ] + state_dict[f"decoder.transformer_blocks.{i}.attn1.to_qkv.bias"] = checkpoint[ + f"decoder.blocks.{i}.attn.qkv.bias" + ] + state_dict[f"decoder.transformer_blocks.{i}.attn1.to_out.0.weight"] = checkpoint[ + f"decoder.blocks.{i}.attn.proj.weight" + ] + state_dict[f"decoder.transformer_blocks.{i}.attn1.to_out.0.bias"] = checkpoint[ + f"decoder.blocks.{i}.attn.proj.bias" + ] # Normalization state_dict[f"decoder.transformer_blocks.{i}.norm2.weight"] = checkpoint[f"decoder.blocks.{i}.norm2.weight"] state_dict[f"decoder.transformer_blocks.{i}.norm2.bias"] = checkpoint[f"decoder.blocks.{i}.norm2.bias"] # MLP components - state_dict[f"decoder.transformer_blocks.{i}.ff.net.0.proj.weight"] = checkpoint[f"decoder.blocks.{i}.mlp.fc1.weight"] - state_dict[f"decoder.transformer_blocks.{i}.ff.net.0.proj.bias"] = checkpoint[f"decoder.blocks.{i}.mlp.fc1.bias"] - state_dict[f"decoder.transformer_blocks.{i}.ff.net.2.weight"] = checkpoint[f"decoder.blocks.{i}.mlp.fc2.weight"] + state_dict[f"decoder.transformer_blocks.{i}.ff.net.0.proj.weight"] = checkpoint[ + f"decoder.blocks.{i}.mlp.fc1.weight" + ] + state_dict[f"decoder.transformer_blocks.{i}.ff.net.0.proj.bias"] = checkpoint[ + f"decoder.blocks.{i}.mlp.fc1.bias" + ] + state_dict[f"decoder.transformer_blocks.{i}.ff.net.2.weight"] = checkpoint[ + f"decoder.blocks.{i}.mlp.fc2.weight" + ] state_dict[f"decoder.transformer_blocks.{i}.ff.net.2.bias"] = checkpoint[f"decoder.blocks.{i}.mlp.fc2.bias"] # Decoder norm @@ -390,60 +418,98 @@ def convert_transformer_state_dict(checkpoint): continue # FF norm (AdaLN projection) - state_dict[f"transformer_blocks.{i}.ff_norm.weight"] = checkpoint[f"{layer_prefix}.ada_modulate_layer.proj.0.weight"] - state_dict[f"transformer_blocks.{i}.ff_norm.bias"] = checkpoint[f"{layer_prefix}.ada_modulate_layer.proj.0.bias"] + state_dict[f"transformer_blocks.{i}.ff_norm.weight"] = checkpoint[ + f"{layer_prefix}.ada_modulate_layer.proj.0.weight" + ] + state_dict[f"transformer_blocks.{i}.ff_norm.bias"] = checkpoint[ + f"{layer_prefix}.ada_modulate_layer.proj.0.bias" + ] # Self-attention components # Query normalization if f"{layer_prefix}.self_attention.q_layernorm.weight" in checkpoint: - state_dict[f"transformer_blocks.{i}.attn1.norm_q.weight"] = checkpoint[f"{layer_prefix}.self_attention.q_layernorm.weight"] - state_dict[f"transformer_blocks.{i}.attn1.norm_q.bias"] = checkpoint[f"{layer_prefix}.self_attention.q_layernorm.bias"] + state_dict[f"transformer_blocks.{i}.attn1.norm_q.weight"] = checkpoint[ + f"{layer_prefix}.self_attention.q_layernorm.weight" + ] + state_dict[f"transformer_blocks.{i}.attn1.norm_q.bias"] = checkpoint[ + f"{layer_prefix}.self_attention.q_layernorm.bias" + ] # Key normalization if f"{layer_prefix}.self_attention.k_layernorm.weight" in checkpoint: - state_dict[f"transformer_blocks.{i}.attn1.norm_k.weight"] = checkpoint[f"{layer_prefix}.self_attention.k_layernorm.weight"] - state_dict[f"transformer_blocks.{i}.attn1.norm_k.bias"] = checkpoint[f"{layer_prefix}.self_attention.k_layernorm.bias"] + state_dict[f"transformer_blocks.{i}.attn1.norm_k.weight"] = checkpoint[ + f"{layer_prefix}.self_attention.k_layernorm.weight" + ] + state_dict[f"transformer_blocks.{i}.attn1.norm_k.bias"] = checkpoint[ + f"{layer_prefix}.self_attention.k_layernorm.bias" + ] # Cross-attention key normalization if f"{layer_prefix}.self_attention.k_layernorm_xattn.weight" in checkpoint: - state_dict[f"transformer_blocks.{i}.attn1.norm_k_xattn.weight"] = checkpoint[f"{layer_prefix}.self_attention.k_layernorm_xattn.weight"] - state_dict[f"transformer_blocks.{i}.attn1.norm_k_xattn.bias"] = checkpoint[f"{layer_prefix}.self_attention.k_layernorm_xattn.bias"] + state_dict[f"transformer_blocks.{i}.attn1.norm_k_xattn.weight"] = checkpoint[ + f"{layer_prefix}.self_attention.k_layernorm_xattn.weight" + ] + state_dict[f"transformer_blocks.{i}.attn1.norm_k_xattn.bias"] = checkpoint[ + f"{layer_prefix}.self_attention.k_layernorm_xattn.bias" + ] # Cross-attention query normalization if f"{layer_prefix}.self_attention.q_layernorm_xattn.weight" in checkpoint: - state_dict[f"transformer_blocks.{i}.attn1.norm_q_xattn.weight"] = checkpoint[f"{layer_prefix}.self_attention.q_layernorm_xattn.weight"] - state_dict[f"transformer_blocks.{i}.attn1.norm_q_xattn.bias"] = checkpoint[f"{layer_prefix}.self_attention.q_layernorm_xattn.bias"] + state_dict[f"transformer_blocks.{i}.attn1.norm_q_xattn.weight"] = checkpoint[ + f"{layer_prefix}.self_attention.q_layernorm_xattn.weight" + ] + state_dict[f"transformer_blocks.{i}.attn1.norm_q_xattn.bias"] = checkpoint[ + f"{layer_prefix}.self_attention.q_layernorm_xattn.bias" + ] # QKV linear projections if f"{layer_prefix}.self_attention.linear_qkv.q.weight" in checkpoint: - state_dict[f"transformer_blocks.{i}.attn1.to_q.weight"] = checkpoint[f"{layer_prefix}.self_attention.linear_qkv.q.weight"] + state_dict[f"transformer_blocks.{i}.attn1.to_q.weight"] = checkpoint[ + f"{layer_prefix}.self_attention.linear_qkv.q.weight" + ] if f"{layer_prefix}.self_attention.linear_qkv.k.weight" in checkpoint: - state_dict[f"transformer_blocks.{i}.attn1.to_k.weight"] = checkpoint[f"{layer_prefix}.self_attention.linear_qkv.k.weight"] + state_dict[f"transformer_blocks.{i}.attn1.to_k.weight"] = checkpoint[ + f"{layer_prefix}.self_attention.linear_qkv.k.weight" + ] if f"{layer_prefix}.self_attention.linear_qkv.v.weight" in checkpoint: - state_dict[f"transformer_blocks.{i}.attn1.to_v.weight"] = checkpoint[f"{layer_prefix}.self_attention.linear_qkv.v.weight"] + state_dict[f"transformer_blocks.{i}.attn1.to_v.weight"] = checkpoint[ + f"{layer_prefix}.self_attention.linear_qkv.v.weight" + ] if f"{layer_prefix}.self_attention.linear_qkv.qx.weight" in checkpoint: - state_dict[f"transformer_blocks.{i}.attn1.to_q_xattn.weight"] = checkpoint[f"{layer_prefix}.self_attention.linear_qkv.qx.weight"] + state_dict[f"transformer_blocks.{i}.attn1.to_q_xattn.weight"] = checkpoint[ + f"{layer_prefix}.self_attention.linear_qkv.qx.weight" + ] # QKV layer norm if f"{layer_prefix}.self_attention.linear_qkv.layer_norm.weight" in checkpoint: - state_dict[f"transformer_blocks.{i}.attn1.qkv_norm.weight"] = checkpoint[f"{layer_prefix}.self_attention.linear_qkv.layer_norm.weight"] - state_dict[f"transformer_blocks.{i}.attn1.qkv_norm.bias"] = checkpoint[f"{layer_prefix}.self_attention.linear_qkv.layer_norm.bias"] + state_dict[f"transformer_blocks.{i}.attn1.qkv_norm.weight"] = checkpoint[ + f"{layer_prefix}.self_attention.linear_qkv.layer_norm.weight" + ] + state_dict[f"transformer_blocks.{i}.attn1.qkv_norm.bias"] = checkpoint[ + f"{layer_prefix}.self_attention.linear_qkv.layer_norm.bias" + ] # KV cross-attention if f"{layer_prefix}.self_attention.linear_kv_xattn.weight" in checkpoint: - state_dict[f"transformer_blocks.{i}.attn1.to_kv_xattn.weight"] = checkpoint[f"{layer_prefix}.self_attention.linear_kv_xattn.weight"] + state_dict[f"transformer_blocks.{i}.attn1.to_kv_xattn.weight"] = checkpoint[ + f"{layer_prefix}.self_attention.linear_kv_xattn.weight" + ] # Output projection if f"{layer_prefix}.self_attention.linear_proj.weight" in checkpoint: - state_dict[f"transformer_blocks.{i}.attn1.to_out.0.weight"] = checkpoint[f"{layer_prefix}.self_attention.linear_proj.weight"] + state_dict[f"transformer_blocks.{i}.attn1.to_out.0.weight"] = checkpoint[ + f"{layer_prefix}.self_attention.linear_proj.weight" + ] # Self-attention post normalization if f"{layer_prefix}.self_attn_post_norm.weight" in checkpoint: - state_dict[f"transformer_blocks.{i}.norm1.weight"] = checkpoint[f"{layer_prefix}.self_attn_post_norm.weight"] + state_dict[f"transformer_blocks.{i}.norm1.weight"] = checkpoint[ + f"{layer_prefix}.self_attn_post_norm.weight" + ] state_dict[f"transformer_blocks.{i}.norm1.bias"] = checkpoint[f"{layer_prefix}.self_attn_post_norm.bias"] # MLP components @@ -454,7 +520,9 @@ def convert_transformer_state_dict(checkpoint): # MLP FC1 (projection) if f"{layer_prefix}.mlp.linear_fc1.weight" in checkpoint: - state_dict[f"transformer_blocks.{i}.ff.net.0.proj.weight"] = checkpoint[f"{layer_prefix}.mlp.linear_fc1.weight"] + state_dict[f"transformer_blocks.{i}.ff.net.0.proj.weight"] = checkpoint[ + f"{layer_prefix}.mlp.linear_fc1.weight" + ] # MLP FC2 (projection) if f"{layer_prefix}.mlp.linear_fc2.weight" in checkpoint: @@ -644,6 +712,7 @@ def main(): except Exception as e: print(f"Error during conversion: {str(e)}") import traceback + traceback.print_exc() return 1 diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_magi.py b/src/diffusers/models/autoencoders/autoencoder_kl_magi.py index 11ce5edd0db6..a4c9a5eb81fd 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_magi.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_magi.py @@ -89,6 +89,7 @@ class ManualLayerNorm(nn.Module): """ Manual implementation of LayerNorm for better compatibility. """ + def __init__(self, normalized_shape, eps=1e-5): super().__init__() self.normalized_shape = normalized_shape @@ -105,6 +106,7 @@ class Mlp(nn.Module): """ MLP module used in the transformer architecture. """ + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0): super().__init__() out_features = out_features or in_features @@ -127,11 +129,12 @@ class Attention(nn.Module): """ Multi-head attention module. """ + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.0, proj_drop=0.0): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads - self.scale = qk_scale or head_dim ** -0.5 + self.scale = qk_scale or head_dim**-0.5 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) @@ -157,6 +160,7 @@ class Block(nn.Module): """ Transformer block with attention and MLP. """ + def __init__( self, dim, @@ -195,6 +199,7 @@ class PatchEmbed(nn.Module): """ Image to Patch Embedding for 3D data. """ + def __init__(self, video_size=224, video_length=16, patch_size=16, patch_length=1, in_chans=3, embed_dim=768): super().__init__() self.video_size = video_size @@ -207,17 +212,18 @@ def __init__(self, video_size=224, video_length=16, patch_size=16, patch_length= self.num_patches = self.grid_length * self.grid_size * self.grid_size self.proj = nn.Conv3d( - in_chans, embed_dim, + in_chans, + embed_dim, kernel_size=(patch_length, patch_size, patch_size), - stride=(patch_length, patch_size, patch_size) + stride=(patch_length, patch_size, patch_size), ) def forward(self, x): B, C, T, H, W = x.shape - assert H == self.video_size and W == self.video_size, \ + assert H == self.video_size and W == self.video_size, ( f"Input image size ({H}*{W}) doesn't match model ({self.video_size}*{self.video_size})." - assert T == self.video_length, \ - f"Input video length ({T}) doesn't match model ({self.video_length})." + ) + assert T == self.video_length, f"Input video length ({T}) doesn't match model ({self.video_length})." x = self.proj(x).flatten(2).transpose(1, 2) return x @@ -227,6 +233,7 @@ class ViTEncoder(nn.Module): """ Vision Transformer Encoder for MAGI-1 VAE. """ + def __init__( self, video_size=256, @@ -280,20 +287,22 @@ def __init__( # Transformer blocks dpr = [x.item() for x in torch.linspace(0, 0.0, depth)] # stochastic depth decay rule - self.blocks = nn.ModuleList([ - Block( - dim=embed_dim, - num_heads=num_heads, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - qk_scale=qk_scale, - drop=drop_rate, - attn_drop=attn_drop_rate, - drop_path=dpr[i], - norm_layer=norm_layer, - ) - for i in range(depth) - ]) + self.blocks = nn.ModuleList( + [ + Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer, + ) + for i in range(depth) + ] + ) self.norm = norm_layer(embed_dim) @@ -359,6 +368,7 @@ class ViTDecoder(nn.Module): """ Vision Transformer Decoder for MAGI-1 VAE. """ + def __init__( self, video_size=256, @@ -406,20 +416,22 @@ def __init__( # Transformer blocks dpr = [x.item() for x in torch.linspace(0, 0.0, depth)] # stochastic depth decay rule - self.blocks = nn.ModuleList([ - Block( - dim=embed_dim, - num_heads=num_heads, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - qk_scale=qk_scale, - drop=drop_rate, - attn_drop=attn_drop_rate, - drop_path=dpr[i], - norm_layer=norm_layer, - ) - for i in range(depth) - ]) + self.blocks = nn.ModuleList( + [ + Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer, + ) + for i in range(depth) + ] + ) self.norm = norm_layer(embed_dim) @@ -428,7 +440,7 @@ def __init__( embed_dim, in_chans, kernel_size=(patch_length, patch_size, patch_size), - stride=(patch_length, patch_size, patch_size) + stride=(patch_length, patch_size, patch_size), ) # Initialize weights @@ -491,8 +503,8 @@ class AutoencoderKLMagi(ModelMixin, ConfigMixin): """ Variational Autoencoder (VAE) model with KL loss for MAGI-1. - This model inherits from [`ModelMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic methods - implemented for all models (downloading, saving, loading, etc.) + This model inherits from [`ModelMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods implemented for all models (downloading, saving, loading, etc.) Parameters: in_channels (`int`, *optional*, defaults to 3): Number of channels in the input image. @@ -511,9 +523,9 @@ class AutoencoderKLMagi(ModelMixin, ConfigMixin): The component-wise standard deviation of the trained latent space computed using the first batch of the training set. This is used to scale the latent space to have unit variance when training the diffusion model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the - diffusion model. When decoding, the latents are scaled back to the original scale with the formula: - `z = 1 / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution - Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. + diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 + / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image + Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. temporal_downsample_factor (`Tuple[int]`, *optional*, defaults to (1, 2, 1, 1)): Tuple of temporal downsampling factors for each block. """ diff --git a/src/diffusers/models/transformers/transformer_magi.py b/src/diffusers/models/transformers/transformer_magi.py index 070f2921d522..1ea78053f349 100644 --- a/src/diffusers/models/transformers/transformer_magi.py +++ b/src/diffusers/models/transformers/transformer_magi.py @@ -44,8 +44,8 @@ class MagiAttention(nn.Module): """ A cross attention layer for MAGI-1. - This implements the specialized attention mechanism from the MAGI-1 model, - including query/key normalization and proper handling of rotary embeddings. + This implements the specialized attention mechanism from the MAGI-1 model, including query/key normalization and + proper handling of rotary embeddings. """ def __init__( @@ -324,9 +324,10 @@ class LearnableRotaryEmbedding(nn.Module): """ Learnable rotary position embeddings similar to the one used in MAGI-1. - This implementation is based on MAGI-1's LearnableRotaryEmbeddingCat class, - which creates rotary embeddings for 3D data (frames, height, width). + This implementation is based on MAGI-1's LearnableRotaryEmbeddingCat class, which creates rotary embeddings for 3D + data (frames, height, width). """ + def __init__( self, dim: int, @@ -394,17 +395,17 @@ def get_embed(self, shape: List[int]) -> torch.Tensor: # Compute embeddings for each dimension # Temporal dimension - t_emb = torch.outer(grid[:, 0], freqs[:self.dim//6]) + t_emb = torch.outer(grid[:, 0], freqs[: self.dim // 6]) t_sin = torch.sin(t_emb) t_cos = torch.cos(t_emb) # Height dimension - h_emb = torch.outer(grid[:, 1], freqs[:self.dim//6]) + h_emb = torch.outer(grid[:, 1], freqs[: self.dim // 6]) h_sin = torch.sin(h_emb) h_cos = torch.cos(h_emb) # Width dimension - w_emb = torch.outer(grid[:, 2], freqs[:self.dim//6]) + w_emb = torch.outer(grid[:, 2], freqs[: self.dim // 6]) w_sin = torch.sin(w_emb) w_cos = torch.cos(w_emb) @@ -430,8 +431,8 @@ class MagiTransformer3DModel(ModelMixin, ConfigMixin): """ Transformer model for MAGI-1. - This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods implemented - for all models (downloading, saving, loading, etc.) + This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods implemented for + all models (downloading, saving, loading, etc.) Parameters: sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): @@ -489,12 +490,7 @@ def __init__( self.time_embedding = TimestepEmbedding(time_embed_dim, time_embed_dim) # Input projection - self.input_proj = nn.Conv3d( - in_channels, - time_embed_dim, - kernel_size=patch_size, - stride=patch_size - ) + self.input_proj = nn.Conv3d(in_channels, time_embed_dim, kernel_size=patch_size, stride=patch_size) # Rotary position embeddings self.rotary_embedding = LearnableRotaryEmbedding( @@ -527,11 +523,7 @@ def __init__( # Output projection self.out_channels = out_channels - self.output_proj = nn.Conv3d( - time_embed_dim, - out_channels, - kernel_size=1 - ) + self.output_proj = nn.Conv3d(time_embed_dim, out_channels, kernel_size=1) self.gradient_checkpointing = False @@ -541,10 +533,10 @@ def set_attention_slice(self, slice_size): Args: slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): - When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. - If `"max"`, maximum amount of memory is saved by running only one slice at a time. - If a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, - `attention_head_dim` must be a multiple of `slice_size`. + When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If + `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. """ logger.warning( "Calling `set_attention_slice` is deprecated and will be removed in a future version. Use" @@ -586,8 +578,8 @@ def forward( Returns: `MagiTransformerOutput` or `tuple`: - If `return_dict` is `True`, a `MagiTransformerOutput` is returned, otherwise a tuple - `(sample,)` is returned where `sample` is the output tensor. + If `return_dict` is `True`, a `MagiTransformerOutput` is returned, otherwise a tuple `(sample,)` is + returned where `sample` is the output tensor. """ # 1. Input processing batch_size, channels, frames, height, width = hidden_states.shape @@ -612,7 +604,9 @@ def forward( # 4. Reshape for transformer blocks hidden_states = hidden_states.permute(0, 2, 3, 4, 1) # [B, C, F, H, W] -> [B, F, H, W, C] - hidden_states = hidden_states.reshape(batch_size, patched_frames * patched_height * patched_width, -1) # [B, F*H*W, C] + hidden_states = hidden_states.reshape( + batch_size, patched_frames * patched_height * patched_width, -1 + ) # [B, F*H*W, C] # 5. Add time embeddings if provided if time_embeds is not None: diff --git a/src/diffusers/pipelines/magi/pipeline_magi.py b/src/diffusers/pipelines/magi/pipeline_magi.py index 44b51738bae1..928a8a14fb58 100644 --- a/src/diffusers/pipelines/magi/pipeline_magi.py +++ b/src/diffusers/pipelines/magi/pipeline_magi.py @@ -31,6 +31,7 @@ if is_torch_xla_available(): + import torch_xla.core.xla_model as xm XLA_AVAILABLE = True else: @@ -49,7 +50,9 @@ >>> # Text-to-video generation >>> pipeline = MagiPipeline.from_pretrained("sand-ai/MAGI-1-4.5B", torch_dtype=torch.float16) >>> pipeline = pipeline.to("cuda") - >>> prompt = "A cat and a dog playing in a garden. The cat is chasing a butterfly while the dog is digging a hole." + >>> prompt = ( + ... "A cat and a dog playing in a garden. The cat is chasing a butterfly while the dog is digging a hole." + ... ) >>> output = pipeline( ... prompt=prompt, ... num_frames=24, @@ -124,7 +127,8 @@ def _encode_prompt( device: The device to place the encoded prompt on. num_videos_per_prompt (`int`): The number of videos that should be generated per prompt. do_classifier_free_guidance (`bool`): Whether to use classifier-free guidance or not. - negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the video generation. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the video generation. max_length (`int`, *optional*): The maximum length of the prompt to be encoded. Returns: @@ -211,7 +215,8 @@ def _prepare_latents( dtype (`torch.dtype`): The data type of the latents. device (`torch.device`): The device to place the latents on. generator (`torch.Generator`, *optional*): A generator to use for random number generation. - latents (`torch.Tensor`, *optional*): Pre-generated latent vectors. If not provided, latents will be generated randomly. + latents (`torch.Tensor`, *optional*): + Pre-generated latent vectors. If not provided, latents will be generated randomly. Returns: `torch.Tensor`: The prepared latent vectors. @@ -338,7 +343,7 @@ def _prepare_timesteps( if t_schedule_func == "sd3": # Apply quadratic transformation t = torch.linspace(0, 1, num_inference_steps + 1, device=device) - t = t ** 2 + t = t**2 # Apply SD3-style transformation def t_resolution_transform(x, shift_value=shift): @@ -356,7 +361,7 @@ def t_resolution_transform(x, shift_value=shift): elif t_schedule_func == "square": # Simple quadratic scheduling t = torch.linspace(0, 1, num_inference_steps + 1, device=device) - t = t ** 2 + t = t**2 return self.scheduler.timesteps elif t_schedule_func == "piecewise": @@ -375,8 +380,6 @@ def t_resolution_transform(x, shift_value=shift): # Default: use scheduler's default timesteps return timesteps - - @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( @@ -416,34 +419,52 @@ def __call__( num_frames (`int`, *optional*, defaults to 24): The number of video frames to generate. num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality video at the expense of slower inference. + The number of denoising steps. More denoising steps usually lead to a higher quality video at the + expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate videos that are closely linked to the text `prompt`, usually at the expense of lower video quality. + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate videos that are closely linked to the text `prompt`, + usually at the expense of lower video quality. negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the video generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + The prompt or prompts not to guide the video generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). num_videos_per_prompt (`int`, *optional*, defaults to 1): The number of videos to generate per prompt. eta (`float`, *optional*, defaults to 0.0): - Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): - One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. latents (`torch.Tensor`, *optional*): - Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for video generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will be generated by sampling using the supplied random `generator`. + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for video + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. output_type (`str`, *optional*, defaults to `"np"`): - The output format of the generate video. Choose between `np` for `numpy.array`, `pt` for `torch.Tensor` or `latent` to get the latent space output. + The output format of the generate video. Choose between `np` for `numpy.array`, `pt` for `torch.Tensor` + or `latent` to get the latent space output. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.magi.MagiPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): - A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. cross_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [`diffusers.cross_attention`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [`diffusers.cross_attention`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). chunk_size (`int`, *optional*, defaults to 24): The chunk size to use for autoregressive generation. Measured in frames. t_schedule_func (`str`, *optional*, defaults to "sd3"): @@ -455,7 +476,8 @@ def __call__( Returns: [`~pipelines.magi.MagiPipelineOutput`] or `tuple`: - If `return_dict` is `True`, [`~pipelines.magi.MagiPipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated frames. + If `return_dict` is `True`, [`~pipelines.magi.MagiPipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated frames. """ # 0. Default height and width to unet height = height or self.transformer.config.sample_size @@ -491,10 +513,7 @@ def __call__( # 4. Prepare timesteps timesteps = self._prepare_timesteps( - num_inference_steps, - device, - t_schedule_func=t_schedule_func, - shift=t_schedule_shift + num_inference_steps, device, t_schedule_func=t_schedule_func, shift=t_schedule_shift ) # 5. Prepare latent variables @@ -514,7 +533,9 @@ def __call__( ) # 6. Process in chunks for autoregressive generation - chunk_indices = self._get_chunk_indices(num_frames // self.vae_scale_factor_temporal, chunk_size // self.vae_scale_factor_temporal) + chunk_indices = self._get_chunk_indices( + num_frames // self.vae_scale_factor_temporal, chunk_size // self.vae_scale_factor_temporal + ) all_latents = [] # 7. Denoise the latents @@ -533,8 +554,13 @@ def __call__( # Initialize with noise chunk_latents = randn_tensor( - (batch_size * num_videos_per_prompt, num_channels_latents, chunk_frames, - height // self.vae_scale_factor_spatial, width // self.vae_scale_factor_spatial), + ( + batch_size * num_videos_per_prompt, + num_channels_latents, + chunk_frames, + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial, + ), generator=generator, device=device, dtype=prompt_embeds.dtype, @@ -544,13 +570,17 @@ def __call__( # Use previous chunk output as conditioning by copying overlapping frames if start_idx > 0 and chunk_idx > 0 and overlap_frames > 0: # Copy overlapping frames from previous chunk's output - chunk_latents[:, :, :overlap_frames, :, :] = all_latents[chunk_idx - 1][:, :, -overlap_frames:, :, :] + chunk_latents[:, :, :overlap_frames, :, :] = all_latents[chunk_idx - 1][ + :, :, -overlap_frames:, :, : + ] # Denoising loop for this chunk with self.progress_bar(total=len(timesteps)) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([chunk_latents] * 2) if do_classifier_free_guidance else chunk_latents + latent_model_input = ( + torch.cat([chunk_latents] * 2) if do_classifier_free_guidance else chunk_latents + ) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual diff --git a/src/diffusers/pipelines/magi/pipeline_magi_i2v.py b/src/diffusers/pipelines/magi/pipeline_magi_i2v.py index 40a66e21a5c7..79e70e147191 100644 --- a/src/diffusers/pipelines/magi/pipeline_magi_i2v.py +++ b/src/diffusers/pipelines/magi/pipeline_magi_i2v.py @@ -70,6 +70,7 @@ class MagiImageToVideoPipeline(DiffusionPipeline): scheduler (`FlowMatchEulerDiscreteScheduler`): A scheduler to be used in combination with `transformer` to denoise the encoded video latents. """ + model_cpu_offload_seq = "text_encoder->transformer->vae" _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] @@ -92,7 +93,9 @@ def __init__( ) self.vae_scale_factor_temporal = 2 ** (1 if hasattr(self.vae, "temporal_downsample") else 0) - self.vae_scale_factor_spatial = 2 ** (3 if hasattr(self.vae, "config") else 8) # Default to 8 for 3 downsamples + self.vae_scale_factor_spatial = 2 ** ( + 3 if hasattr(self.vae, "config") else 8 + ) # Default to 8 for 3 downsamples self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) def _encode_prompt( @@ -279,7 +282,8 @@ def _prepare_image_based_latents( dtype (`torch.dtype`): The data type of the latents. device (`torch.device`): The device to place the latents on. generator (`torch.Generator`, *optional*): A generator to use for random number generation. - latents (`torch.Tensor`, *optional*): Pre-generated latent vectors. If not provided, latents will be generated randomly. + latents (`torch.Tensor`, *optional*): + Pre-generated latent vectors. If not provided, latents will be generated randomly. image_latents (`torch.Tensor`, *optional*): Image latents for conditioning. Returns: @@ -366,34 +370,52 @@ def __call__( num_frames (`int`, *optional*, defaults to 24): The number of video frames to generate. num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality video at the expense of slower inference. + The number of denoising steps. More denoising steps usually lead to a higher quality video at the + expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate videos that are closely linked to the text `prompt`, usually at the expense of lower video quality. + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate videos that are closely linked to the text `prompt`, + usually at the expense of lower video quality. negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the video generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + The prompt or prompts not to guide the video generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). num_videos_per_prompt (`int`, *optional*, defaults to 1): The number of videos to generate per prompt. eta (`float`, *optional*, defaults to 0.0): - Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): - One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. latents (`torch.Tensor`, *optional*): - Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for video generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will be generated by sampling using the supplied random `generator`. + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for video + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. output_type (`str`, *optional*, defaults to `"np"`): - The output format of the generate video. Choose between `np` for `numpy.array`, `pt` for `torch.Tensor` or `latent` to get the latent space output. + The output format of the generate video. Choose between `np` for `numpy.array`, `pt` for `torch.Tensor` + or `latent` to get the latent space output. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.magi.MagiPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): - A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. cross_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [`diffusers.cross_attention`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [`diffusers.cross_attention`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). chunk_size (`int`, *optional*, defaults to 24): The chunk size to use for autoregressive generation. Measured in frames. @@ -401,7 +423,8 @@ def __call__( Returns: [`~pipelines.magi.MagiPipelineOutput`] or `tuple`: - If `return_dict` is `True`, [`~pipelines.magi.MagiPipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated frames. + If `return_dict` is `True`, [`~pipelines.magi.MagiPipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated frames. """ # 0. Default height and width to unet height = height or self.transformer.config.sample_size @@ -464,7 +487,9 @@ def __call__( ) # 7. Process in chunks for autoregressive generation - chunk_indices = self._get_chunk_indices(num_frames // self.vae_scale_factor_temporal, chunk_size // self.vae_scale_factor_temporal) + chunk_indices = self._get_chunk_indices( + num_frames // self.vae_scale_factor_temporal, chunk_size // self.vae_scale_factor_temporal + ) all_latents = [] # 8. Process each chunk @@ -479,8 +504,13 @@ def __call__( # This is a simplified version - in a real implementation, we would need to handle # the autoregressive conditioning properly chunk_latents = randn_tensor( - (batch_size * num_videos_per_prompt, num_channels_latents, chunk_frames, - height // self.vae_scale_factor_spatial, width // self.vae_scale_factor_spatial), + ( + batch_size * num_videos_per_prompt, + num_channels_latents, + chunk_frames, + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial, + ), generator=generator, device=device, dtype=prompt_embeds.dtype, @@ -491,7 +521,9 @@ def __call__( with self.progress_bar(total=len(timesteps)) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([chunk_latents] * 2) if do_classifier_free_guidance else chunk_latents + latent_model_input = ( + torch.cat([chunk_latents] * 2) if do_classifier_free_guidance else chunk_latents + ) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual diff --git a/src/diffusers/pipelines/magi/pipeline_magi_v2v.py b/src/diffusers/pipelines/magi/pipeline_magi_v2v.py index 8c7f7a821c0d..681aa1557b75 100644 --- a/src/diffusers/pipelines/magi/pipeline_magi_v2v.py +++ b/src/diffusers/pipelines/magi/pipeline_magi_v2v.py @@ -69,6 +69,7 @@ class MagiVideoToVideoPipeline(DiffusionPipeline): scheduler (`FlowMatchEulerDiscreteScheduler`): A scheduler to be used in combination with `transformer` to denoise the encoded video latents. """ + model_cpu_offload_seq = "text_encoder->transformer->vae" _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] @@ -91,7 +92,9 @@ def __init__( ) self.vae_scale_factor_temporal = 2 ** (1 if hasattr(self.vae, "temporal_downsample") else 0) - self.vae_scale_factor_spatial = 2 ** (3 if hasattr(self.vae, "config") else 8) # Default to 8 for 3 downsamples + self.vae_scale_factor_spatial = 2 ** ( + 3 if hasattr(self.vae, "config") else 8 + ) # Default to 8 for 3 downsamples self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) def _encode_prompt( @@ -281,7 +284,8 @@ def _prepare_video_based_latents( dtype (`torch.dtype`): The data type of the latents. device (`torch.device`): The device to place the latents on. generator (`torch.Generator`, *optional*): A generator to use for random number generation. - latents (`torch.Tensor`, *optional*): Pre-generated latent vectors. If not provided, latents will be generated randomly. + latents (`torch.Tensor`, *optional*): + Pre-generated latent vectors. If not provided, latents will be generated randomly. video_latents (`torch.Tensor`, *optional*): Video latents for conditioning. num_frames_to_condition (`int`, *optional*): Number of frames to use for conditioning. @@ -370,44 +374,64 @@ def __call__( num_frames (`int`, *optional*, defaults to 24): The number of video frames to generate. num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality video at the expense of slower inference. + The number of denoising steps. More denoising steps usually lead to a higher quality video at the + expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate videos that are closely linked to the text `prompt`, usually at the expense of lower video quality. + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate videos that are closely linked to the text `prompt`, + usually at the expense of lower video quality. negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the video generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + The prompt or prompts not to guide the video generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). num_videos_per_prompt (`int`, *optional*, defaults to 1): The number of videos to generate per prompt. eta (`float`, *optional*, defaults to 0.0): - Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): - One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. latents (`torch.Tensor`, *optional*): - Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for video generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will be generated by sampling using the supplied random `generator`. + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for video + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. output_type (`str`, *optional*, defaults to `"np"`): - The output format of the generate video. Choose between `np` for `numpy.array`, `pt` for `torch.Tensor` or `latent` to get the latent space output. + The output format of the generate video. Choose between `np` for `numpy.array`, `pt` for `torch.Tensor` + or `latent` to get the latent space output. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.magi.MagiPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): - A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. cross_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [`diffusers.cross_attention`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [`diffusers.cross_attention`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). chunk_size (`int`, *optional*, defaults to 24): The chunk size to use for autoregressive generation. Measured in frames. num_frames_to_condition (`int`, *optional*): - Number of frames from the input video to use for conditioning. If not provided, all frames from the input video will be used. + Number of frames from the input video to use for conditioning. If not provided, all frames from the + input video will be used. Examples: Returns: [`~pipelines.magi.MagiPipelineOutput`] or `tuple`: - If `return_dict` is `True`, [`~pipelines.magi.MagiPipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated frames. + If `return_dict` is `True`, [`~pipelines.magi.MagiPipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated frames. """ # 0. Default height and width to unet height = height or self.transformer.config.sample_size @@ -471,7 +495,9 @@ def __call__( ) # 7. Process in chunks for autoregressive generation - chunk_indices = self._get_chunk_indices(num_frames // self.vae_scale_factor_temporal, chunk_size // self.vae_scale_factor_temporal) + chunk_indices = self._get_chunk_indices( + num_frames // self.vae_scale_factor_temporal, chunk_size // self.vae_scale_factor_temporal + ) all_latents = [] # 8. Process each chunk @@ -486,8 +512,13 @@ def __call__( # This is a simplified version - in a real implementation, we would need to handle # the autoregressive conditioning properly chunk_latents = randn_tensor( - (batch_size * num_videos_per_prompt, num_channels_latents, chunk_frames, - height // self.vae_scale_factor_spatial, width // self.vae_scale_factor_spatial), + ( + batch_size * num_videos_per_prompt, + num_channels_latents, + chunk_frames, + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial, + ), generator=generator, device=device, dtype=prompt_embeds.dtype, @@ -498,7 +529,9 @@ def __call__( with self.progress_bar(total=len(timesteps)) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([chunk_latents] * 2) if do_classifier_free_guidance else chunk_latents + latent_model_input = ( + torch.cat([chunk_latents] * 2) if do_classifier_free_guidance else chunk_latents + ) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual diff --git a/src/diffusers/pipelines/magi/pipeline_output.py b/src/diffusers/pipelines/magi/pipeline_output.py index 1a71e788246a..f14a2050df33 100644 --- a/src/diffusers/pipelines/magi/pipeline_output.py +++ b/src/diffusers/pipelines/magi/pipeline_output.py @@ -28,7 +28,9 @@ class MagiPipelineOutput(BaseOutput): Args: frames (`torch.Tensor` or `np.ndarray`): - List of denoised frames from the diffusion process, as a NumPy array of shape `(batch_size, num_frames, height, width, num_channels)` or a PyTorch tensor of shape `(batch_size, num_channels, num_frames, height, width)`. + List of denoised frames from the diffusion process, as a NumPy array of shape `(batch_size, num_frames, + height, width, num_channels)` or a PyTorch tensor of shape `(batch_size, num_channels, num_frames, height, + width)`. """ frames: Union[torch.Tensor, np.ndarray, List[List[np.ndarray]]] diff --git a/tests/pipelines/magi/test_magi_video_to_video.py b/tests/pipelines/magi/test_magi_video_to_video.py index db3c6a0299a6..57d47f7456cd 100644 --- a/tests/pipelines/magi/test_magi_video_to_video.py +++ b/tests/pipelines/magi/test_magi_video_to_video.py @@ -134,10 +134,14 @@ def test_inference(self): def test_attention_slicing_forward_pass(self): pass - @unittest.skip("MagiVideoToVideoPipeline has to run in mixed precision. Casting the entire pipeline will result in errors") + @unittest.skip( + "MagiVideoToVideoPipeline has to run in mixed precision. Casting the entire pipeline will result in errors" + ) def test_model_cpu_offload_forward_pass(self): pass - @unittest.skip("MagiVideoToVideoPipeline has to run in mixed precision. Save/Load the entire pipeline in FP16 will result in errors") + @unittest.skip( + "MagiVideoToVideoPipeline has to run in mixed precision. Save/Load the entire pipeline in FP16 will result in errors" + ) def test_save_load_float16(self): pass diff --git a/tests/single_file/test_model_magi_autoencoder_single_file.py b/tests/single_file/test_model_magi_autoencoder_single_file.py index 07e5701dbb79..b3a3f7e0c48c 100644 --- a/tests/single_file/test_model_magi_autoencoder_single_file.py +++ b/tests/single_file/test_model_magi_autoencoder_single_file.py @@ -27,9 +27,7 @@ class AutoencoderKLMagiSingleFileTests(unittest.TestCase): model_class = AutoencoderKLMagi - ckpt_path = ( - "https://huggingface.co/sand-ai/MAGI-1/blob/main/vae/diffusion_pytorch_model.safetensors" - ) + ckpt_path = "https://huggingface.co/sand-ai/MAGI-1/blob/main/vae/diffusion_pytorch_model.safetensors" repo_id = "sand-ai/MAGI-1" @slow From 03d50e2628f81fa691b6e5afeb937d69e8384438 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sat, 14 Jun 2025 08:20:06 +0300 Subject: [PATCH 5/8] 2nd draft --- src/diffusers/pipelines/magi/__init__.py | 54 +- src/diffusers/pipelines/magi/pipeline_magi.py | 830 ++++++++-------- .../pipelines/magi/pipeline_magi_i2v.py | 934 +++++++++++------- .../pipelines/magi/pipeline_magi_v2v.py | 928 +++++++++-------- .../pipelines/magi/pipeline_output.py | 2 +- 5 files changed, 1511 insertions(+), 1237 deletions(-) diff --git a/src/diffusers/pipelines/magi/__init__.py b/src/diffusers/pipelines/magi/__init__.py index 2aa0f44f6f08..842593a9f24a 100644 --- a/src/diffusers/pipelines/magi/__init__.py +++ b/src/diffusers/pipelines/magi/__init__.py @@ -1,45 +1,44 @@ -# Copyright 2024 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - from typing import TYPE_CHECKING -from ....utils import ( +from ...utils import ( DIFFUSERS_SLOW_IMPORT, OptionalDependencyNotAvailable, _LazyModule, + get_objects_from_module, is_torch_available, is_transformers_available, ) +_dummy_objects = {} _import_structure = {} -if is_torch_available() and is_transformers_available(): - _import_structure["pipeline_magi"] = ["MagiPipeline"] - _import_structure["pipeline_magi_i2v"] = ["MagiImageToVideoPipeline"] - _import_structure["pipeline_magi_v2v"] = ["MagiVideoToVideoPipeline"] - _import_structure["pipeline_output"] = ["MagiPipelineOutput"] +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_magi"] = ["Magi1Pipeline"] + _import_structure["pipeline_magi_i2v"] = ["Magi1ImageToVideoPipeline"] + _import_structure["pipeline_magi_v2v"] = ["Magi1VideoToVideoPipeline"] + _import_structure["pipeline_output"] = ["Magi1PipelineOutput"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: - if is_torch_available() and is_transformers_available(): - from .pipeline_magi import MagiPipeline - from .pipeline_magi_i2v import MagiImageToVideoPipeline - from .pipeline_magi_v2v import MagiVideoToVideoPipeline - from .pipeline_output import MagiPipelineOutput + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: - pass + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_magi import Magi1Pipeline + from .pipeline_magi_i2v import Magi1ImageToVideoPipeline + from .pipeline_magi_v2v import Magi1VideoToVideoPipeline + from .pipeline_output import Magi1PipelineOutput + else: import sys @@ -49,3 +48,6 @@ _import_structure, module_spec=__spec__, ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/magi/pipeline_magi.py b/src/diffusers/pipelines/magi/pipeline_magi.py index 928a8a14fb58..555fe6b779b0 100644 --- a/src/diffusers/pipelines/magi/pipeline_magi.py +++ b/src/diffusers/pipelines/magi/pipeline_magi.py @@ -1,4 +1,4 @@ -# Copyright 2024 The HuggingFace Team. All rights reserved. +# Copyright 2025 The SandAI Team and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,22 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +import html +import re +from typing import Any, Callable, Dict, List, Optional, Union +import ftfy import torch from transformers import AutoTokenizer, UMT5EncoderModel -from ...models import AutoencoderKLMagi, MagiTransformer3DModel +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...loaders import Magi1LoraLoaderMixin +from ...models import AutoencoderKLMagi1, Magi1Transformer3DModel from ...schedulers import FlowMatchEulerDiscreteScheduler -from ...utils import ( - is_torch_xla_available, - logging, - randn_tensor, - replace_example_docstring, -) +from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor from ...video_processor import VideoProcessor from ..pipeline_utils import DiffusionPipeline -from .pipeline_output import MagiPipelineOutput +from .pipeline_output import Magi1PipelineOutput if is_torch_xla_available(): @@ -39,50 +40,77 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name +if is_ftfy_available(): + import ftfy EXAMPLE_DOC_STRING = """ Examples: ```python >>> import torch - >>> from diffusers import MagiPipeline >>> from diffusers.utils import export_to_video + >>> from diffusers import AutoencoderKLMagi1, Magi1Pipeline + >>> from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler - >>> # Text-to-video generation - >>> pipeline = MagiPipeline.from_pretrained("sand-ai/MAGI-1-4.5B", torch_dtype=torch.float16) - >>> pipeline = pipeline.to("cuda") - >>> prompt = ( - ... "A cat and a dog playing in a garden. The cat is chasing a butterfly while the dog is digging a hole." - ... ) - >>> output = pipeline( + >>> model_id = "SandAI/Magi1-T2V-14B-480P-Diffusers" + >>> vae = AutoencoderKLMagi1.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32) + >>> pipe = Magi1Pipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16) + >>> flow_shift = 5.0 # 5.0 for 720P, 3.0 for 480P + >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=flow_shift) + >>> pipe.to("cuda") + + >>> prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window." + >>> negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" + + >>> output = pipe( ... prompt=prompt, - ... num_frames=24, + ... negative_prompt=negative_prompt, ... height=720, - ... width=720, - ... t_schedule_func="sd3", + ... width=1280, + ... num_frames=81, + ... guidance_scale=5.0, ... ).frames[0] - >>> export_to_video(output, "magi_output.mp4", fps=8) + >>> export_to_video(output, "output.mp4", fps=16) ``` """ -class MagiPipeline(DiffusionPipeline): +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + + +def prompt_clean(text): + text = whitespace_clean(basic_clean(text)) + return text + + +class Magi1Pipeline(DiffusionPipeline, Magi1LoraLoaderMixin): r""" - Pipeline for text-to-video generation using MAGI-1. + Pipeline for text-to-video generation using Magi1. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). Args: - tokenizer (`AutoTokenizer`): - Tokenizer for the text encoder. - text_encoder (`UMT5EncoderModel`): - Text encoder for conditioning. - transformer (`MagiTransformer3DModel`): - Conditional Transformer to denoise the latent video. - vae (`AutoencoderKLMagi`): + tokenizer ([`T5Tokenizer`]): + Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer), + specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. + text_encoder ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. + transformer ([`Magi1Transformer3DModel`]): + Conditional Transformer to denoise the input latents. + scheduler ([`UniPCMultistepScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLMagi1`]): Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. - scheduler (`FlowMatchEulerDiscreteScheduler`): - A scheduler to be used in combination with `transformer` to denoise the encoded video latents. """ model_cpu_offload_seq = "text_encoder->transformer->vae" @@ -92,8 +120,8 @@ def __init__( self, tokenizer: AutoTokenizer, text_encoder: UMT5EncoderModel, - transformer: MagiTransformer3DModel, - vae: AutoencoderKLMagi, + transformer: Magi1Transformer3DModel, + vae: AutoencoderKLMagi1, scheduler: FlowMatchEulerDiscreteScheduler, ): super().__init__() @@ -106,188 +134,150 @@ def __init__( scheduler=scheduler, ) - self.vae_scale_factor_temporal = vae.temporal_downsample_factor - self.vae_scale_factor_spatial = vae.spatial_downsample_factor + self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4 + self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) - def _encode_prompt( + def _get_t5_prompt_embeds( self, - prompt: Union[str, List[str]], - device: torch.device, - num_videos_per_prompt: int, - do_classifier_free_guidance: bool, - negative_prompt: Optional[Union[str, List[str]]] = None, - max_length: Optional[int] = None, - ) -> torch.Tensor: - """ - Encodes the prompt into text encoder hidden states. - - Args: - prompt (`str` or `List[str]`): The prompt or prompts to guide the video generation. - device: The device to place the encoded prompt on. - num_videos_per_prompt (`int`): The number of videos that should be generated per prompt. - do_classifier_free_guidance (`bool`): Whether to use classifier-free guidance or not. - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the video generation. - max_length (`int`, *optional*): The maximum length of the prompt to be encoded. - - Returns: - `torch.Tensor`: A tensor containing the encoded text embeddings. - """ - batch_size = len(prompt) if isinstance(prompt, list) else 1 + prompt: Union[str, List[str]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype - # Default to 77 if not specified - if max_length is None: - max_length = self.tokenizer.model_max_length + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt = [prompt_clean(u) for u in prompt] + batch_size = len(prompt) text_inputs = self.tokenizer( prompt, padding="max_length", - max_length=max_length, + max_length=max_sequence_length, truncation=True, + add_special_tokens=True, + return_attention_mask=True, return_tensors="pt", ) - text_input_ids = text_inputs.input_ids.to(device) - attention_mask = text_inputs.attention_mask.to(device) - - prompt_embeds = self.text_encoder(text_input_ids, attention_mask=attention_mask)[0] - - # Process special tokens if present (following MAGI-1's approach) - # In diffusers style, we don't need to explicitly handle special tokens as they're part of the tokenizer - # But we can ensure proper mask handling similar to MAGI-1 - # Shape of prompt_embeds: [batch_size, seq_len, hidden_size] - - # Duplicate text embeddings for each generation per prompt - prompt_embeds = prompt_embeds.repeat_interleave(num_videos_per_prompt, dim=0) - attention_mask = attention_mask.repeat_interleave(num_videos_per_prompt, dim=0) - - # Get unconditional embeddings for classifier-free guidance - if do_classifier_free_guidance: - uncond_tokens = [""] * batch_size - uncond_input = self.tokenizer( - uncond_tokens, - padding="max_length", - max_length=max_length, - truncation=True, - return_tensors="pt", - ) - uncond_input_ids = uncond_input.input_ids.to(device) - uncond_attention_mask = uncond_input.attention_mask.to(device) - negative_prompt_embeds = self.text_encoder(uncond_input_ids, attention_mask=uncond_attention_mask)[0] - - # Duplicate unconditional embeddings for each generation per prompt - negative_prompt_embeds = negative_prompt_embeds.repeat_interleave(num_videos_per_prompt, dim=0) - uncond_attention_mask = uncond_attention_mask.repeat_interleave(num_videos_per_prompt, dim=0) - - # Ensure null embeddings have proper attention mask handling (similar to MAGI-1's null_emb_masks) - # In MAGI-1, they set attention to first 50 tokens and zero for the rest - if uncond_attention_mask.shape[1] > 50: - uncond_attention_mask[:, :50] = 1 - uncond_attention_mask[:, 50:] = 0 + text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask + seq_lens = mask.gt(0).sum(dim=1).long() + + prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 + ) - # Concatenate unconditional and text embeddings for classifier-free guidance - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) - attention_mask = torch.cat([uncond_attention_mask, attention_mask]) + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) return prompt_embeds - def _prepare_latents( + def encode_prompt( self, - batch_size: int, - num_channels_latents: int, - num_frames: int, - height: int, - width: int, - dtype: torch.dtype, - device: torch.device, - generator: Optional[torch.Generator] = None, - latents: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - """ - Prepare latents for diffusion. + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. Args: - batch_size (`int`): The batch size. - num_channels_latents (`int`): The number of channels in the latent space. - num_frames (`int`): The number of frames to generate. - height (`int`): The height of the video. - width (`int`): The width of the video. - dtype (`torch.dtype`): The data type of the latents. - device (`torch.device`): The device to place the latents on. - generator (`torch.Generator`, *optional*): A generator to use for random number generation. - latents (`torch.Tensor`, *optional*): - Pre-generated latent vectors. If not provided, latents will be generated randomly. - - Returns: - `torch.Tensor`: The prepared latent vectors. + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype """ - shape = ( - batch_size, - num_channels_latents, - num_frames // self.vae_scale_factor_temporal, - height // self.vae_scale_factor_spatial, - width // self.vae_scale_factor_spatial, - ) + device = device or self._execution_device - if latents is None: - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) else: - if latents.shape != shape: - raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") - latents = latents.to(device) + batch_size = prompt_embeds.shape[0] - # Scale the initial noise by the standard deviation required by the scheduler - latents = latents * self.scheduler.init_noise_sigma + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) - return latents + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt - def _get_chunk_indices(self, num_frames: int, chunk_size: int = 24) -> List[Tuple[int, int]]: - """ - Get the indices for processing video in chunks. + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) - Args: - num_frames (`int`): Total number of frames. - chunk_size (`int`, *optional*, defaults to 24): Size of each chunk. + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) - Returns: - `List[Tuple[int, int]]`: List of (start_idx, end_idx) tuples for each chunk. - """ - chunks = [] - for i in range(0, num_frames, chunk_size): - chunks.append((i, min(i + chunk_size, num_frames))) - return chunks + return prompt_embeds, negative_prompt_embeds def check_inputs( self, prompt, + negative_prompt, height, width, - callback_steps, - negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, ): - """ - Validate the inputs for the pipeline. - - Args: - prompt (`str` or `List[str]`): The prompt or prompts to guide generation. - height (`int`): The height in pixels of the generated video. - width (`int`): The width in pixels of the generated video. - callback_steps (`int`): The frequency at which the callback function is called. - negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide generation. - prompt_embeds (`torch.Tensor`, *optional*): Pre-computed text embeddings. - negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-computed negative text embeddings. - """ - if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( - f"`callback_steps` has to be a positive integer but is {callback_steps} of type" - f" {type(callback_steps)}." + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) if prompt is not None and prompt_embeds is not None: @@ -295,199 +285,197 @@ def check_inputs( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) + elif negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to" + " only forward one of the two." + ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif negative_prompt is not None and ( + not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list) + ): + raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") - if negative_prompt is not None and negative_prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" - f" {negative_prompt_embeds}. Please make sure to only forward one of the two." - ) - - if prompt_embeds is not None and negative_prompt_embeds is not None: - if prompt_embeds.shape != negative_prompt_embeds.shape: - raise ValueError( - f"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" - f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" - f" {negative_prompt_embeds.shape}." - ) - - def _prepare_timesteps( + def prepare_latents( self, - num_inference_steps: int, - device: torch.device, - t_schedule_func: str = "sd3", - shift: float = 3.0, + batch_size: int, + num_channels_latents: int = 16, + height: int = 480, + width: int = 832, + num_frames: int = 81, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, ) -> torch.Tensor: - """ - Prepare timesteps for diffusion process, with scheduling options similar to MAGI-1. - - Args: - num_inference_steps (`int`): Number of diffusion steps. - device (`torch.device`): Device to place timesteps on. - t_schedule_func (`str`, optional, defaults to "sd3"): - Timestep scheduling function. Options: "sd3", "square", "piecewise", "linear". - shift (`float`, optional, defaults to 3.0): Shift parameter for sd3 scheduler. - - Returns: - `torch.Tensor`: Prepared timesteps. - """ - self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps = self.scheduler.timesteps - - # Apply custom scheduling similar to MAGI-1 if needed - if t_schedule_func == "sd3": - # Apply quadratic transformation - t = torch.linspace(0, 1, num_inference_steps + 1, device=device) - t = t**2 + if latents is not None: + return latents.to(device=device, dtype=dtype) - # Apply SD3-style transformation - def t_resolution_transform(x, shift_value=shift): - assert shift_value >= 1.0, "shift should >=1" - shift_inv = 1.0 / shift_value - return shift_inv * x / (1 + (shift_inv - 1) * x) + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + shape = ( + batch_size, + num_channels_latents, + num_latent_frames, + int(height) // self.vae_scale_factor_spatial, + int(width) // self.vae_scale_factor_spatial, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) - t = t_resolution_transform(t, shift) + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + return latents - # Map to scheduler timesteps - # Note: This is a simplified approach - in a full implementation, - # we would need to properly map these values to the scheduler's timesteps - return self.scheduler.timesteps + @property + def guidance_scale(self): + return self._guidance_scale - elif t_schedule_func == "square": - # Simple quadratic scheduling - t = torch.linspace(0, 1, num_inference_steps + 1, device=device) - t = t**2 - return self.scheduler.timesteps + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 - elif t_schedule_func == "piecewise": - # Piecewise scheduling as in MAGI-1 - t = torch.linspace(0, 1, num_inference_steps + 1, device=device) + @property + def num_timesteps(self): + return self._num_timesteps - # Apply piecewise transformation - mask = t < 0.875 - t_transformed = torch.zeros_like(t) - t_transformed[mask] = t[mask] * (0.5 / 0.875) - t_transformed[~mask] = 0.5 + (t[~mask] - 0.875) * (0.5 / (1 - 0.875)) + @property + def current_timestep(self): + return self._current_timestep - # Map to scheduler timesteps (simplified) - return self.scheduler.timesteps + @property + def interrupt(self): + return self._interrupt - # Default: use scheduler's default timesteps - return timesteps + @property + def attention_kwargs(self): + return self._attention_kwargs @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, - prompt: Union[str, List[str]], - height: Optional[int] = 720, - width: Optional[int] = 720, - num_frames: Optional[int] = 24, + prompt: Union[str, List[str]] = None, + negative_prompt: Union[str, List[str]] = None, + height: int = 480, + width: int = 832, + num_frames: int = 81, num_inference_steps: int = 50, - guidance_scale: float = 7.5, - negative_prompt: Optional[Union[str, List[str]]] = None, + guidance_scale: float = 5.0, num_videos_per_prompt: Optional[int] = 1, - eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, output_type: Optional[str] = "np", return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, - callback_steps: int = 1, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - chunk_size: int = 24, - t_schedule_func: str = "sd3", - t_schedule_shift: float = 3.0, - ) -> Union[MagiPipelineOutput, Tuple]: - """ - Function invoked when calling the pipeline for generation. + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + The call function to the pipeline for generation. Args: - prompt (`str` or `List[str]`): - The prompt or prompts to guide the video generation. - height (`int`, *optional*, defaults to 720): - The height in pixels of the generated video. - width (`int`, *optional*, defaults to 720): - The width in pixels of the generated video. - num_frames (`int`, *optional*, defaults to 24): - The number of video frames to generate. - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality video at the - expense of slower inference. - guidance_scale (`float`, *optional*, defaults to 7.5): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate videos that are closely linked to the text `prompt`, - usually at the expense of lower video quality. + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, pass `prompt_embeds` instead. negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the video generation. If not defined, one has to pass - `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is - less than `1`). + The prompt or prompts to avoid during image generation. If not defined, pass `negative_prompt_embeds` + instead. Ignored when not using guidance (`guidance_scale` < `1`). + height (`int`, defaults to `480`): + The height in pixels of the generated image. + width (`int`, defaults to `832`): + The width in pixels of the generated image. + num_frames (`int`, defaults to `81`): + The number of frames in the generated video. + num_inference_steps (`int`, defaults to `50`): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, defaults to `5.0`): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. num_videos_per_prompt (`int`, *optional*, defaults to 1): - The number of videos to generate per prompt. - eta (`float`, *optional*, defaults to 0.0): - Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to - [`schedulers.DDIMScheduler`], will be ignored for others. + The number of images to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): - One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) - to make generation deterministic. + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. latents (`torch.Tensor`, *optional*): - Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for video + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor will be generated by sampling using the supplied random `generator`. + tensor is generated by sampling using the supplied random `generator`. prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. output_type (`str`, *optional*, defaults to `"np"`): - The output format of the generate video. Choose between `np` for `numpy.array`, `pt` for `torch.Tensor` - or `latent` to get the latent space output. + The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.magi.MagiPipelineOutput`] instead of a plain tuple. - callback (`Callable`, *optional*): - A function that will be called every `callback_steps` steps during inference. The function will be - called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function will be called. If not specified, the callback will be - called at every step. - cross_attention_kwargs (`dict`, *optional*): + Whether or not to return a [`WanPipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in - [`diffusers.cross_attention`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). - chunk_size (`int`, *optional*, defaults to 24): - The chunk size to use for autoregressive generation. Measured in frames. - t_schedule_func (`str`, *optional*, defaults to "sd3"): - Timestep scheduling function. Options: "sd3", "square", "piecewise", "linear". - t_schedule_shift (`float`, *optional*, defaults to 3.0): - Shift parameter for sd3 scheduler. + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, defaults to `512`): + The maximum sequence length of the text encoder. If the prompt is longer than this, it will be + truncated. If the prompt is shorter, it will be padded to this length. Examples: Returns: - [`~pipelines.magi.MagiPipelineOutput`] or `tuple`: - If `return_dict` is `True`, [`~pipelines.magi.MagiPipelineOutput`] is returned, otherwise a `tuple` is - returned where the first element is a list with the generated frames. + [`~Magi1PipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`Magi1PipelineOutput`] is returned, otherwise a `tuple` is returned where + the first element is a list with the generated images and the second element is a list of `bool`s + indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. """ - # 0. Default height and width to unet - height = height or self.transformer.config.sample_size - width = width or self.transformer.config.sample_size + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs # 1. Check inputs. Raise error if not correct self.check_inputs( - prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds + prompt, + negative_prompt, + height, + width, + prompt_embeds, + negative_prompt_embeds, + callback_on_step_end_tensor_inputs, ) + if num_frames % self.vae_scale_factor_temporal != 1: + logger.warning( + f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number." + ) + num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 + num_frames = max(num_frames, 1) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + device = self._execution_device + # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 @@ -496,140 +484,114 @@ def __call__( else: batch_size = prompt_embeds.shape[0] - device = self._execution_device - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 - # 3. Encode input prompt - prompt_embeds = self._encode_prompt( - prompt, - device, - num_videos_per_prompt, - do_classifier_free_guidance, - negative_prompt, + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, ) + transformer_dtype = self.transformer.dtype + prompt_embeds = prompt_embeds.to(transformer_dtype) + if negative_prompt_embeds is not None: + negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) + # 4. Prepare timesteps - timesteps = self._prepare_timesteps( - num_inference_steps, device, t_schedule_func=t_schedule_func, shift=t_schedule_shift - ) + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps # 5. Prepare latent variables - num_channels_latents = self.transformer.in_channels - - # Regular text-to-video case - latents = self._prepare_latents( + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( batch_size * num_videos_per_prompt, num_channels_latents, - num_frames, height, width, - prompt_embeds.dtype, + num_frames, + torch.float32, device, generator, latents, ) - # 6. Process in chunks for autoregressive generation - chunk_indices = self._get_chunk_indices( - num_frames // self.vae_scale_factor_temporal, chunk_size // self.vae_scale_factor_temporal - ) - all_latents = [] - - # 7. Denoise the latents - for chunk_idx, (start_idx, end_idx) in enumerate(chunk_indices): - # Extract the current chunk - chunk_frames = end_idx - start_idx - if chunk_idx == 0: - # For the first chunk, use the initial latents - chunk_latents = latents[:, :, start_idx:end_idx, :, :] - else: - # For subsequent chunks, implement proper autoregressive conditioning - # In MAGI-1, each chunk conditions the next in an autoregressive manner - # We use the previous chunk's output as conditioning for the current chunk - # Calculate overlap for conditioning - overlap_frames = min(self.vae_scale_factor_temporal, start_idx) - - # Initialize with noise - chunk_latents = randn_tensor( - ( - batch_size * num_videos_per_prompt, - num_channels_latents, - chunk_frames, - height // self.vae_scale_factor_spatial, - width // self.vae_scale_factor_spatial, - ), - generator=generator, - device=device, - dtype=prompt_embeds.dtype, - ) - chunk_latents = chunk_latents * self.scheduler.init_noise_sigma - - # Use previous chunk output as conditioning by copying overlapping frames - if start_idx > 0 and chunk_idx > 0 and overlap_frames > 0: - # Copy overlapping frames from previous chunk's output - chunk_latents[:, :, :overlap_frames, :, :] = all_latents[chunk_idx - 1][ - :, :, -overlap_frames:, :, : - ] - - # Denoising loop for this chunk - with self.progress_bar(total=len(timesteps)) as progress_bar: - for i, t in enumerate(timesteps): - # expand the latents if we are doing classifier free guidance - latent_model_input = ( - torch.cat([chunk_latents] * 2) if do_classifier_free_guidance else chunk_latents - ) - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - - # predict the noise residual - noise_pred = self.transformer( - latent_model_input, - timesteps=torch.tensor([t], device=chunk_latents.device), - encoder_hidden_states=prompt_embeds, - cross_attention_kwargs=cross_attention_kwargs, - ).sample - - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - - # compute the previous noisy sample x_t -> x_t-1 - chunk_latents = self.scheduler.step(noise_pred, t, chunk_latents).prev_sample - - # call the callback, if provided - if callback is not None and chunk_idx == 0 and i % callback_steps == 0: - callback(i, t, chunk_latents) - + # 6. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + latent_model_input = latents.to(transformer_dtype) + timestep = t.expand(latents.shape[0]) + + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + + if self.do_classifier_free_guidance: + noise_uncond = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() - if XLA_AVAILABLE: - xm.mark_step() - - all_latents.append(chunk_latents) + if XLA_AVAILABLE: + xm.mark_step() - # 8. Concatenate all chunks - latents = torch.cat(all_latents, dim=2) + self._current_timestep = None - # 9. Post-processing - if output_type == "latent": - video = latents + if not output_type == "latent": + latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents / latents_std + latents_mean + video = self.vae.decode(latents, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) else: - # Decode the latents - latents = 1 / self.vae.scaling_factor * latents - video = self.vae.decode(latents).sample - video = (video / 2 + 0.5).clamp(0, 1) - - # Convert to the desired output format - if output_type == "pt": - video = video - else: - video = video.cpu().permute(0, 2, 3, 4, 1).float().numpy() - - # 10. Return output + video = latents + + # Offload all models + self.maybe_free_model_hooks() + if not return_dict: return (video,) - return MagiPipelineOutput(frames=video) + return Magi1PipelineOutput(frames=video) diff --git a/src/diffusers/pipelines/magi/pipeline_magi_i2v.py b/src/diffusers/pipelines/magi/pipeline_magi_i2v.py index 79e70e147191..667e9467fda4 100644 --- a/src/diffusers/pipelines/magi/pipeline_magi_i2v.py +++ b/src/diffusers/pipelines/magi/pipeline_magi_i2v.py @@ -1,4 +1,4 @@ -# Copyright 2024 The HuggingFace Team. All rights reserved. +# Copyright 2025 The SandAI Team and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,181 +12,338 @@ # See the License for the specific language governing permissions and # limitations under the License. +import html from typing import Any, Callable, Dict, List, Optional, Tuple, Union import PIL +import regex as re import torch -from transformers import AutoTokenizer, UMT5EncoderModel +from transformers import AutoTokenizer, CLIPImageProcessor, CLIPVisionModel, UMT5EncoderModel -from ...models import AutoencoderKLMagi, MagiTransformer3DModel +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput +from ...loaders import Magi1LoraLoaderMixin +from ...models import AutoencoderKLMagi1, Magi1Transformer3DModel from ...schedulers import FlowMatchEulerDiscreteScheduler -from ...utils import ( - logging, - randn_tensor, - replace_example_docstring, -) +from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor from ...video_processor import VideoProcessor from ..pipeline_utils import DiffusionPipeline -from .pipeline_output import MagiPipelineOutput +from .pipeline_output import Magi1PipelineOutput +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name +if is_ftfy_available(): + import ftfy EXAMPLE_DOC_STRING = """ Examples: ```python >>> import torch - >>> from diffusers import MagiImageToVideoPipeline + >>> import numpy as np + >>> from diffusers import AutoencoderKLMagi1, Magi1ImageToVideoPipeline >>> from diffusers.utils import export_to_video, load_image - - >>> pipeline = MagiImageToVideoPipeline.from_pretrained("sand-ai/MAGI-1-4.5B", torch_dtype=torch.float16) - >>> pipeline = pipeline.to("cuda") - - >>> image = load_image("path/to/image.jpg") - >>> prompt = "A cat playing in a garden. The cat is chasing a butterfly." - >>> output = pipeline(prompt=prompt, image=image, num_frames=24, height=720, width=720).frames[0] - >>> export_to_video(output, "magi_i2v_output.mp4", fps=8) + >>> from transformers import CLIPVisionModel + + >>> model_id = "SandAI/Magi1-I2V-14B-480P-Diffusers" + >>> image_encoder = CLIPVisionModel.from_pretrained( + ... model_id, subfolder="image_encoder", torch_dtype=torch.float32 + ... ) + >>> vae = AutoencoderKLMagi1.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32) + >>> pipe = Magi1ImageToVideoPipeline.from_pretrained( + ... model_id, vae=vae, image_encoder=image_encoder, torch_dtype=torch.bfloat16 + ... ) + >>> pipe.to("cuda") + + >>> image = load_image( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg" + ... ) + >>> max_area = 480 * 832 + >>> aspect_ratio = image.height / image.width + >>> mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1] + >>> height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value + >>> width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value + >>> image = image.resize((width, height)) + >>> prompt = ( + ... "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in " + ... "the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot." + ... ) + >>> negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" + + >>> output = pipe( + ... image=image, + ... prompt=prompt, + ... negative_prompt=negative_prompt, + ... height=height, + ... width=width, + ... num_frames=81, + ... guidance_scale=5.0, + ... ).frames[0] + >>> export_to_video(output, "output.mp4", fps=16) ``` """ -class MagiImageToVideoPipeline(DiffusionPipeline): +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + + +def prompt_clean(text): + text = whitespace_clean(basic_clean(text)) + return text + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class Magi1ImageToVideoPipeline(DiffusionPipeline, Magi1LoraLoaderMixin): r""" - Pipeline for image-to-video generation using MAGI-1. + Pipeline for image-to-video generation using Magi1. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). Args: - tokenizer (`AutoTokenizer`): - Tokenizer for the text encoder. - text_encoder (`UMT5EncoderModel`): - Text encoder for conditioning. - transformer (`MagiTransformer3DModel`): - Conditional Transformer to denoise the latent video. - vae (`AutoencoderKLMagi`): + tokenizer ([`T5Tokenizer`]): + Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer), + specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. + text_encoder ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. + image_encoder ([`CLIPVisionModel`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPVisionModel), specifically + the + [clip-vit-huge-patch14](https://github.com/mlfoundations/open_clip/blob/main/docs/PRETRAINED.md#vit-h14-xlm-roberta-large) + variant. + transformer ([`Magi1Transformer3DModel`]): + Conditional Transformer to denoise the input latents. + scheduler ([`UniPCMultistepScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLMagi1`]): Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. - scheduler (`FlowMatchEulerDiscreteScheduler`): - A scheduler to be used in combination with `transformer` to denoise the encoded video latents. """ - model_cpu_offload_seq = "text_encoder->transformer->vae" + model_cpu_offload_seq = "text_encoder->image_encoder->transformer->vae" _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] def __init__( self, tokenizer: AutoTokenizer, text_encoder: UMT5EncoderModel, - transformer: MagiTransformer3DModel, - vae: AutoencoderKLMagi, + image_encoder: CLIPVisionModel, + image_processor: CLIPImageProcessor, + transformer: Magi1Transformer3DModel, + vae: AutoencoderKLMagi1, scheduler: FlowMatchEulerDiscreteScheduler, ): super().__init__() self.register_modules( - tokenizer=tokenizer, + vae=vae, text_encoder=text_encoder, + tokenizer=tokenizer, + image_encoder=image_encoder, transformer=transformer, - vae=vae, scheduler=scheduler, + image_processor=image_processor, ) - self.vae_scale_factor_temporal = 2 ** (1 if hasattr(self.vae, "temporal_downsample") else 0) - self.vae_scale_factor_spatial = 2 ** ( - 3 if hasattr(self.vae, "config") else 8 - ) # Default to 8 for 3 downsamples + self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4 + self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + self.image_processor = image_processor - def _encode_prompt( + def _get_t5_prompt_embeds( self, - prompt: Union[str, List[str]], - device: torch.device, - num_videos_per_prompt: int, - do_classifier_free_guidance: bool, - negative_prompt: Optional[Union[str, List[str]]] = None, + prompt: Union[str, List[str]] = None, + num_videos_per_prompt: int = 1, max_sequence_length: int = 512, - ) -> torch.Tensor: - """ - Encodes the prompt into text encoder hidden states. - - Args: - prompt (`str` or `List[str]`): - prompt to be encoded - device (`torch.device`): - torch device - num_videos_per_prompt (`int`): - number of videos that should be generated per prompt - do_classifier_free_guidance (`bool`): - whether to use classifier free guidance or not - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the video generation. Ignored when not using guidance (i.e., ignored - if `guidance_scale` is less than `1`). - max_sequence_length (`int`, *optional*, defaults to 512): - The maximum length of the sequence to be processed by the text encoder. + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype - Returns: - `torch.Tensor`: text embeddings - """ - batch_size = len(prompt) if isinstance(prompt, list) else 1 + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt = [prompt_clean(u) for u in prompt] + batch_size = len(prompt) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=max_sequence_length, truncation=True, + add_special_tokens=True, + return_attention_mask=True, return_tensors="pt", ) - text_input_ids = text_inputs.input_ids.to(device) - - prompt_embeds = self.text_encoder(text_input_ids).last_hidden_state - - # duplicate text embeddings for each generation per prompt - prompt_embeds = prompt_embeds.repeat_interleave(num_videos_per_prompt, dim=0) - - # get unconditional embeddings for classifier-free guidance - if do_classifier_free_guidance: - uncond_tokens = [""] * batch_size - max_length = text_inputs.input_ids.shape[-1] - uncond_input = self.tokenizer( - uncond_tokens, - padding="max_length", - max_length=max_length, - truncation=True, - return_tensors="pt", + text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask + seq_lens = mask.gt(0).sum(dim=1).long() + + prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 + ) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds + + def encode_image( + self, + image: PipelineImageInput, + device: Optional[torch.device] = None, + ): + device = device or self._execution_device + image = self.image_processor(images=image, return_tensors="pt").to(device) + image_embeds = self.image_encoder(**image, output_hidden_states=True) + return image_embeds.hidden_states[-2] + + # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, ) - uncond_input_ids = uncond_input.input_ids.to(device) - negative_prompt_embeds = self.text_encoder(uncond_input_ids).last_hidden_state - # duplicate unconditional embeddings for each generation per prompt - negative_prompt_embeds = negative_prompt_embeds.repeat_interleave(num_videos_per_prompt, dim=0) + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt - # For classifier-free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) - return prompt_embeds + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds def check_inputs( self, prompt, + negative_prompt, + image, height, width, - callback_steps, - negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, + image_embeds=None, + callback_on_step_end_tensor_inputs=None, ): - if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + if image is not None and image_embeds is not None: + raise ValueError( + f"Cannot forward both `image`: {image} and `image_embeds`: {image_embeds}. Please make sure to" + " only forward one of the two." + ) + if image is None and image_embeds is None: + raise ValueError( + "Provide either `image` or `prompt_embeds`. Cannot leave both `image` and `image_embeds` undefined." + ) + if image is not None and not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image): + raise ValueError(f"`image` has to be of type `torch.Tensor` or `PIL.Image.Image` but is {type(image)}") + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( - f"`callback_steps` has to be a positive integer but is {callback_steps} of type" - f" {type(callback_steps)}." + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) if prompt is not None and prompt_embeds is not None: @@ -194,247 +351,259 @@ def check_inputs( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) + elif negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to" + " only forward one of the two." + ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif negative_prompt is not None and ( + not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list) + ): + raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") - if negative_prompt is not None and negative_prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" - f" {negative_prompt_embeds}. Please make sure to only forward one of the two." - ) - - if prompt_embeds is not None and negative_prompt_embeds is not None: - if prompt_embeds.shape != negative_prompt_embeds.shape: - raise ValueError( - f"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" - f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" - f" {negative_prompt_embeds.shape}." - ) - - def prepare_image_latents( + def prepare_latents( self, - image: Union[torch.Tensor, PIL.Image.Image], + image: PipelineImageInput, batch_size: int, - num_videos_per_prompt: int, - do_classifier_free_guidance: bool, - device: torch.device, - ) -> torch.Tensor: - """ - Encode an input image to latent space. - - Args: - image (`torch.Tensor` or `PIL.Image.Image`): - Input image to be encoded. - batch_size (`int`): - Batch size. - num_videos_per_prompt (`int`): - Number of videos per prompt. - do_classifier_free_guidance (`bool`): - Whether to use classifier-free guidance. - device (`torch.device`): - Device to place the latents on. - - Returns: - `torch.Tensor`: Encoded image latents. - """ - # Convert PIL image to tensor - if isinstance(image, PIL.Image.Image): - image = self.video_processor.preprocess_image(image) - image = image.to(device=device) + num_channels_latents: int = 16, + height: int = 480, + width: int = 832, + num_frames: int = 81, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + last_image: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + latent_height = height // self.vae_scale_factor_spatial + latent_width = width // self.vae_scale_factor_spatial + + shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) - # Encode image - image_latents = self.vae.encode(image).latent_dist.sample() - image_latents = image_latents * self.vae.scaling_factor + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device=device, dtype=dtype) - # Expand for batch size and classifier-free guidance - image_latents = image_latents.repeat(batch_size * num_videos_per_prompt, 1, 1, 1) - if do_classifier_free_guidance: - image_latents = torch.cat([image_latents, image_latents], dim=0) + image = image.unsqueeze(2) + if last_image is None: + video_condition = torch.cat( + [image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 1, height, width)], dim=2 + ) + else: + last_image = last_image.unsqueeze(2) + video_condition = torch.cat( + [image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 2, height, width), last_image], + dim=2, + ) + video_condition = video_condition.to(device=device, dtype=self.vae.dtype) - return image_latents + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) - def _prepare_image_based_latents( - self, - batch_size: int, - num_channels_latents: int, - num_frames: int, - height: int, - width: int, - dtype: torch.dtype, - device: torch.device, - generator: Optional[torch.Generator] = None, - latents: Optional[torch.Tensor] = None, - image_latents: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - """ - Prepare latents for diffusion with image conditioning. + if isinstance(generator, list): + latent_condition = [ + retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax") for _ in generator + ] + latent_condition = torch.cat(latent_condition) + else: + latent_condition = retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax") + latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1) - Args: - batch_size (`int`): The batch size. - num_channels_latents (`int`): The number of channels in the latent space. - num_frames (`int`): The number of frames to generate. - height (`int`): The height of the video. - width (`int`): The width of the video. - dtype (`torch.dtype`): The data type of the latents. - device (`torch.device`): The device to place the latents on. - generator (`torch.Generator`, *optional*): A generator to use for random number generation. - latents (`torch.Tensor`, *optional*): - Pre-generated latent vectors. If not provided, latents will be generated randomly. - image_latents (`torch.Tensor`, *optional*): Image latents for conditioning. + latent_condition = latent_condition.to(dtype) + latent_condition = (latent_condition - latents_mean) * latents_std - Returns: - `torch.Tensor`: The prepared latent vectors. - """ - shape = ( - batch_size, - num_channels_latents, - num_frames // self.vae_scale_factor_temporal, - height // self.vae_scale_factor_spatial, - width // self.vae_scale_factor_spatial, - ) + mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width) - if latents is None: - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + if last_image is None: + mask_lat_size[:, :, list(range(1, num_frames))] = 0 else: - if latents.shape != shape: - raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") - latents = latents.to(device) + mask_lat_size[:, :, list(range(1, num_frames - 1))] = 0 + first_frame_mask = mask_lat_size[:, :, 0:1] + first_frame_mask = torch.repeat_interleave(first_frame_mask, dim=2, repeats=self.vae_scale_factor_temporal) + mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2) + mask_lat_size = mask_lat_size.view(batch_size, -1, self.vae_scale_factor_temporal, latent_height, latent_width) + mask_lat_size = mask_lat_size.transpose(1, 2) + mask_lat_size = mask_lat_size.to(latent_condition.device) - # Scale the initial noise by the standard deviation required by the scheduler - latents = latents * self.scheduler.init_noise_sigma + return latents, torch.concat([mask_lat_size, latent_condition], dim=1) - # If we have image latents, use them to condition the first frame - if image_latents is not None: - # Expand image latents to match the temporal dimension of the first frame - image_latents = image_latents.unsqueeze(2) # [B, C, 1, H, W] + @property + def guidance_scale(self): + return self._guidance_scale - # Only replace the first frame with the image latents - latents[:, :, 0, :, :] = image_latents.squeeze(2) + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 - return latents + @property + def num_timesteps(self): + return self._num_timesteps - def _get_chunk_indices(self, num_frames: int, chunk_size: int) -> List[Tuple[int, int]]: - """ - Get the start and end indices for each chunk. + @property + def current_timestep(self): + return self._current_timestep - Args: - num_frames (`int`): Total number of frames. - chunk_size (`int`): Size of each chunk. + @property + def interrupt(self): + return self._interrupt - Returns: - `List[Tuple[int, int]]`: List of (start_idx, end_idx) tuples for each chunk. - """ - return [(i, min(i + chunk_size, num_frames)) for i in range(0, num_frames, chunk_size)] + @property + def attention_kwargs(self): + return self._attention_kwargs @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, - prompt: Union[str, List[str]], - image: Union[torch.Tensor, PIL.Image.Image], - height: Optional[int] = 720, - width: Optional[int] = 720, - num_frames: Optional[int] = 24, + image: PipelineImageInput, + prompt: Union[str, List[str]] = None, + negative_prompt: Union[str, List[str]] = None, + height: int = 480, + width: int = 832, + num_frames: int = 81, num_inference_steps: int = 50, - guidance_scale: float = 7.5, - negative_prompt: Optional[Union[str, List[str]]] = None, + guidance_scale: float = 5.0, num_videos_per_prompt: Optional[int] = 1, - eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, + image_embeds: Optional[torch.Tensor] = None, + last_image: Optional[torch.Tensor] = None, output_type: Optional[str] = "np", return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, - callback_steps: int = 1, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - chunk_size: int = 24, - ) -> Union[MagiPipelineOutput, Tuple]: - """ - Function invoked when calling the pipeline for generation. + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + The call function to the pipeline for generation. Args: - prompt (`str` or `List[str]`): - The prompt or prompts to guide the video generation. - image (`torch.Tensor` or `PIL.Image.Image`): - The input image to guide the video generation. - height (`int`, *optional*, defaults to 720): - The height in pixels of the generated video. - width (`int`, *optional*, defaults to 720): - The width in pixels of the generated video. - num_frames (`int`, *optional*, defaults to 24): - The number of video frames to generate. - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality video at the - expense of slower inference. - guidance_scale (`float`, *optional*, defaults to 7.5): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate videos that are closely linked to the text `prompt`, - usually at the expense of lower video quality. + image (`PipelineImageInput`): + The input image to condition the generation on. Must be an image, a list of images or a `torch.Tensor`. + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the video generation. If not defined, one has to pass + The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + height (`int`, defaults to `480`): + The height of the generated video. + width (`int`, defaults to `832`): + The width of the generated video. + num_frames (`int`, defaults to `81`): + The number of frames in the generated video. + num_inference_steps (`int`, defaults to `50`): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, defaults to `5.0`): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. num_videos_per_prompt (`int`, *optional*, defaults to 1): - The number of videos to generate per prompt. - eta (`float`, *optional*, defaults to 0.0): - Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to - [`schedulers.DDIMScheduler`], will be ignored for others. + The number of images to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): - One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) - to make generation deterministic. + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. latents (`torch.Tensor`, *optional*): - Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for video + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor will be generated by sampling using the supplied random `generator`. + tensor is generated by sampling using the supplied random `generator`. prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `negative_prompt` input argument. + image_embeds (`torch.Tensor`, *optional*): + Pre-generated image embeddings. Can be used to easily tweak image inputs (weighting). If not provided, + image embeddings are generated from the `image` input argument. output_type (`str`, *optional*, defaults to `"np"`): - The output format of the generate video. Choose between `np` for `numpy.array`, `pt` for `torch.Tensor` - or `latent` to get the latent space output. + The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.magi.MagiPipelineOutput`] instead of a plain tuple. - callback (`Callable`, *optional*): - A function that will be called every `callback_steps` steps during inference. The function will be - called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function will be called. If not specified, the callback will be - called at every step. - cross_attention_kwargs (`dict`, *optional*): + Whether or not to return a [`WanPipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in - [`diffusers.cross_attention`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). - chunk_size (`int`, *optional*, defaults to 24): - The chunk size to use for autoregressive generation. Measured in frames. + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, defaults to `512`): + The maximum sequence length of the text encoder. If the prompt is longer than this, it will be + truncated. If the prompt is shorter, it will be padded to this length. Examples: Returns: - [`~pipelines.magi.MagiPipelineOutput`] or `tuple`: - If `return_dict` is `True`, [`~pipelines.magi.MagiPipelineOutput`] is returned, otherwise a `tuple` is - returned where the first element is a list with the generated frames. + [`~Magi1PipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`Magi1PipelineOutput`] is returned, otherwise a `tuple` is returned where + the first element is a list with the generated images and the second element is a list of `bool`s + indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. """ - # 0. Default height and width to unet - height = height or self.transformer.config.sample_size - width = width or self.transformer.config.sample_size + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs # 1. Check inputs. Raise error if not correct self.check_inputs( - prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds + prompt, + negative_prompt, + image, + height, + width, + prompt_embeds, + negative_prompt_embeds, + image_embeds, + callback_on_step_end_tensor_inputs, ) + if num_frames % self.vae_scale_factor_temporal != 1: + logger.warning( + f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number." + ) + num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 + num_frames = max(num_frames, 1) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + device = self._execution_device + # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 @@ -443,133 +612,132 @@ def __call__( else: batch_size = prompt_embeds.shape[0] - device = self._execution_device - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 - # 3. Encode input prompt - prompt_embeds = self._encode_prompt( - prompt, - device, - num_videos_per_prompt, - do_classifier_free_guidance, - negative_prompt, + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, ) + # Encode image embedding + transformer_dtype = self.transformer.dtype + prompt_embeds = prompt_embeds.to(transformer_dtype) + if negative_prompt_embeds is not None: + negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) + + if image_embeds is None: + if last_image is None: + image_embeds = self.encode_image(image, device) + else: + image_embeds = self.encode_image([image, last_image], device) + image_embeds = image_embeds.repeat(batch_size, 1, 1) + image_embeds = image_embeds.to(transformer_dtype) + # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps - # 5. Prepare image latents - image_latents = self.prepare_image_latents( + # 5. Prepare latent variables + num_channels_latents = self.vae.config.z_dim + image = self.video_processor.preprocess(image, height=height, width=width).to(device, dtype=torch.float32) + if last_image is not None: + last_image = self.video_processor.preprocess(last_image, height=height, width=width).to( + device, dtype=torch.float32 + ) + latents, condition = self.prepare_latents( image, - batch_size, - num_videos_per_prompt, - do_classifier_free_guidance, - device, - ) - - # 6. Prepare latent variables with image conditioning - num_channels_latents = self.transformer.in_channels - latents = self._prepare_image_based_latents( batch_size * num_videos_per_prompt, num_channels_latents, - num_frames, height, width, - prompt_embeds.dtype, + num_frames, + torch.float32, device, generator, latents, - image_latents, + last_image, ) - # 7. Process in chunks for autoregressive generation - chunk_indices = self._get_chunk_indices( - num_frames // self.vae_scale_factor_temporal, chunk_size // self.vae_scale_factor_temporal - ) - all_latents = [] - - # 8. Process each chunk - for chunk_idx, (start_idx, end_idx) in enumerate(chunk_indices): - # Extract the current chunk - chunk_frames = end_idx - start_idx - if chunk_idx == 0: - # For the first chunk, use the initial latents - chunk_latents = latents[:, :, start_idx:end_idx, :, :] - else: - # For subsequent chunks, use the previous chunk as conditioning - # This is a simplified version - in a real implementation, we would need to handle - # the autoregressive conditioning properly - chunk_latents = randn_tensor( - ( - batch_size * num_videos_per_prompt, - num_channels_latents, - chunk_frames, - height // self.vae_scale_factor_spatial, - width // self.vae_scale_factor_spatial, - ), - generator=generator, - device=device, - dtype=prompt_embeds.dtype, - ) - chunk_latents = chunk_latents * self.scheduler.init_noise_sigma - - # 9. Denoising loop for this chunk - with self.progress_bar(total=len(timesteps)) as progress_bar: - for i, t in enumerate(timesteps): - # expand the latents if we are doing classifier free guidance - latent_model_input = ( - torch.cat([chunk_latents] * 2) if do_classifier_free_guidance else chunk_latents - ) - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - - # predict the noise residual - noise_pred = self.transformer( - latent_model_input, - t, - encoder_hidden_states=prompt_embeds, - cross_attention_kwargs=cross_attention_kwargs, - ).sample - - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - - # compute the previous noisy sample x_t -> x_t-1 - chunk_latents = self.scheduler.step(noise_pred, t, chunk_latents).prev_sample - - # call the callback, if provided - if callback is not None and i % callback_steps == 0: - callback(i, t, chunk_latents) - + # 6. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + latent_model_input = torch.cat([latents, condition], dim=1).to(transformer_dtype) + timestep = t.expand(latents.shape[0]) + + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + encoder_hidden_states_image=image_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + + if self.do_classifier_free_guidance: + noise_uncond = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + encoder_hidden_states_image=image_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() - all_latents.append(chunk_latents) + if XLA_AVAILABLE: + xm.mark_step() - # 10. Concatenate all chunks - latents = torch.cat(all_latents, dim=2) + self._current_timestep = None - # 11. Post-processing - if output_type == "latent": - video = latents + if not output_type == "latent": + latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents / latents_std + latents_mean + video = self.vae.decode(latents, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) else: - # Decode the latents - latents = 1 / self.vae.scaling_factor * latents - video = self.vae.decode(latents).sample - video = (video / 2 + 0.5).clamp(0, 1) - - # Convert to the desired output format - if output_type == "pt": - video = video - else: - video = video.cpu().permute(0, 2, 3, 4, 1).float().numpy() + video = latents + + # Offload all models + self.maybe_free_model_hooks() - # 12. Return output if not return_dict: return (video,) - return MagiPipelineOutput(frames=video) + return Magi1PipelineOutput(frames=video) diff --git a/src/diffusers/pipelines/magi/pipeline_magi_v2v.py b/src/diffusers/pipelines/magi/pipeline_magi_v2v.py index 681aa1557b75..65efb1757a83 100644 --- a/src/diffusers/pipelines/magi/pipeline_magi_v2v.py +++ b/src/diffusers/pipelines/magi/pipeline_magi_v2v.py @@ -1,4 +1,4 @@ -# Copyright 2024 The HuggingFace Team. All rights reserved. +# Copyright 2025 The SandAI Team and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,62 +12,185 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +import html +import inspect +from typing import Any, Callable, Dict, List, Optional, Union +import regex as re import torch +from PIL import Image from transformers import AutoTokenizer, UMT5EncoderModel -from ...models import AutoencoderKLMagi, MagiTransformer3DModel +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...loaders import Magi1LoraLoaderMixin +from ...models import AutoencoderKLMagi1, Magi1Transformer3DModel from ...schedulers import FlowMatchEulerDiscreteScheduler -from ...utils import ( - logging, - randn_tensor, - replace_example_docstring, -) +from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor from ...video_processor import VideoProcessor from ..pipeline_utils import DiffusionPipeline -from .pipeline_output import MagiPipelineOutput +from .pipeline_output import Magi1PipelineOutput +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name +if is_ftfy_available(): + import ftfy + EXAMPLE_DOC_STRING = """ Examples: ```python >>> import torch - >>> from diffusers import MagiVideoToVideoPipeline - >>> from diffusers.utils import export_to_video, load_video - - >>> pipeline = MagiVideoToVideoPipeline.from_pretrained("sand-ai/MAGI-1-4.5B", torch_dtype=torch.float16) - >>> pipeline = pipeline.to("cuda") - - >>> input_video = load_video("path/to/video.mp4") - >>> prompt = "A cat playing in a garden. The cat is chasing a butterfly." - >>> output = pipeline(prompt=prompt, video=input_video, num_frames=24, height=720, width=720).frames[0] - >>> export_to_video(output, "magi_v2v_output.mp4", fps=8) + >>> from diffusers.utils import export_to_video + >>> from diffusers import AutoencoderKLMagi1, Magi1VideoToVideoPipeline + >>> from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler + + >>> # Available models: SandAI/Magi1-T2V-14B-480P-Diffusers, SandAI/Magi1-T2V-1.3B-480P-Diffusers + >>> model_id = "SandAI/Magi1-T2V-1.3B-480P-Diffusers" + >>> vae = AutoencoderKLMagi1.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32) + >>> pipe = Magi1VideoToVideoPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16) + >>> flow_shift = 3.0 # 5.0 for 720P, 3.0 for 480P + >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=flow_shift) + >>> pipe.to("cuda") + + >>> prompt = "A robot standing on a mountain top. The sun is setting in the background" + >>> negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" + >>> video = load_video( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/hiker.mp4" + ... ) + >>> output = pipe( + ... video=video, + ... prompt=prompt, + ... negative_prompt=negative_prompt, + ... height=480, + ... width=720, + ... guidance_scale=5.0, + ... strength=0.7, + ... ).frames[0] + >>> export_to_video(output, "output.mp4", fps=16) ``` """ -class MagiVideoToVideoPipeline(DiffusionPipeline): +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + + +def prompt_clean(text): + text = whitespace_clean(basic_clean(text)) + return text + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): r""" - Pipeline for video-to-video generation using MAGI-1. + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class Magi1VideoToVideoPipeline(DiffusionPipeline, Magi1LoraLoaderMixin): + r""" + Pipeline for video-to-video generation using Magi1. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). Args: - tokenizer (`AutoTokenizer`): - Tokenizer for the text encoder. - text_encoder (`UMT5EncoderModel`): - Text encoder for conditioning. - transformer (`MagiTransformer3DModel`): - Conditional Transformer to denoise the latent video. - vae (`AutoencoderKLMagi`): + tokenizer ([`T5Tokenizer`]): + Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer), + specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. + text_encoder ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. + transformer ([`Magi1Transformer3DModel`]): + Conditional Transformer to denoise the input latents. + scheduler ([`UniPCMultistepScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLMagi1`]): Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. - scheduler (`FlowMatchEulerDiscreteScheduler`): - A scheduler to be used in combination with `transformer` to denoise the encoded video latents. """ model_cpu_offload_seq = "text_encoder->transformer->vae" @@ -77,116 +200,168 @@ def __init__( self, tokenizer: AutoTokenizer, text_encoder: UMT5EncoderModel, - transformer: MagiTransformer3DModel, - vae: AutoencoderKLMagi, + transformer: Magi1Transformer3DModel, + vae: AutoencoderKLMagi1, scheduler: FlowMatchEulerDiscreteScheduler, ): super().__init__() self.register_modules( - tokenizer=tokenizer, + vae=vae, text_encoder=text_encoder, + tokenizer=tokenizer, transformer=transformer, - vae=vae, scheduler=scheduler, ) - self.vae_scale_factor_temporal = 2 ** (1 if hasattr(self.vae, "temporal_downsample") else 0) - self.vae_scale_factor_spatial = 2 ** ( - 3 if hasattr(self.vae, "config") else 8 - ) # Default to 8 for 3 downsamples + self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4 + self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) - def _encode_prompt( + # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline._get_t5_prompt_embeds + def _get_t5_prompt_embeds( self, - prompt: Union[str, List[str]], - device: torch.device, - num_videos_per_prompt: int, - do_classifier_free_guidance: bool, - negative_prompt: Optional[Union[str, List[str]]] = None, - max_sequence_length: int = 512, - ) -> torch.Tensor: - """ - Encodes the prompt into text encoder hidden states. - - Args: - prompt (`str` or `List[str]`): - prompt to be encoded - device (`torch.device`): - torch device - num_videos_per_prompt (`int`): - number of videos that should be generated per prompt - do_classifier_free_guidance (`bool`): - whether to use classifier free guidance or not - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the video generation. Ignored when not using guidance (i.e., ignored - if `guidance_scale` is less than `1`). - max_sequence_length (`int`, *optional*, defaults to 512): - The maximum length of the sequence to be processed by the text encoder. + prompt: Union[str, List[str]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype - Returns: - `torch.Tensor`: text embeddings - """ - batch_size = len(prompt) if isinstance(prompt, list) else 1 + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt = [prompt_clean(u) for u in prompt] + batch_size = len(prompt) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=max_sequence_length, truncation=True, + add_special_tokens=True, + return_attention_mask=True, return_tensors="pt", ) - text_input_ids = text_inputs.input_ids.to(device) - - prompt_embeds = self.text_encoder(text_input_ids).last_hidden_state - - # duplicate text embeddings for each generation per prompt - prompt_embeds = prompt_embeds.repeat_interleave(num_videos_per_prompt, dim=0) - - # get unconditional embeddings for classifier-free guidance - if do_classifier_free_guidance: - uncond_tokens = [""] * batch_size - max_length = text_inputs.input_ids.shape[-1] - uncond_input = self.tokenizer( - uncond_tokens, - padding="max_length", - max_length=max_length, - truncation=True, - return_tensors="pt", + text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask + seq_lens = mask.gt(0).sum(dim=1).long() + + prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 + ) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, ) - uncond_input_ids = uncond_input.input_ids.to(device) - negative_prompt_embeds = self.text_encoder(uncond_input_ids).last_hidden_state - # duplicate unconditional embeddings for each generation per prompt - negative_prompt_embeds = negative_prompt_embeds.repeat_interleave(num_videos_per_prompt, dim=0) + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt - # For classifier-free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) - return prompt_embeds + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds def check_inputs( self, prompt, + negative_prompt, height, width, - callback_steps, - negative_prompt=None, + video=None, + latents=None, prompt_embeds=None, negative_prompt_embeds=None, - video=None, + callback_on_step_end_tensor_inputs=None, ): - if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( - f"`callback_steps` has to be a positive integer but is {callback_steps} of type" - f" {type(callback_steps)}." + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) if prompt is not None and prompt_embeds is not None: @@ -194,254 +369,233 @@ def check_inputs( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) + elif negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to" + " only forward one of the two." + ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif negative_prompt is not None and ( + not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list) + ): + raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") - if negative_prompt is not None and negative_prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" - f" {negative_prompt_embeds}. Please make sure to only forward one of the two." - ) - - if prompt_embeds is not None and negative_prompt_embeds is not None: - if prompt_embeds.shape != negative_prompt_embeds.shape: - raise ValueError( - f"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" - f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" - f" {negative_prompt_embeds.shape}." - ) - - if video is None: - raise ValueError("`video` input cannot be undefined.") - - def prepare_video_latents( - self, - video: torch.Tensor, - batch_size: int, - num_videos_per_prompt: int, - do_classifier_free_guidance: bool, - device: torch.device, - ) -> torch.Tensor: - """ - Encode an input video to latent space. - - Args: - video (`torch.Tensor`): - Input video to be encoded. - batch_size (`int`): - Batch size. - num_videos_per_prompt (`int`): - Number of videos per prompt. - do_classifier_free_guidance (`bool`): - Whether to use classifier-free guidance. - device (`torch.device`): - Device to place the latents on. - - Returns: - `torch.Tensor`: Encoded video latents. - """ - # Ensure video is on the correct device - video = video.to(device=device) - - # Encode video - video_latents = self.vae.encode(video).latent_dist.sample() - video_latents = video_latents * self.vae.scaling_factor - - # Expand for batch size and classifier-free guidance - video_latents = video_latents.repeat(batch_size * num_videos_per_prompt, 1, 1, 1, 1) - if do_classifier_free_guidance: - video_latents = torch.cat([video_latents, video_latents], dim=0) - - return video_latents + if video is not None and latents is not None: + raise ValueError("Only one of `video` or `latents` should be provided") - def _prepare_video_based_latents( + def prepare_latents( self, - batch_size: int, - num_channels_latents: int, - num_frames: int, - height: int, - width: int, - dtype: torch.dtype, - device: torch.device, + video: Optional[torch.Tensor] = None, + batch_size: int = 1, + num_channels_latents: int = 16, + height: int = 480, + width: int = 832, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, generator: Optional[torch.Generator] = None, latents: Optional[torch.Tensor] = None, - video_latents: Optional[torch.Tensor] = None, - num_frames_to_condition: Optional[int] = None, - ) -> torch.Tensor: - """ - Prepare latents for diffusion with video conditioning. - - Args: - batch_size (`int`): The batch size. - num_channels_latents (`int`): The number of channels in the latent space. - num_frames (`int`): The number of frames to generate. - height (`int`): The height of the video. - width (`int`): The width of the video. - dtype (`torch.dtype`): The data type of the latents. - device (`torch.device`): The device to place the latents on. - generator (`torch.Generator`, *optional*): A generator to use for random number generation. - latents (`torch.Tensor`, *optional*): - Pre-generated latent vectors. If not provided, latents will be generated randomly. - video_latents (`torch.Tensor`, *optional*): Video latents for conditioning. - num_frames_to_condition (`int`, *optional*): Number of frames to use for conditioning. + timestep: Optional[torch.Tensor] = None, + ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) - Returns: - `torch.Tensor`: The prepared latent vectors. - """ + num_latent_frames = ( + (video.size(2) - 1) // self.vae_scale_factor_temporal + 1 if latents is None else latents.size(1) + ) shape = ( batch_size, num_channels_latents, - num_frames // self.vae_scale_factor_temporal, + num_latent_frames, height // self.vae_scale_factor_spatial, width // self.vae_scale_factor_spatial, ) if latents is None: - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + init_latents = [retrieve_latents(self.vae.encode(vid.unsqueeze(0)), sample_mode="argmax") for vid in video] + + init_latents = torch.cat(init_latents, dim=0).to(dtype) + + latents_mean = ( + torch.tensor(self.vae.config.latents_mean).view(1, self.vae.config.z_dim, 1, 1, 1).to(device, dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + device, dtype + ) + + init_latents = (init_latents - latents_mean) * latents_std + + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + if hasattr(self.scheduler, "add_noise"): + latents = self.scheduler.add_noise(init_latents, noise, timestep) + else: + latents = self.scheduler.scale_noise(init_latents, timestep, noise) else: - if latents.shape != shape: - raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") latents = latents.to(device) - # Scale the initial noise by the standard deviation required by the scheduler - latents = latents * self.scheduler.init_noise_sigma + return latents - # If we have video latents, use them to condition the first frames - if video_latents is not None: - if num_frames_to_condition is None: - num_frames_to_condition = video_latents.shape[2] + # Copied from diffusers.pipelines.animatediff.pipeline_animatediff_video2video.AnimateDiffVideoToVideoPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, timesteps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) - # Only replace the first N frames with the video latents - latents[:, :, :num_frames_to_condition, :, :] = video_latents[:, :, :num_frames_to_condition, :, :] + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = timesteps[t_start * self.scheduler.order :] - return latents + return timesteps, num_inference_steps - t_start - def _get_chunk_indices(self, num_frames: int, chunk_size: int) -> List[Tuple[int, int]]: - """ - Get the start and end indices for each chunk. + @property + def guidance_scale(self): + return self._guidance_scale - Args: - num_frames (`int`): Total number of frames. - chunk_size (`int`): Size of each chunk. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 - Returns: - `List[Tuple[int, int]]`: List of (start_idx, end_idx) tuples for each chunk. - """ - return [(i, min(i + chunk_size, num_frames)) for i in range(0, num_frames, chunk_size)] + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @property + def attention_kwargs(self): + return self._attention_kwargs @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, - prompt: Union[str, List[str]], - video: torch.Tensor, - height: Optional[int] = 720, - width: Optional[int] = 720, - num_frames: Optional[int] = 24, + video: List[Image.Image] = None, + prompt: Union[str, List[str]] = None, + negative_prompt: Union[str, List[str]] = None, + height: int = 480, + width: int = 832, num_inference_steps: int = 50, - guidance_scale: float = 7.5, - negative_prompt: Optional[Union[str, List[str]]] = None, + timesteps: Optional[List[int]] = None, + guidance_scale: float = 5.0, + strength: float = 0.8, num_videos_per_prompt: Optional[int] = 1, - eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, output_type: Optional[str] = "np", return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, - callback_steps: int = 1, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - chunk_size: int = 24, - num_frames_to_condition: Optional[int] = None, - ) -> Union[MagiPipelineOutput, Tuple]: - """ - Function invoked when calling the pipeline for generation. + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + The call function to the pipeline for generation. Args: - prompt (`str` or `List[str]`): - The prompt or prompts to guide the video generation. - video (`torch.Tensor`): - The input video to guide the video generation. Should be a tensor of shape (B, C, F, H, W). - height (`int`, *optional*, defaults to 720): - The height in pixels of the generated video. - width (`int`, *optional*, defaults to 720): - The width in pixels of the generated video. - num_frames (`int`, *optional*, defaults to 24): - The number of video frames to generate. - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality video at the + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds` + instead. + height (`int`, defaults to `480`): + The height in pixels of the generated image. + width (`int`, defaults to `832`): + The width in pixels of the generated image. + num_frames (`int`, defaults to `81`): + The number of frames in the generated video. + num_inference_steps (`int`, defaults to `50`): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. - guidance_scale (`float`, *optional*, defaults to 7.5): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate videos that are closely linked to the text `prompt`, - usually at the expense of lower video quality. - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the video generation. If not defined, one has to pass - `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is - less than `1`). + guidance_scale (`float`, defaults to `5.0`): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + strength (`float`, defaults to `0.8`): + Higher strength leads to more differences between original image and generated video. num_videos_per_prompt (`int`, *optional*, defaults to 1): - The number of videos to generate per prompt. - eta (`float`, *optional*, defaults to 0.0): - Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to - [`schedulers.DDIMScheduler`], will be ignored for others. + The number of images to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): - One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) - to make generation deterministic. + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. latents (`torch.Tensor`, *optional*): - Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for video + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor will be generated by sampling using the supplied random `generator`. + tensor is generated by sampling using the supplied random `generator`. prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. output_type (`str`, *optional*, defaults to `"np"`): - The output format of the generate video. Choose between `np` for `numpy.array`, `pt` for `torch.Tensor` - or `latent` to get the latent space output. + The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.magi.MagiPipelineOutput`] instead of a plain tuple. - callback (`Callable`, *optional*): - A function that will be called every `callback_steps` steps during inference. The function will be - called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function will be called. If not specified, the callback will be - called at every step. - cross_attention_kwargs (`dict`, *optional*): + Whether or not to return a [`WanPipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in - [`diffusers.cross_attention`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). - chunk_size (`int`, *optional*, defaults to 24): - The chunk size to use for autoregressive generation. Measured in frames. - num_frames_to_condition (`int`, *optional*): - Number of frames from the input video to use for conditioning. If not provided, all frames from the - input video will be used. + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, defaults to `512`): + The maximum sequence length of the text encoder. If the prompt is longer than this, it will be + truncated. If the prompt is shorter, it will be padded to this length. Examples: Returns: - [`~pipelines.magi.MagiPipelineOutput`] or `tuple`: - If `return_dict` is `True`, [`~pipelines.magi.MagiPipelineOutput`] is returned, otherwise a `tuple` is - returned where the first element is a list with the generated frames. + [`~Magi1PipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`Magi1PipelineOutput`] is returned, otherwise a `tuple` is returned where + the first element is a list with the generated images and the second element is a list of `bool`s + indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. """ - # 0. Default height and width to unet - height = height or self.transformer.config.sample_size - width = width or self.transformer.config.sample_size + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + height = height or self.transformer.config.sample_height * self.vae_scale_factor_spatial + width = width or self.transformer.config.sample_width * self.vae_scale_factor_spatial + num_videos_per_prompt = 1 # 1. Check inputs. Raise error if not correct self.check_inputs( - prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds, video + prompt, + negative_prompt, + height, + width, + video, + latents, + prompt_embeds, + negative_prompt_embeds, + callback_on_step_end_tensor_inputs, ) + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + device = self._execution_device + # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 @@ -450,134 +604,122 @@ def __call__( else: batch_size = prompt_embeds.shape[0] - device = self._execution_device - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 - # 3. Encode input prompt - prompt_embeds = self._encode_prompt( - prompt, - device, - num_videos_per_prompt, - do_classifier_free_guidance, - negative_prompt, + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, ) + transformer_dtype = self.transformer.dtype + prompt_embeds = prompt_embeds.to(transformer_dtype) + if negative_prompt_embeds is not None: + negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) + # 4. Prepare timesteps - self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps = self.scheduler.timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, timesteps, strength, device) + latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt) + self._num_timesteps = len(timesteps) - # 5. Prepare video latents - video_latents = self.prepare_video_latents( - video, - batch_size, - num_videos_per_prompt, - do_classifier_free_guidance, - device, - ) + if latents is None: + video = self.video_processor.preprocess_video(video, height=height, width=width).to( + device, dtype=torch.float32 + ) - # 6. Prepare latent variables with video conditioning - num_channels_latents = self.transformer.in_channels - latents = self._prepare_video_based_latents( + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + video, batch_size * num_videos_per_prompt, num_channels_latents, - num_frames, height, width, - prompt_embeds.dtype, + torch.float32, device, generator, latents, - video_latents, - num_frames_to_condition, + latent_timestep, ) - # 7. Process in chunks for autoregressive generation - chunk_indices = self._get_chunk_indices( - num_frames // self.vae_scale_factor_temporal, chunk_size // self.vae_scale_factor_temporal - ) - all_latents = [] - - # 8. Process each chunk - for chunk_idx, (start_idx, end_idx) in enumerate(chunk_indices): - # Extract the current chunk - chunk_frames = end_idx - start_idx - if chunk_idx == 0: - # For the first chunk, use the initial latents - chunk_latents = latents[:, :, start_idx:end_idx, :, :] - else: - # For subsequent chunks, use the previous chunk as conditioning - # This is a simplified version - in a real implementation, we would need to handle - # the autoregressive conditioning properly - chunk_latents = randn_tensor( - ( - batch_size * num_videos_per_prompt, - num_channels_latents, - chunk_frames, - height // self.vae_scale_factor_spatial, - width // self.vae_scale_factor_spatial, - ), - generator=generator, - device=device, - dtype=prompt_embeds.dtype, - ) - chunk_latents = chunk_latents * self.scheduler.init_noise_sigma - - # 9. Denoising loop for this chunk - with self.progress_bar(total=len(timesteps)) as progress_bar: - for i, t in enumerate(timesteps): - # expand the latents if we are doing classifier free guidance - latent_model_input = ( - torch.cat([chunk_latents] * 2) if do_classifier_free_guidance else chunk_latents - ) - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - - # predict the noise residual - noise_pred = self.transformer( - latent_model_input, - t, - encoder_hidden_states=prompt_embeds, - cross_attention_kwargs=cross_attention_kwargs, - ).sample - - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - - # compute the previous noisy sample x_t -> x_t-1 - chunk_latents = self.scheduler.step(noise_pred, t, chunk_latents).prev_sample - - # call the callback, if provided - if callback is not None and i % callback_steps == 0: - callback(i, t, chunk_latents) - + # 6. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + latent_model_input = latents.to(transformer_dtype) + timestep = t.expand(latents.shape[0]) + + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + + if self.do_classifier_free_guidance: + noise_uncond = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() - all_latents.append(chunk_latents) + if XLA_AVAILABLE: + xm.mark_step() - # 10. Concatenate all chunks - latents = torch.cat(all_latents, dim=2) + self._current_timestep = None - # 11. Post-processing - if output_type == "latent": - video = latents + if not output_type == "latent": + latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents / latents_std + latents_mean + video = self.vae.decode(latents, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) else: - # Decode the latents - latents = 1 / self.vae.scaling_factor * latents - video = self.vae.decode(latents).sample - video = (video / 2 + 0.5).clamp(0, 1) - - # Convert to the desired output format - if output_type == "pt": - video = video - else: - video = video.cpu().permute(0, 2, 3, 4, 1).float().numpy() + video = latents + + # Offload all models + self.maybe_free_model_hooks() - # 12. Return output if not return_dict: return (video,) - return MagiPipelineOutput(frames=video) + return Magi1PipelineOutput(frames=video) diff --git a/src/diffusers/pipelines/magi/pipeline_output.py b/src/diffusers/pipelines/magi/pipeline_output.py index f14a2050df33..200156cffac9 100644 --- a/src/diffusers/pipelines/magi/pipeline_output.py +++ b/src/diffusers/pipelines/magi/pipeline_output.py @@ -22,7 +22,7 @@ @dataclass -class MagiPipelineOutput(BaseOutput): +class Magi1PipelineOutput(BaseOutput): """ Output class for MAGI-1 pipeline. From 08287a951f4b67aee5ae87ed2950b4da8a71c0d9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sat, 14 Jun 2025 10:10:07 +0300 Subject: [PATCH 6/8] 2nd draft --- src/diffusers/models/autoencoders/__init__.py | 2 +- .../autoencoders/autoencoder_kl_magi.py | 730 ----------- .../autoencoders/autoencoder_kl_magi1.py | 1085 +++++++++++++++++ .../models/transformers/transformer_magi.py | 661 ---------- .../models/transformers/transformer_magi1.py | 485 ++++++++ 5 files changed, 1571 insertions(+), 1392 deletions(-) delete mode 100644 src/diffusers/models/autoencoders/autoencoder_kl_magi.py create mode 100644 src/diffusers/models/autoencoders/autoencoder_kl_magi1.py delete mode 100644 src/diffusers/models/transformers/transformer_magi.py create mode 100644 src/diffusers/models/transformers/transformer_magi1.py diff --git a/src/diffusers/models/autoencoders/__init__.py b/src/diffusers/models/autoencoders/__init__.py index 1f4d412f3bbf..409f5d06ffc1 100644 --- a/src/diffusers/models/autoencoders/__init__.py +++ b/src/diffusers/models/autoencoders/__init__.py @@ -6,7 +6,7 @@ from .autoencoder_kl_cosmos import AutoencoderKLCosmos from .autoencoder_kl_hunyuan_video import AutoencoderKLHunyuanVideo from .autoencoder_kl_ltx import AutoencoderKLLTXVideo -from .autoencoder_kl_magi import AutoencoderKLMagi +from .autoencoder_kl_magi1 import AutoencoderKLMagi from .autoencoder_kl_magvit import AutoencoderKLMagvit from .autoencoder_kl_mochi import AutoencoderKLMochi from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_magi.py b/src/diffusers/models/autoencoders/autoencoder_kl_magi.py deleted file mode 100644 index a4c9a5eb81fd..000000000000 --- a/src/diffusers/models/autoencoders/autoencoder_kl_magi.py +++ /dev/null @@ -1,730 +0,0 @@ -# Copyright 2024 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from dataclasses import dataclass -from typing import Optional, Tuple, Union - -import torch -import torch.nn as nn - -from ...configuration_utils import ConfigMixin, register_to_config -from ...utils import BaseOutput, logging -from ...utils.accelerate_utils import apply_forward_hook -from ..modeling_utils import ModelMixin - - -logger = logging.get_logger(__name__) - - -@dataclass -class AutoencoderKLMagiOutput(BaseOutput): - """ - Output of AutoencoderKLMagi encoding method. - - Args: - latent_dist (`DiagonalGaussianDistribution`): - Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`. - `DiagonalGaussianDistribution` allows for sampling from the encoded latent vector. - """ - - latent_dist: "DiagonalGaussianDistribution" - - -class DiagonalGaussianDistribution(object): - """ - Diagonal Gaussian distribution with mean and logvar. - """ - - def __init__(self, parameters, deterministic=False): - self.parameters = parameters - self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) - self.logvar = torch.clamp(self.logvar, -30.0, 20.0) - self.deterministic = deterministic - self.std = torch.exp(0.5 * self.logvar) - self.var = torch.exp(self.logvar) - if self.deterministic: - self.var = self.std = torch.zeros_like(self.mean) - - def sample(self, generator=None): - x = self.mean + self.std * torch.randn(self.mean.shape, device=self.parameters.device, generator=generator) - return x - - def kl(self, other=None): - if self.deterministic: - return torch.Tensor([0.0]) - else: - if other is None: - return 0.5 * torch.sum(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2, 3, 4]) - else: - return 0.5 * torch.sum( - torch.pow(self.mean - other.mean, 2) / other.var - + self.var / other.var - - 1.0 - - self.logvar - + other.logvar, - dim=[1, 2, 3, 4], - ) - - def nll(self, sample, dims=[1, 2, 3, 4]): - if self.deterministic: - return torch.Tensor([0.0]) - logtwopi = torch.log(torch.tensor(2.0 * 3.141592653589793)) - return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims) - - def mode(self): - return self.mean - - -class ManualLayerNorm(nn.Module): - """ - Manual implementation of LayerNorm for better compatibility. - """ - - def __init__(self, normalized_shape, eps=1e-5): - super().__init__() - self.normalized_shape = normalized_shape - self.eps = eps - - def forward(self, x): - mean = x.mean(dim=-1, keepdim=True) - std = x.std(dim=-1, keepdim=True, unbiased=False) - x_normalized = (x - mean) / (std + self.eps) - return x_normalized - - -class Mlp(nn.Module): - """ - MLP module used in the transformer architecture. - """ - - def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0): - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features - self.fc1 = nn.Linear(in_features, hidden_features) - self.act = act_layer() - self.fc2 = nn.Linear(hidden_features, out_features) - self.drop = nn.Dropout(drop) - - def forward(self, x): - x = self.fc1(x) - x = self.act(x) - x = self.drop(x) - x = self.fc2(x) - x = self.drop(x) - return x - - -class Attention(nn.Module): - """ - Multi-head attention module. - """ - - def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.0, proj_drop=0.0): - super().__init__() - self.num_heads = num_heads - head_dim = dim // num_heads - self.scale = qk_scale or head_dim**-0.5 - - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim) - self.proj_drop = nn.Dropout(proj_drop) - - def forward(self, x): - B, N, C = x.shape - qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) - q, k, v = qkv[0], qkv[1], qkv[2] - - attn = (q @ k.transpose(-2, -1)) * self.scale - attn = attn.softmax(dim=-1) - attn = self.attn_drop(attn) - - x = (attn @ v).transpose(1, 2).reshape(B, N, C) - x = self.proj(x) - x = self.proj_drop(x) - return x - - -class Block(nn.Module): - """ - Transformer block with attention and MLP. - """ - - def __init__( - self, - dim, - num_heads, - mlp_ratio=4.0, - qkv_bias=False, - qk_scale=None, - drop=0.0, - attn_drop=0.0, - drop_path=0.0, - act_layer=nn.GELU, - norm_layer=nn.LayerNorm, - ): - super().__init__() - self.norm1 = norm_layer(dim) - self.attn = Attention( - dim, - num_heads=num_heads, - qkv_bias=qkv_bias, - qk_scale=qk_scale, - attn_drop=attn_drop, - proj_drop=drop, - ) - self.drop_path = nn.Identity() - self.norm2 = norm_layer(dim) - mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) - - def forward(self, x): - x = x + self.drop_path(self.attn(self.norm1(x))) - x = x + self.drop_path(self.mlp(self.norm2(x))) - return x - - -class PatchEmbed(nn.Module): - """ - Image to Patch Embedding for 3D data. - """ - - def __init__(self, video_size=224, video_length=16, patch_size=16, patch_length=1, in_chans=3, embed_dim=768): - super().__init__() - self.video_size = video_size - self.video_length = video_length - self.patch_size = patch_size - self.patch_length = patch_length - - self.grid_size = video_size // patch_size - self.grid_length = video_length // patch_length - self.num_patches = self.grid_length * self.grid_size * self.grid_size - - self.proj = nn.Conv3d( - in_chans, - embed_dim, - kernel_size=(patch_length, patch_size, patch_size), - stride=(patch_length, patch_size, patch_size), - ) - - def forward(self, x): - B, C, T, H, W = x.shape - assert H == self.video_size and W == self.video_size, ( - f"Input image size ({H}*{W}) doesn't match model ({self.video_size}*{self.video_size})." - ) - assert T == self.video_length, f"Input video length ({T}) doesn't match model ({self.video_length})." - - x = self.proj(x).flatten(2).transpose(1, 2) - return x - - -class ViTEncoder(nn.Module): - """ - Vision Transformer Encoder for MAGI-1 VAE. - """ - - def __init__( - self, - video_size=256, - video_length=16, - patch_size=8, - patch_length=4, - in_chans=3, - z_chans=4, - double_z=True, - embed_dim=768, - depth=12, - num_heads=12, - mlp_ratio=4.0, - qkv_bias=False, - qk_scale=None, - drop_rate=0.0, - attn_drop_rate=0.0, - norm_layer=nn.LayerNorm, - with_cls_token=True, - ): - super().__init__() - self.video_size = video_size - self.video_length = video_length - self.patch_size = patch_size - self.patch_length = patch_length - self.z_chans = z_chans - self.double_z = double_z - self.with_cls_token = with_cls_token - - # Patch embedding - self.patch_embed = PatchEmbed( - video_size=video_size, - video_length=video_length, - patch_size=patch_size, - patch_length=patch_length, - in_chans=in_chans, - embed_dim=embed_dim, - ) - - num_patches = self.patch_embed.num_patches - - # Class token and position embedding - if with_cls_token: - self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) - self.num_tokens = 1 - else: - self.num_tokens = 0 - - self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) - self.pos_drop = nn.Dropout(p=drop_rate) - - # Transformer blocks - dpr = [x.item() for x in torch.linspace(0, 0.0, depth)] # stochastic depth decay rule - self.blocks = nn.ModuleList( - [ - Block( - dim=embed_dim, - num_heads=num_heads, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - qk_scale=qk_scale, - drop=drop_rate, - attn_drop=attn_drop_rate, - drop_path=dpr[i], - norm_layer=norm_layer, - ) - for i in range(depth) - ] - ) - - self.norm = norm_layer(embed_dim) - - # Projection to latent space - self.proj = nn.Linear(embed_dim, z_chans * 2 if double_z else z_chans) - - # Initialize weights - self._init_weights() - - def _init_weights(self): - # Initialize patch_embed like nn.Linear (instead of nn.Conv2d) - w = self.patch_embed.proj.weight.data - nn.init.xavier_uniform_(w.view([w.shape[0], -1])) - - # Initialize position embedding - nn.init.normal_(self.pos_embed, std=0.02) - - # Initialize cls token if used - if self.with_cls_token: - nn.init.normal_(self.cls_token, std=0.02) - - def forward(self, x): - # Patch embedding - x = self.patch_embed(x) - - # Add class token if used - if self.with_cls_token: - cls_tokens = self.cls_token.expand(x.shape[0], -1, -1) - x = torch.cat((cls_tokens, x), dim=1) - - # Add position embedding and apply dropout - x = x + self.pos_embed - x = self.pos_drop(x) - - # Apply transformer blocks - for blk in self.blocks: - x = blk(x) - - x = self.norm(x) - - # Use class token for output if available, otherwise use patch tokens - if self.with_cls_token: - x = x[:, 0] - else: - x = x.mean(dim=1) - - # Project to latent space - x = self.proj(x) - - # Reshape to [B, C, T, H, W] - B = x.shape[0] - T = self.video_length // self.patch_length - H = self.video_size // self.patch_size - W = self.video_size // self.patch_size - C = self.z_chans * 2 if self.double_z else self.z_chans - - x = x.view(B, C, T, H, W) - - return x - - -class ViTDecoder(nn.Module): - """ - Vision Transformer Decoder for MAGI-1 VAE. - """ - - def __init__( - self, - video_size=256, - video_length=16, - patch_size=8, - patch_length=4, - in_chans=3, - z_chans=4, - embed_dim=768, - depth=12, - num_heads=12, - mlp_ratio=4.0, - qkv_bias=False, - qk_scale=None, - drop_rate=0.0, - attn_drop_rate=0.0, - norm_layer=nn.LayerNorm, - with_cls_token=True, - ): - super().__init__() - self.video_size = video_size - self.video_length = video_length - self.patch_size = patch_size - self.patch_length = patch_length - self.z_chans = z_chans - self.with_cls_token = with_cls_token - - # Calculate patch dimensions - self.grid_size = video_size // patch_size - self.grid_length = video_length // patch_length - num_patches = self.grid_length * self.grid_size * self.grid_size - - # Input projection from latent space to embedding dimension - self.proj_in = nn.Linear(z_chans, embed_dim) - - # Class token and position embedding - if with_cls_token: - self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) - self.num_tokens = 1 - else: - self.num_tokens = 0 - - self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) - self.pos_drop = nn.Dropout(p=drop_rate) - - # Transformer blocks - dpr = [x.item() for x in torch.linspace(0, 0.0, depth)] # stochastic depth decay rule - self.blocks = nn.ModuleList( - [ - Block( - dim=embed_dim, - num_heads=num_heads, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - qk_scale=qk_scale, - drop=drop_rate, - attn_drop=attn_drop_rate, - drop_path=dpr[i], - norm_layer=norm_layer, - ) - for i in range(depth) - ] - ) - - self.norm = norm_layer(embed_dim) - - # Output projection to image space - self.proj_out = nn.ConvTranspose3d( - embed_dim, - in_chans, - kernel_size=(patch_length, patch_size, patch_size), - stride=(patch_length, patch_size, patch_size), - ) - - # Initialize weights - self._init_weights() - - def _init_weights(self): - # Initialize position embedding - nn.init.normal_(self.pos_embed, std=0.02) - - # Initialize cls token if used - if self.with_cls_token: - nn.init.normal_(self.cls_token, std=0.02) - - # Initialize output projection - w = self.proj_out.weight.data - nn.init.xavier_uniform_(w.view([w.shape[0], -1])) - - def forward(self, z): - # Get dimensions - B, C, T, H, W = z.shape - - # Flatten spatial dimensions and transpose to [B, T*H*W, C] - z = z.flatten(2).transpose(1, 2) - - # Project to embedding dimension - x = self.proj_in(z) - - # Add class token if used - if self.with_cls_token: - cls_tokens = self.cls_token.expand(x.shape[0], -1, -1) - x = torch.cat((cls_tokens, x), dim=1) - - # Add position embedding and apply dropout - x = x + self.pos_embed - x = self.pos_drop(x) - - # Apply transformer blocks - for blk in self.blocks: - x = blk(x) - - x = self.norm(x) - - # Remove class token if used - if self.with_cls_token: - x = x[:, 1:] - - # Reshape to [B, T, H, W, C] - x = x.reshape(B, T, H, W, -1) - - # Transpose to [B, C, T, H, W] - x = x.permute(0, 4, 1, 2, 3) - - # Project to image space - x = self.proj_out(x) - - return x - - -class AutoencoderKLMagi(ModelMixin, ConfigMixin): - """ - Variational Autoencoder (VAE) model with KL loss for MAGI-1. - - This model inherits from [`ModelMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic - methods implemented for all models (downloading, saving, loading, etc.) - - Parameters: - in_channels (`int`, *optional*, defaults to 3): Number of channels in the input image. - out_channels (`int`, *optional*, defaults to 3): Number of channels in the output. - down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock3D",)`): - Tuple of downsample block types. - up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock3D",)`): - Tuple of upsample block types. - block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`): - Tuple of block output channels. - layers_per_block (`int`, *optional*, defaults to 1): Number of layers per block. - act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. - latent_channels (`int`, *optional*, defaults to 8): Number of channels in the latent space. - norm_num_groups (`int`, *optional*, defaults to 32): Number of groups for the normalization. - scaling_factor (`float`, *optional*, defaults to 0.18215): - The component-wise standard deviation of the trained latent space computed using the first batch of the - training set. This is used to scale the latent space to have unit variance when training the diffusion - model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the - diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 - / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image - Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. - temporal_downsample_factor (`Tuple[int]`, *optional*, defaults to (1, 2, 1, 1)): - Tuple of temporal downsampling factors for each block. - """ - - @register_to_config - def __init__( - self, - in_channels: int = 3, - out_channels: int = 3, - down_block_types: Tuple[str] = ("DownEncoderBlock3D",), - up_block_types: Tuple[str] = ("UpDecoderBlock3D",), - block_out_channels: Tuple[int] = (64,), - layers_per_block: int = 1, - act_fn: str = "silu", - latent_channels: int = 8, - norm_num_groups: int = 32, - scaling_factor: float = 0.18215, - temporal_downsample_factor: Tuple[int] = (1, 2, 1, 1), - video_size: int = 256, - video_length: int = 16, - patch_size: int = 8, - patch_length: int = 4, - embed_dim: int = 768, - depth: int = 12, - num_heads: int = 12, - mlp_ratio: float = 4.0, - qkv_bias: bool = False, - qk_scale: Optional[float] = None, - drop_rate: float = 0.0, - attn_drop_rate: float = 0.0, - with_cls_token: bool = True, - ): - super().__init__() - - # Save important parameters - self.latent_channels = latent_channels - self.scaling_factor = scaling_factor - self.temperal_downsample = temporal_downsample_factor - - # Create encoder and decoder - self.encoder = ViTEncoder( - video_size=video_size, - video_length=video_length, - patch_size=patch_size, - patch_length=patch_length, - in_chans=in_channels, - z_chans=latent_channels, - embed_dim=embed_dim, - depth=depth, - num_heads=num_heads, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - qk_scale=qk_scale, - drop_rate=drop_rate, - attn_drop_rate=attn_drop_rate, - norm_layer=nn.LayerNorm, - with_cls_token=with_cls_token, - double_z=True, - ) - - self.decoder = ViTDecoder( - video_size=video_size, - video_length=video_length, - patch_size=patch_size, - patch_length=patch_length, - in_chans=out_channels, - z_chans=latent_channels, - embed_dim=embed_dim, - depth=depth, - num_heads=num_heads, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - qk_scale=qk_scale, - drop_rate=drop_rate, - attn_drop_rate=attn_drop_rate, - norm_layer=nn.LayerNorm, - with_cls_token=with_cls_token, - ) - - # Enable tiling - self._enable_tiling = False - self._tile_sample_min_size = None - self._tile_sample_stride = None - - @property - def spatial_downsample_factor(self) -> int: - """ - Returns the spatial downsample factor for the VAE. - """ - return self.encoder.patch_size # MAGI-1 uses patch_size as spatial downsampling - - @property - def temporal_downsample_factor(self) -> int: - """ - Returns the temporal downsample factor for the VAE. - """ - return self.encoder.patch_length # MAGI-1 uses patch_length as temporal downsampling - - @apply_forward_hook - def encode( - self, x: torch.FloatTensor, return_dict: bool = True - ) -> Union[AutoencoderKLMagiOutput, Tuple[DiagonalGaussianDistribution]]: - """ - Encode a batch of videos. - - Args: - x (`torch.FloatTensor`): Input batch of videos. - return_dict (`bool`, *optional*, defaults to `True`): Whether to return a dictionary or tuple. - - Returns: - `AutoencoderKLMagiOutput` or `tuple`: - If return_dict is True, returns an `AutoencoderKLMagiOutput` object, otherwise returns a tuple. - """ - h = self.encoder(x) - posterior = DiagonalGaussianDistribution(h) - - if not return_dict: - return (posterior,) - - return AutoencoderKLMagiOutput(latent_dist=posterior) - - @apply_forward_hook - def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[torch.FloatTensor, BaseOutput]: - """ - Decode a batch of latent vectors. - - Args: - z (`torch.FloatTensor`): Input batch of latent vectors. - return_dict (`bool`, *optional*, defaults to `True`): Whether to return a dictionary or tuple. - - Returns: - `BaseOutput` or `torch.FloatTensor`: - If return_dict is True, returns a `BaseOutput` object, otherwise returns the decoded tensor. - """ - dec = self.decoder(z) - - if not return_dict: - return (dec,) - - return BaseOutput(sample=dec) - - def enable_tiling( - self, - tile_sample_min_height: Optional[int] = None, - tile_sample_min_width: Optional[int] = None, - ) -> None: - """ - Enable tiled processing for large videos. - - Args: - tile_sample_min_height (`int`, *optional*): Minimum tile height. - tile_sample_min_width (`int`, *optional*): Minimum tile width. - """ - self._enable_tiling = True - self._tile_sample_min_size = (tile_sample_min_height, tile_sample_min_width) - - def disable_tiling(self) -> None: - """ - Disable tiled processing. - """ - self._enable_tiling = False - self._tile_sample_min_size = None - self._tile_sample_stride = None - - def forward( - self, - sample: torch.FloatTensor, - sample_posterior: bool = False, - return_dict: bool = True, - generator: Optional[torch.Generator] = None, - ) -> Union[BaseOutput, Tuple]: - """ - Forward pass of the model. - - Args: - sample (`torch.FloatTensor`): Input batch of videos. - sample_posterior (`bool`, *optional*, defaults to `False`): - Whether to sample from the posterior distribution. - return_dict (`bool`, *optional*, defaults to `True`): Whether to return a dictionary or tuple. - generator (`torch.Generator`, *optional*): Generator for random sampling. - - Returns: - `BaseOutput` or `tuple`: - If return_dict is True, returns a `BaseOutput` object, otherwise returns a tuple. - """ - posterior = self.encode(sample, return_dict=True).latent_dist - - if sample_posterior: - z = posterior.sample(generator=generator) - else: - z = posterior.mode() - - # Scale latents by the scaling factor - z = self.scaling_factor * z - - # Decode the latents - dec = self.decode(z, return_dict=return_dict) - - if not return_dict: - return (dec,) - - return BaseOutput(sample=dec.sample) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_magi1.py b/src/diffusers/models/autoencoders/autoencoder_kl_magi1.py new file mode 100644 index 000000000000..cdd05298dcb2 --- /dev/null +++ b/src/diffusers/models/autoencoders/autoencoder_kl_magi1.py @@ -0,0 +1,1085 @@ +# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import FromOriginalModelMixin +from ...utils import logging +from ...utils.accelerate_utils import apply_forward_hook +from ..activations import get_activation +from ..modeling_outputs import AutoencoderKLOutput +from ..modeling_utils import ModelMixin +from .vae import DecoderOutput, DiagonalGaussianDistribution + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +CACHE_T = 2 + + +class Magi1CausalConv3d(nn.Conv3d): + r""" + A custom 3D causal convolution layer with feature caching support. + + This layer extends the standard Conv3D layer by ensuring causality in the time dimension and handling feature + caching for efficient inference. + + Args: + in_channels (int): Number of channels in the input image + out_channels (int): Number of channels produced by the convolution + kernel_size (int or tuple): Size of the convolving kernel + stride (int or tuple, optional): Stride of the convolution. Default: 1 + padding (int or tuple, optional): Zero-padding added to all three sides of the input. Default: 0 + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int, int]], + stride: Union[int, Tuple[int, int, int]] = 1, + padding: Union[int, Tuple[int, int, int]] = 0, + ) -> None: + super().__init__( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + ) + + # Set up causal padding + self._padding = (self.padding[2], self.padding[2], self.padding[1], self.padding[1], 2 * self.padding[0], 0) + self.padding = (0, 0, 0) + + def forward(self, x, cache_x=None): + padding = list(self._padding) + if cache_x is not None and self._padding[4] > 0: + cache_x = cache_x.to(x.device) + x = torch.cat([cache_x, x], dim=2) + padding[4] -= cache_x.shape[2] + x = F.pad(x, padding) + return super().forward(x) + + +class Magi1RMS_norm(nn.Module): + r""" + A custom RMS normalization layer. + + Args: + dim (int): The number of dimensions to normalize over. + channel_first (bool, optional): Whether the input tensor has channels as the first dimension. + Default is True. + images (bool, optional): Whether the input represents image data. Default is True. + bias (bool, optional): Whether to include a learnable bias term. Default is False. + """ + + def __init__(self, dim: int, channel_first: bool = True, images: bool = True, bias: bool = False) -> None: + super().__init__() + broadcastable_dims = (1, 1, 1) if not images else (1, 1) + shape = (dim, *broadcastable_dims) if channel_first else (dim,) + + self.channel_first = channel_first + self.scale = dim**0.5 + self.gamma = nn.Parameter(torch.ones(shape)) + self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0 + + def forward(self, x): + return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias + + +class Magi1Upsample(nn.Upsample): + r""" + Perform upsampling while ensuring the output tensor has the same data type as the input. + + Args: + x (torch.Tensor): Input tensor to be upsampled. + + Returns: + torch.Tensor: Upsampled tensor with the same data type as the input. + """ + + def forward(self, x): + return super().forward(x.float()).type_as(x) + + +class Magi1Resample(nn.Module): + r""" + A custom resampling module for 2D and 3D data. + + Args: + dim (int): The number of input/output channels. + mode (str): The resampling mode. Must be one of: + - 'none': No resampling (identity operation). + - 'upsample2d': 2D upsampling with nearest-exact interpolation and convolution. + - 'upsample3d': 3D upsampling with nearest-exact interpolation, convolution, and causal 3D convolution. + - 'downsample2d': 2D downsampling with zero-padding and convolution. + - 'downsample3d': 3D downsampling with zero-padding, convolution, and causal 3D convolution. + """ + + def __init__(self, dim: int, mode: str) -> None: + super().__init__() + self.dim = dim + self.mode = mode + + # layers + if mode == "upsample2d": + self.resample = nn.Sequential( + Magi1Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1) + ) + elif mode == "upsample3d": + self.resample = nn.Sequential( + Magi1Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1) + ) + self.time_conv = Magi1CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) + + elif mode == "downsample2d": + self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))) + elif mode == "downsample3d": + self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))) + self.time_conv = Magi1CausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)) + + else: + self.resample = nn.Identity() + + def forward(self, x, feat_cache=None, feat_idx=[0]): + b, c, t, h, w = x.size() + if self.mode == "upsample3d": + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = "Rep" + feat_idx[0] += 1 + else: + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep": + # cache last frame of last two chunk + cache_x = torch.cat( + [feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2 + ) + if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep": + cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device), cache_x], dim=2) + if feat_cache[idx] == "Rep": + x = self.time_conv(x) + else: + x = self.time_conv(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + + x = x.reshape(b, 2, c, t, h, w) + x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3) + x = x.reshape(b, c, t * 2, h, w) + t = x.shape[2] + x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) + x = self.resample(x) + x = x.view(b, t, x.size(1), x.size(2), x.size(3)).permute(0, 2, 1, 3, 4) + + if self.mode == "downsample3d": + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = x.clone() + feat_idx[0] += 1 + else: + cache_x = x[:, :, -1:, :, :].clone() + x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + return x + + +class Magi1ResidualBlock(nn.Module): + r""" + A custom residual block module. + + Args: + in_dim (int): Number of input channels. + out_dim (int): Number of output channels. + dropout (float, optional): Dropout rate for the dropout layer. Default is 0.0. + non_linearity (str, optional): Type of non-linearity to use. Default is "silu". + """ + + def __init__( + self, + in_dim: int, + out_dim: int, + dropout: float = 0.0, + non_linearity: str = "silu", + ) -> None: + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + self.nonlinearity = get_activation(non_linearity) + + # layers + self.norm1 = Magi1RMS_norm(in_dim, images=False) + self.conv1 = Magi1CausalConv3d(in_dim, out_dim, 3, padding=1) + self.norm2 = Magi1RMS_norm(out_dim, images=False) + self.dropout = nn.Dropout(dropout) + self.conv2 = Magi1CausalConv3d(out_dim, out_dim, 3, padding=1) + self.conv_shortcut = Magi1CausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity() + + def forward(self, x, feat_cache=None, feat_idx=[0]): + # Apply shortcut connection + h = self.conv_shortcut(x) + + # First normalization and activation + x = self.norm1(x) + x = self.nonlinearity(x) + + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + + x = self.conv1(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv1(x) + + # Second normalization and activation + x = self.norm2(x) + x = self.nonlinearity(x) + + # Dropout + x = self.dropout(x) + + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + + x = self.conv2(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv2(x) + + # Add residual connection + return x + h + + +class Magi1AttentionBlock(nn.Module): + r""" + Causal self-attention with a single head. + + Args: + dim (int): The number of channels in the input tensor. + """ + + def __init__(self, dim): + super().__init__() + self.dim = dim + + # layers + self.norm = Magi1RMS_norm(dim) + self.to_qkv = nn.Conv2d(dim, dim * 3, 1) + self.proj = nn.Conv2d(dim, dim, 1) + + def forward(self, x): + identity = x + batch_size, channels, time, height, width = x.size() + + x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * time, channels, height, width) + x = self.norm(x) + + # compute query, key, value + qkv = self.to_qkv(x) + qkv = qkv.reshape(batch_size * time, 1, channels * 3, -1) + qkv = qkv.permute(0, 1, 3, 2).contiguous() + q, k, v = qkv.chunk(3, dim=-1) + + # apply attention + x = F.scaled_dot_product_attention(q, k, v) + + x = x.squeeze(1).permute(0, 2, 1).reshape(batch_size * time, channels, height, width) + + # output projection + x = self.proj(x) + + # Reshape back: [(b*t), c, h, w] -> [b, c, t, h, w] + x = x.view(batch_size, time, channels, height, width) + x = x.permute(0, 2, 1, 3, 4) + + return x + identity + + +class Magi1MidBlock(nn.Module): + """ + Middle block for WanVAE encoder and decoder. + + Args: + dim (int): Number of input/output channels. + dropout (float): Dropout rate. + non_linearity (str): Type of non-linearity to use. + """ + + def __init__(self, dim: int, dropout: float = 0.0, non_linearity: str = "silu", num_layers: int = 1): + super().__init__() + self.dim = dim + + # Create the components + resnets = [Magi1ResidualBlock(dim, dim, dropout, non_linearity)] + attentions = [] + for _ in range(num_layers): + attentions.append(Magi1AttentionBlock(dim)) + resnets.append(Magi1ResidualBlock(dim, dim, dropout, non_linearity)) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + self.gradient_checkpointing = False + + def forward(self, x, feat_cache=None, feat_idx=[0]): + # First residual block + x = self.resnets[0](x, feat_cache, feat_idx) + + # Process through attention and residual blocks + for attn, resnet in zip(self.attentions, self.resnets[1:]): + if attn is not None: + x = attn(x) + + x = resnet(x, feat_cache, feat_idx) + + return x + + +class Magi1Encoder3d(nn.Module): + r""" + A 3D encoder module. + + Args: + dim (int): The base number of channels in the first layer. + z_dim (int): The dimensionality of the latent space. + dim_mult (list of int): Multipliers for the number of channels in each block. + num_res_blocks (int): Number of residual blocks in each block. + attn_scales (list of float): Scales at which to apply attention mechanisms. + temperal_downsample (list of bool): Whether to downsample temporally in each block. + dropout (float): Dropout rate for the dropout layers. + non_linearity (str): Type of non-linearity to use. + """ + + def __init__( + self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[True, True, False], + dropout=0.0, + non_linearity: str = "silu", + ): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_downsample = temperal_downsample + self.nonlinearity = get_activation(non_linearity) + + # dimensions + dims = [dim * u for u in [1] + dim_mult] + scale = 1.0 + + # init block + self.conv_in = Magi1CausalConv3d(3, dims[0], 3, padding=1) + + # downsample blocks + self.down_blocks = nn.ModuleList([]) + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + # residual (+attention) blocks + for _ in range(num_res_blocks): + self.down_blocks.append(Magi1ResidualBlock(in_dim, out_dim, dropout)) + if scale in attn_scales: + self.down_blocks.append(Magi1AttentionBlock(out_dim)) + in_dim = out_dim + + # downsample block + if i != len(dim_mult) - 1: + mode = "downsample3d" if temperal_downsample[i] else "downsample2d" + self.down_blocks.append(Magi1Resample(out_dim, mode=mode)) + scale /= 2.0 + + # middle blocks + self.mid_block = Magi1MidBlock(out_dim, dropout, non_linearity, num_layers=1) + + # output blocks + self.norm_out = Magi1RMS_norm(out_dim, images=False) + self.conv_out = Magi1CausalConv3d(out_dim, z_dim, 3, padding=1) + + self.gradient_checkpointing = False + + def forward(self, x, feat_cache=None, feat_idx=[0]): + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + x = self.conv_in(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv_in(x) + + ## downsamples + for layer in self.down_blocks: + if feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## middle + x = self.mid_block(x, feat_cache, feat_idx) + + ## head + x = self.norm_out(x) + x = self.nonlinearity(x) + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + x = self.conv_out(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv_out(x) + return x + + +class Magi1UpBlock(nn.Module): + """ + A block that handles upsampling for the WanVAE decoder. + + Args: + in_dim (int): Input dimension + out_dim (int): Output dimension + num_res_blocks (int): Number of residual blocks + dropout (float): Dropout rate + upsample_mode (str, optional): Mode for upsampling ('upsample2d' or 'upsample3d') + non_linearity (str): Type of non-linearity to use + """ + + def __init__( + self, + in_dim: int, + out_dim: int, + num_res_blocks: int, + dropout: float = 0.0, + upsample_mode: Optional[str] = None, + non_linearity: str = "silu", + ): + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + + # Create layers list + resnets = [] + # Add residual blocks and attention if needed + current_dim = in_dim + for _ in range(num_res_blocks + 1): + resnets.append(Magi1ResidualBlock(current_dim, out_dim, dropout, non_linearity)) + current_dim = out_dim + + self.resnets = nn.ModuleList(resnets) + + # Add upsampling layer if needed + self.upsamplers = None + if upsample_mode is not None: + self.upsamplers = nn.ModuleList([Magi1Resample(out_dim, mode=upsample_mode)]) + + self.gradient_checkpointing = False + + def forward(self, x, feat_cache=None, feat_idx=[0]): + """ + Forward pass through the upsampling block. + + Args: + x (torch.Tensor): Input tensor + feat_cache (list, optional): Feature cache for causal convolutions + feat_idx (list, optional): Feature index for cache management + + Returns: + torch.Tensor: Output tensor + """ + for resnet in self.resnets: + if feat_cache is not None: + x = resnet(x, feat_cache, feat_idx) + else: + x = resnet(x) + + if self.upsamplers is not None: + if feat_cache is not None: + x = self.upsamplers[0](x, feat_cache, feat_idx) + else: + x = self.upsamplers[0](x) + return x + + +class Magi1Decoder3d(nn.Module): + r""" + A 3D decoder module. + + Args: + dim (int): The base number of channels in the first layer. + z_dim (int): The dimensionality of the latent space. + dim_mult (list of int): Multipliers for the number of channels in each block. + num_res_blocks (int): Number of residual blocks in each block. + attn_scales (list of float): Scales at which to apply attention mechanisms. + temperal_upsample (list of bool): Whether to upsample temporally in each block. + dropout (float): Dropout rate for the dropout layers. + non_linearity (str): Type of non-linearity to use. + """ + + def __init__( + self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_upsample=[False, True, True], + dropout=0.0, + non_linearity: str = "silu", + ): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_upsample = temperal_upsample + + self.nonlinearity = get_activation(non_linearity) + + # dimensions + dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] + scale = 1.0 / 2 ** (len(dim_mult) - 2) + + # init block + self.conv_in = Magi1CausalConv3d(z_dim, dims[0], 3, padding=1) + + # middle blocks + self.mid_block = Magi1MidBlock(dims[0], dropout, non_linearity, num_layers=1) + + # upsample blocks + self.up_blocks = nn.ModuleList([]) + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + # residual (+attention) blocks + if i > 0: + in_dim = in_dim // 2 + + # Determine if we need upsampling + upsample_mode = None + if i != len(dim_mult) - 1: + upsample_mode = "upsample3d" if temperal_upsample[i] else "upsample2d" + + # Create and add the upsampling block + up_block = Magi1UpBlock( + in_dim=in_dim, + out_dim=out_dim, + num_res_blocks=num_res_blocks, + dropout=dropout, + upsample_mode=upsample_mode, + non_linearity=non_linearity, + ) + self.up_blocks.append(up_block) + + # Update scale for next iteration + if upsample_mode is not None: + scale *= 2.0 + + # output blocks + self.norm_out = Magi1RMS_norm(out_dim, images=False) + self.conv_out = Magi1CausalConv3d(out_dim, 3, 3, padding=1) + + self.gradient_checkpointing = False + + def forward(self, x, feat_cache=None, feat_idx=[0]): + ## conv1 + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + x = self.conv_in(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv_in(x) + + ## middle + x = self.mid_block(x, feat_cache, feat_idx) + + ## upsamples + for up_block in self.up_blocks: + x = up_block(x, feat_cache, feat_idx) + + ## head + x = self.norm_out(x) + x = self.nonlinearity(x) + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + x = self.conv_out(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv_out(x) + return x + + +class AutoencoderKLMagi1(ModelMixin, ConfigMixin, FromOriginalModelMixin): + r""" + A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. + Introduced in [Wan 2.1]. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + """ + + _supports_gradient_checkpointing = False + + @register_to_config + def __init__( + self, + base_dim: int = 96, + z_dim: int = 16, + dim_mult: Tuple[int] = [1, 2, 4, 4], + num_res_blocks: int = 2, + attn_scales: List[float] = [], + temperal_downsample: List[bool] = [False, True, True], + dropout: float = 0.0, + latents_mean: List[float] = [ + -0.7571, + -0.7089, + -0.9113, + 0.1075, + -0.1745, + 0.9653, + -0.1517, + 1.5508, + 0.4134, + -0.0715, + 0.5517, + -0.3632, + -0.1922, + -0.9497, + 0.2503, + -0.2921, + ], + latents_std: List[float] = [ + 2.8184, + 1.4541, + 2.3275, + 2.6558, + 1.2196, + 1.7708, + 2.6052, + 2.0743, + 3.2687, + 2.1526, + 2.8652, + 1.5579, + 1.6382, + 1.1253, + 2.8251, + 1.9160, + ], + ) -> None: + super().__init__() + + self.z_dim = z_dim + self.temperal_downsample = temperal_downsample + self.temperal_upsample = temperal_downsample[::-1] + + self.encoder = Magi1Encoder3d( + base_dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout + ) + self.quant_conv = Magi1CausalConv3d(z_dim * 2, z_dim * 2, 1) + self.post_quant_conv = Magi1CausalConv3d(z_dim, z_dim, 1) + + self.decoder = Magi1Decoder3d( + base_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout + ) + + self.spatial_compression_ratio = 2 ** len(self.temperal_downsample) + + # When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension + # to perform decoding of a single video latent at a time. + self.use_slicing = False + + # When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent + # frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the + # intermediate tiles together, the memory requirement can be lowered. + self.use_tiling = False + + # The minimal tile height and width for spatial tiling to be used + self.tile_sample_min_height = 256 + self.tile_sample_min_width = 256 + + # The minimal distance between two spatial tiles + self.tile_sample_stride_height = 192 + self.tile_sample_stride_width = 192 + + def enable_tiling( + self, + tile_sample_min_height: Optional[int] = None, + tile_sample_min_width: Optional[int] = None, + tile_sample_stride_height: Optional[float] = None, + tile_sample_stride_width: Optional[float] = None, + ) -> None: + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + + Args: + tile_sample_min_height (`int`, *optional*): + The minimum height required for a sample to be separated into tiles across the height dimension. + tile_sample_min_width (`int`, *optional*): + The minimum width required for a sample to be separated into tiles across the width dimension. + tile_sample_stride_height (`int`, *optional*): + The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are + no tiling artifacts produced across the height dimension. + tile_sample_stride_width (`int`, *optional*): + The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling + artifacts produced across the width dimension. + """ + self.use_tiling = True + self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height + self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width + self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height + self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width + + def disable_tiling(self) -> None: + r""" + Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.use_tiling = False + + def enable_slicing(self) -> None: + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.use_slicing = True + + def disable_slicing(self) -> None: + r""" + Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.use_slicing = False + + def clear_cache(self): + def _count_conv3d(model): + count = 0 + for m in model.modules(): + if isinstance(m, Magi1CausalConv3d): + count += 1 + return count + + self._conv_num = _count_conv3d(self.decoder) + self._conv_idx = [0] + self._feat_map = [None] * self._conv_num + # cache encode + self._enc_conv_num = _count_conv3d(self.encoder) + self._enc_conv_idx = [0] + self._enc_feat_map = [None] * self._enc_conv_num + + def _encode(self, x: torch.Tensor): + _, _, num_frame, height, width = x.shape + + if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height): + return self.tiled_encode(x) + + self.clear_cache() + iter_ = 1 + (num_frame - 1) // 4 + for i in range(iter_): + self._enc_conv_idx = [0] + if i == 0: + out = self.encoder(x[:, :, :1, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx) + else: + out_ = self.encoder( + x[:, :, 1 + 4 * (i - 1) : 1 + 4 * i, :, :], + feat_cache=self._enc_feat_map, + feat_idx=self._enc_conv_idx, + ) + out = torch.cat([out, out_], 2) + + enc = self.quant_conv(out) + self.clear_cache() + return enc + + @apply_forward_hook + def encode( + self, x: torch.Tensor, return_dict: bool = True + ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: + r""" + Encode a batch of images into latents. + + Args: + x (`torch.Tensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + + Returns: + The latent representations of the encoded videos. If `return_dict` is True, a + [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. + """ + if self.use_slicing and x.shape[0] > 1: + encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)] + h = torch.cat(encoded_slices) + else: + h = self._encode(x) + posterior = DiagonalGaussianDistribution(h) + + if not return_dict: + return (posterior,) + return AutoencoderKLOutput(latent_dist=posterior) + + def _decode(self, z: torch.Tensor, return_dict: bool = True): + _, _, num_frame, height, width = z.shape + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + + if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height): + return self.tiled_decode(z, return_dict=return_dict) + + self.clear_cache() + x = self.post_quant_conv(z) + for i in range(num_frame): + self._conv_idx = [0] + if i == 0: + out = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) + else: + out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) + out = torch.cat([out, out_], 2) + + out = torch.clamp(out, min=-1.0, max=1.0) + self.clear_cache() + if not return_dict: + return (out,) + + return DecoderOutput(sample=out) + + @apply_forward_hook + def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + r""" + Decode a batch of images. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + if self.use_slicing and z.shape[0] > 1: + decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)] + decoded = torch.cat(decoded_slices) + else: + decoded = self._decode(z).sample + + if not return_dict: + return (decoded,) + return DecoderOutput(sample=decoded) + + def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[-2], b.shape[-2], blend_extent) + for y in range(blend_extent): + b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * ( + y / blend_extent + ) + return b + + def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[-1], b.shape[-1], blend_extent) + for x in range(blend_extent): + b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * ( + x / blend_extent + ) + return b + + def tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput: + r"""Encode a batch of images using a tiled encoder. + + Args: + x (`torch.Tensor`): Input batch of videos. + + Returns: + `torch.Tensor`: + The latent representation of the encoded videos. + """ + _, _, num_frames, height, width = x.shape + latent_height = height // self.spatial_compression_ratio + latent_width = width // self.spatial_compression_ratio + + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio + tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio + + blend_height = tile_latent_min_height - tile_latent_stride_height + blend_width = tile_latent_min_width - tile_latent_stride_width + + # Split x into overlapping tiles and encode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, height, self.tile_sample_stride_height): + row = [] + for j in range(0, width, self.tile_sample_stride_width): + self.clear_cache() + time = [] + frame_range = 1 + (num_frames - 1) // 4 + for k in range(frame_range): + self._enc_conv_idx = [0] + if k == 0: + tile = x[:, :, :1, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width] + else: + tile = x[ + :, + :, + 1 + 4 * (k - 1) : 1 + 4 * k, + i : i + self.tile_sample_min_height, + j : j + self.tile_sample_min_width, + ] + tile = self.encoder(tile, feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx) + tile = self.quant_conv(tile) + time.append(tile) + row.append(torch.cat(time, dim=2)) + rows.append(row) + self.clear_cache() + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width]) + result_rows.append(torch.cat(result_row, dim=-1)) + + enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width] + return enc + + def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + r""" + Decode a batch of images using a tiled decoder. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + _, _, num_frames, height, width = z.shape + sample_height = height * self.spatial_compression_ratio + sample_width = width * self.spatial_compression_ratio + + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio + tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio + + blend_height = self.tile_sample_min_height - self.tile_sample_stride_height + blend_width = self.tile_sample_min_width - self.tile_sample_stride_width + + # Split z into overlapping tiles and decode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, height, tile_latent_stride_height): + row = [] + for j in range(0, width, tile_latent_stride_width): + self.clear_cache() + time = [] + for k in range(num_frames): + self._conv_idx = [0] + tile = z[:, :, k : k + 1, i : i + tile_latent_min_height, j : j + tile_latent_min_width] + tile = self.post_quant_conv(tile) + decoded = self.decoder(tile, feat_cache=self._feat_map, feat_idx=self._conv_idx) + time.append(decoded) + row.append(torch.cat(time, dim=2)) + rows.append(row) + self.clear_cache() + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width]) + result_rows.append(torch.cat(result_row, dim=-1)) + + dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width] + + if not return_dict: + return (dec,) + return DecoderOutput(sample=dec) + + def forward( + self, + sample: torch.Tensor, + sample_posterior: bool = False, + return_dict: bool = True, + generator: Optional[torch.Generator] = None, + ) -> Union[DecoderOutput, torch.Tensor]: + """ + Args: + sample (`torch.Tensor`): Input sample. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`DecoderOutput`] instead of a plain tuple. + """ + x = sample + posterior = self.encode(x).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z, return_dict=return_dict) + return dec diff --git a/src/diffusers/models/transformers/transformer_magi.py b/src/diffusers/models/transformers/transformer_magi.py deleted file mode 100644 index 1ea78053f349..000000000000 --- a/src/diffusers/models/transformers/transformer_magi.py +++ /dev/null @@ -1,661 +0,0 @@ -# Copyright 2024 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple, Union - -import torch -import torch.nn as nn -import torch.nn.functional as F - -from ...configuration_utils import ConfigMixin, register_to_config -from ...utils import BaseOutput, logging -from ..embeddings import TimestepEmbedding, Timesteps -from ..modeling_utils import ModelMixin - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -@dataclass -class MagiTransformerOutput(BaseOutput): - """ - The output of [`MagiTransformer3DModel`]. - - Args: - sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch_size, num_channels, frames, height, width)`): - The hidden states output from the last layer of the model. - """ - - sample: torch.FloatTensor - - -class MagiAttention(nn.Module): - """ - A cross attention layer for MAGI-1. - - This implements the specialized attention mechanism from the MAGI-1 model, including query/key normalization and - proper handling of rotary embeddings. - """ - - def __init__( - self, - query_dim: int, - cross_attention_dim: Optional[int] = None, - heads: int = 8, - dim_head: int = 64, - dropout: float = 0.0, - bias: bool = False, - upcast_attention: bool = False, - upcast_softmax: bool = False, - added_kv_proj_dim: Optional[int] = None, - norm_num_groups: Optional[int] = None, - ): - super().__init__() - inner_dim = dim_head * heads - cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim - - self.upcast_attention = upcast_attention - self.upcast_softmax = upcast_softmax - - self.scale = dim_head**-0.5 - self.heads = heads - self.dim_head = dim_head - - # Projection layers - self.to_q = nn.Linear(query_dim, inner_dim, bias=bias) - self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias) - self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias) - - # Normalization layers for query and key - important part of MAGI-1's attention mechanism - self.norm_q = nn.LayerNorm(dim_head, eps=1e-5) - self.norm_k = nn.LayerNorm(dim_head, eps=1e-5) - - self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) - - def reshape_heads_to_batch_dim(self, tensor): - batch_size, seq_len, dim = tensor.shape - head_size = self.heads - tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) - tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size) - return tensor - - def reshape_batch_dim_to_heads(self, tensor): - batch_size, seq_len, dim = tensor.shape - head_size = self.heads - tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) - tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) - return tensor - - def forward( - self, - hidden_states, - encoder_hidden_states=None, - attention_mask=None, - rotary_pos_emb=None, - **cross_attention_kwargs, - ): - batch_size, sequence_length, _ = hidden_states.shape - - encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states - - # Project to query, key, value - query = self.to_q(hidden_states) - key = self.to_k(encoder_hidden_states) - value = self.to_v(encoder_hidden_states) - - # Reshape for multi-head attention - query = query.reshape(batch_size, sequence_length, self.heads, self.dim_head) - key = key.reshape(batch_size, -1, self.heads, self.dim_head) - value = value.reshape(batch_size, -1, self.heads, self.dim_head) - - # Apply layer normalization to query and key (as in MAGI-1) - # Convert to float32 for better numerical stability during normalization - orig_dtype = query.dtype - query = self.norm_q(query.float()).to(orig_dtype) - key = self.norm_k(key.float()).to(orig_dtype) - - # Transpose for attention - # [batch_size, seq_len, heads, dim_head] -> [batch_size, heads, seq_len, dim_head] - query = query.transpose(1, 2) - key = key.transpose(1, 2) - value = value.transpose(1, 2) - - # Apply rotary position embeddings if provided - if rotary_pos_emb is not None: - # Apply rotary embeddings using the same method as in MAGI-1 - def apply_rotary_emb(hidden_states, freqs): - dtype = torch.float32 if hidden_states.device.type == "mps" else torch.float64 - # Convert to complex numbers - x_complex = torch.view_as_complex(hidden_states.to(dtype).unflatten(-1, (-1, 2))) - # Apply rotation in complex space - x_rotated = x_complex * freqs - # Convert back to real - x_out = torch.view_as_real(x_rotated).flatten(-2) - return x_out.type_as(hidden_states) - - # Apply rotary embeddings - query = apply_rotary_emb(query, rotary_pos_emb) - key = apply_rotary_emb(key, rotary_pos_emb) - - # Use scaled_dot_product_attention if available (PyTorch 2.0+) - if hasattr(F, "scaled_dot_product_attention"): - # Apply scaled dot product attention - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) - # [batch_size, heads, seq_len, dim_head] -> [batch_size, seq_len, heads*dim_head] - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, sequence_length, -1) - else: - # Manual implementation of attention - # Reshape for bmm - query = query.reshape(batch_size * self.heads, sequence_length, self.dim_head) - key = key.reshape(batch_size * self.heads, -1, self.dim_head) - value = value.reshape(batch_size * self.heads, -1, self.dim_head) - - # Compute attention scores - if self.upcast_attention: - query = query.float() - key = key.float() - - attention_scores = torch.bmm(query, key.transpose(-1, -2)) * self.scale - - if attention_mask is not None: - attention_scores = attention_scores + attention_mask - - if self.upcast_softmax: - attention_scores = attention_scores.float() - - attention_probs = F.softmax(attention_scores, dim=-1) - attention_probs = attention_probs.to(value.dtype) - - # Compute output - hidden_states = torch.bmm(attention_probs, value) - hidden_states = hidden_states.reshape(batch_size, self.heads, sequence_length, self.dim_head) - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, sequence_length, -1) - - # Project to output - hidden_states = self.to_out(hidden_states) - - return hidden_states - - -class MagiTransformerBlock(nn.Module): - """ - A transformer block for MAGI-1. - - This is a simplified version of the MAGI-1 transformer block. - """ - - def __init__( - self, - dim: int, - num_attention_heads: int, - attention_head_dim: int, - dropout: float = 0.0, - cross_attention_dim: Optional[int] = None, - activation_fn: str = "gelu", - num_embeds_ada_norm: Optional[int] = None, - attention_bias: bool = False, - only_cross_attention: bool = False, - upcast_attention: bool = False, - norm_elementwise_affine: bool = True, - norm_type: str = "layer_norm", - norm_eps: float = 1e-5, - final_dropout: bool = False, - ): - super().__init__() - self.only_cross_attention = only_cross_attention - - # Self-attention - self.norm1 = nn.LayerNorm(dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine) - self.attn1 = MagiAttention( - query_dim=dim, - heads=num_attention_heads, - dim_head=attention_head_dim, - dropout=dropout, - bias=attention_bias, - upcast_attention=upcast_attention, - ) - - # Cross-attention - if cross_attention_dim is not None: - self.norm2 = nn.LayerNorm(dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine) - self.attn2 = MagiAttention( - query_dim=dim, - cross_attention_dim=cross_attention_dim, - heads=num_attention_heads, - dim_head=attention_head_dim, - dropout=dropout, - bias=attention_bias, - upcast_attention=upcast_attention, - ) - else: - self.norm2 = None - self.attn2 = None - - # Feed-forward - self.norm3 = nn.LayerNorm(dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine) - - # Choose activation function - if activation_fn == "gelu": - self.ff = nn.Sequential( - nn.Linear(dim, dim * 4), - nn.GELU(), - nn.Dropout(dropout) if final_dropout else nn.Identity(), - nn.Linear(dim * 4, dim), - ) - elif activation_fn == "gelu-approximate": - self.ff = nn.Sequential( - nn.Linear(dim, dim * 4), - nn.GELU(approximate="tanh"), - nn.Dropout(dropout) if final_dropout else nn.Identity(), - nn.Linear(dim * 4, dim), - ) - elif activation_fn == "silu": - self.ff = nn.Sequential( - nn.Linear(dim, dim * 4), - nn.SiLU(), - nn.Dropout(dropout) if final_dropout else nn.Identity(), - nn.Linear(dim * 4, dim), - ) - else: - raise ValueError(f"Unsupported activation function: {activation_fn}") - - self.final_dropout = nn.Dropout(dropout) if final_dropout else nn.Identity() - - def forward( - self, - hidden_states, - encoder_hidden_states=None, - timestep=None, - attention_mask=None, - cross_attention_kwargs=None, - rotary_pos_emb=None, - **kwargs, - ): - # Self-attention - norm_hidden_states = self.norm1(hidden_states) - - if self.only_cross_attention: - hidden_states = hidden_states + self.attn1( - norm_hidden_states, - encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, - rotary_pos_emb=rotary_pos_emb, - ) - else: - hidden_states = hidden_states + self.attn1( - norm_hidden_states, - attention_mask=attention_mask, - rotary_pos_emb=rotary_pos_emb, - ) - - # Cross-attention - if self.attn2 is not None: - norm_hidden_states = self.norm2(hidden_states) - hidden_states = hidden_states + self.attn2( - norm_hidden_states, - encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, - rotary_pos_emb=rotary_pos_emb, - **(cross_attention_kwargs if cross_attention_kwargs is not None else {}), - ) - - # Feed-forward - norm_hidden_states = self.norm3(hidden_states) - ff_output = self.ff(norm_hidden_states) - hidden_states = hidden_states + self.final_dropout(ff_output) - - return hidden_states - - -class LearnableRotaryEmbedding(nn.Module): - """ - Learnable rotary position embeddings similar to the one used in MAGI-1. - - This implementation is based on MAGI-1's LearnableRotaryEmbeddingCat class, which creates rotary embeddings for 3D - data (frames, height, width). - """ - - def __init__( - self, - dim: int, - max_seq_len: int = 1024, - temperature: float = 10000.0, - in_pixels: bool = True, - linear_bands: bool = False, - ): - super().__init__() - self.dim = dim - self.max_seq_len = max_seq_len - self.temperature = temperature - self.in_pixels = in_pixels - self.linear_bands = linear_bands - - # Initialize frequency bands - self.register_buffer("freqs", self._get_default_bands()) - - def _get_default_bands(self): - """Generate default frequency bands""" - if self.linear_bands: - # Linear spacing - bands = torch.linspace(1.0, self.max_seq_len / 2, self.dim // 2, dtype=torch.float32) - else: - # Log spacing (as in original RoPE) - bands = 1.0 / (self.temperature ** (torch.arange(0, self.dim // 2, dtype=torch.float32) / (self.dim // 2))) - - return bands * torch.pi - - def get_embed(self, shape: List[int]) -> torch.Tensor: - """ - Generate rotary position embeddings for the given shape. - - Args: - shape: List of dimensions [frames, height, width] - - Returns: - Rotary position embeddings (sin and cos components) - """ - frames, height, width = shape - seq_len = frames * height * width - - # Generate position indices - if self.in_pixels: - # Normalize positions to [-1, 1] - t = torch.linspace(-1.0, 1.0, steps=frames, device=self.freqs.device) - h = torch.linspace(-1.0, 1.0, steps=height, device=self.freqs.device) - w = torch.linspace(-1.0, 1.0, steps=width, device=self.freqs.device) - else: - # Use integer positions - t = torch.arange(frames, device=self.freqs.device, dtype=torch.float32) - h = torch.arange(height, device=self.freqs.device, dtype=torch.float32) - w = torch.arange(width, device=self.freqs.device, dtype=torch.float32) - - # Center spatial dimensions (as in MAGI-1) - h = h - (height - 1) / 2 - w = w - (width - 1) / 2 - - # Create position grid - grid = torch.stack(torch.meshgrid(t, h, w, indexing="ij"), dim=-1) - grid = grid.reshape(-1, 3) # [seq_len, 3] - - # Get frequency bands - freqs = self.freqs.to(grid.device) - - # Compute embeddings for each dimension - # Temporal dimension - t_emb = torch.outer(grid[:, 0], freqs[: self.dim // 6]) - t_sin = torch.sin(t_emb) - t_cos = torch.cos(t_emb) - - # Height dimension - h_emb = torch.outer(grid[:, 1], freqs[: self.dim // 6]) - h_sin = torch.sin(h_emb) - h_cos = torch.cos(h_emb) - - # Width dimension - w_emb = torch.outer(grid[:, 2], freqs[: self.dim // 6]) - w_sin = torch.sin(w_emb) - w_cos = torch.cos(w_emb) - - # Concatenate all embeddings - sin_emb = torch.cat([t_sin, h_sin, w_sin], dim=-1) - cos_emb = torch.cat([t_cos, h_cos, w_cos], dim=-1) - - # Pad or trim to match expected dimension - target_dim = self.dim // 2 - if sin_emb.shape[1] < target_dim: - pad_size = target_dim - sin_emb.shape[1] - sin_emb = F.pad(sin_emb, (0, pad_size)) - cos_emb = F.pad(cos_emb, (0, pad_size)) - elif sin_emb.shape[1] > target_dim: - sin_emb = sin_emb[:, :target_dim] - cos_emb = cos_emb[:, :target_dim] - - # Combine sin and cos for rotary embeddings - return torch.cat([cos_emb.unsqueeze(-1), sin_emb.unsqueeze(-1)], dim=-1).reshape(seq_len, target_dim, 2) - - -class MagiTransformer3DModel(ModelMixin, ConfigMixin): - """ - Transformer model for MAGI-1. - - This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods implemented for - all models (downloading, saving, loading, etc.) - - Parameters: - sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): - Height and width of input/output sample. - in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample. - out_channels (`int`, *optional*, defaults to 4): Number of channels in the output. - num_layers (`int`, *optional*, defaults to 24): Number of transformer blocks. - num_attention_heads (`int`, *optional*, defaults to 16): Number of attention heads. - attention_head_dim (`int`, *optional*, defaults to 64): Dimension of attention heads. - cross_attention_dim (`int`, *optional*, defaults to 1280): Dimension of cross-attention conditioning. - activation_fn (`str`, *optional*, defaults to `"gelu"`): Activation function. - norm_elementwise_affine (`bool`, *optional*, defaults to `True`): - Whether to use learnable elementwise affine parameters for normalization. - norm_type (`str`, *optional*, defaults to `"layer_norm"`): Type of normalization. - norm_eps (`float`, *optional*, defaults to 1e-5): Epsilon for normalization. - attention_bias (`bool`, *optional*, defaults to `False`): Whether to use bias in attention. - num_embeds_ada_norm (`int`, *optional*, defaults to `None`): Number of embeddings for AdaLayerNorm. - only_cross_attention (`bool`, *optional*, defaults to `False`): Whether to only use cross-attention. - upcast_attention (`bool`, *optional*, defaults to `False`): Whether to upcast attention operations. - dropout (`float`, *optional*, defaults to 0.0): Dropout probability. - """ - - @register_to_config - def __init__( - self, - sample_size: Optional[Union[int, Tuple[int, int]]] = None, - in_channels: int = 4, - out_channels: int = 4, - num_layers: int = 24, - num_attention_heads: int = 16, - attention_head_dim: int = 64, - cross_attention_dim: int = 1280, - activation_fn: str = "gelu", - norm_elementwise_affine: bool = True, - norm_type: str = "layer_norm", - norm_eps: float = 1e-5, - attention_bias: bool = False, - num_embeds_ada_norm: Optional[int] = None, - only_cross_attention: bool = False, - upcast_attention: bool = False, - dropout: float = 0.0, - patch_size: Tuple[int, int, int] = (1, 1, 1), - max_seq_len: int = 1024, - ): - super().__init__() - - self.sample_size = sample_size - self.patch_size = patch_size - self.max_seq_len = max_seq_len - - # Input embedding - self.in_channels = in_channels - time_embed_dim = attention_head_dim * num_attention_heads - self.time_proj = Timesteps(time_embed_dim, flip_sin_to_cos=True, downscale_freq_shift=0) - self.time_embedding = TimestepEmbedding(time_embed_dim, time_embed_dim) - - # Input projection - self.input_proj = nn.Conv3d(in_channels, time_embed_dim, kernel_size=patch_size, stride=patch_size) - - # Rotary position embeddings - self.rotary_embedding = LearnableRotaryEmbedding( - dim=attention_head_dim, - max_seq_len=max_seq_len, - temperature=10000.0, - ) - - # Transformer blocks - self.transformer_blocks = nn.ModuleList( - [ - MagiTransformerBlock( - dim=time_embed_dim, - num_attention_heads=num_attention_heads, - attention_head_dim=attention_head_dim, - dropout=dropout, - cross_attention_dim=cross_attention_dim, - activation_fn=activation_fn, - num_embeds_ada_norm=num_embeds_ada_norm, - attention_bias=attention_bias, - only_cross_attention=only_cross_attention, - upcast_attention=upcast_attention, - norm_elementwise_affine=norm_elementwise_affine, - norm_type=norm_type, - norm_eps=norm_eps, - ) - for _ in range(num_layers) - ] - ) - - # Output projection - self.out_channels = out_channels - self.output_proj = nn.Conv3d(time_embed_dim, out_channels, kernel_size=1) - - self.gradient_checkpointing = False - - def set_attention_slice(self, slice_size): - """ - Enable sliced attention computation. - - Args: - slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): - When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If - `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is - provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` - must be a multiple of `slice_size`. - """ - logger.warning( - "Calling `set_attention_slice` is deprecated and will be removed in a future version. Use" - " `set_attention_processor` instead." - ) - - # Not implemented for MAGI-1 yet - pass - - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, MagiTransformerBlock): - module.gradient_checkpointing = value - - def forward( - self, - hidden_states: torch.Tensor, - timesteps: Optional[torch.Tensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - return_dict: bool = True, - ) -> Union[MagiTransformerOutput, Tuple]: - """ - Forward pass of the model. - - Args: - hidden_states (`torch.Tensor`): - Input tensor of shape `(batch_size, in_channels, frames, height, width)`. - timesteps (`torch.Tensor`, *optional*): - Timesteps tensor of shape `(batch_size,)`. - encoder_hidden_states (`torch.Tensor`, *optional*): - Encoder hidden states for cross-attention. - attention_mask (`torch.Tensor`, *optional*): - Attention mask for cross-attention. - cross_attention_kwargs (`dict`, *optional*): - Additional arguments for cross-attention. - return_dict (`bool`, *optional*, defaults to `True`): - Whether to return a dictionary. - - Returns: - `MagiTransformerOutput` or `tuple`: - If `return_dict` is `True`, a `MagiTransformerOutput` is returned, otherwise a tuple `(sample,)` is - returned where `sample` is the output tensor. - """ - # 1. Input processing - batch_size, channels, frames, height, width = hidden_states.shape - residual = hidden_states - - # 2. Time embedding - if timesteps is not None: - timesteps = timesteps.to(hidden_states.device) - time_embeds = self.time_proj(timesteps) - time_embeds = self.time_embedding(time_embeds) - else: - time_embeds = None - - # 3. Project input - hidden_states = self.input_proj(hidden_states) - - # Get patched dimensions - p_t, p_h, p_w = self.patch_size - patched_frames = frames // p_t - patched_height = height // p_h - patched_width = width // p_w - - # 4. Reshape for transformer blocks - hidden_states = hidden_states.permute(0, 2, 3, 4, 1) # [B, C, F, H, W] -> [B, F, H, W, C] - hidden_states = hidden_states.reshape( - batch_size, patched_frames * patched_height * patched_width, -1 - ) # [B, F*H*W, C] - - # 5. Add time embeddings if provided - if time_embeds is not None: - time_embeds = time_embeds.unsqueeze(1) # [B, 1, C] - hidden_states = hidden_states + time_embeds - - # 6. Generate rotary position embeddings - rotary_pos_emb = self.rotary_embedding.get_embed([patched_frames, patched_height, patched_width]) - rotary_pos_emb = rotary_pos_emb.to(hidden_states.device) - - # Convert to complex representation for the attention mechanism - # This matches MAGI-1's approach to applying rotary embeddings - cos_emb = rotary_pos_emb[..., 0] - sin_emb = rotary_pos_emb[..., 1] - rotary_pos_emb = torch.complex(cos_emb, sin_emb).unsqueeze(0) # [1, seq_len, dim//2] - - # 7. Process with transformer blocks - for block in self.transformer_blocks: - if self.gradient_checkpointing and self.training: - hidden_states = torch.utils.checkpoint.checkpoint( - block, - hidden_states, - encoder_hidden_states, - timesteps, - attention_mask, - None, # cross_attention_kwargs - rotary_pos_emb, - ) - else: - hidden_states = block( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - timestep=timesteps, - attention_mask=attention_mask, - cross_attention_kwargs=cross_attention_kwargs, - rotary_pos_emb=rotary_pos_emb, - ) - - # 8. Reshape back to video format - hidden_states = hidden_states.reshape(batch_size, patched_frames, patched_height, patched_width, -1) - hidden_states = hidden_states.permute(0, 4, 1, 2, 3) # [B, F, H, W, C] -> [B, C, F, H, W] - - # 9. Project output - hidden_states = self.output_proj(hidden_states) - - # 10. Add residual connection - hidden_states = hidden_states + residual - - if not return_dict: - return (hidden_states,) - - return MagiTransformerOutput(sample=hidden_states) diff --git a/src/diffusers/models/transformers/transformer_magi1.py b/src/diffusers/models/transformers/transformer_magi1.py new file mode 100644 index 000000000000..0bcaa779a7c2 --- /dev/null +++ b/src/diffusers/models/transformers/transformer_magi1.py @@ -0,0 +1,485 @@ +# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Any, Dict, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import FromOriginalModelMixin, PeftAdapterMixin +from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from ..attention import FeedForward +from ..attention_processor import Attention +from ..cache_utils import CacheMixin +from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed +from ..modeling_outputs import Transformer2DModelOutput +from ..modeling_utils import ModelMixin +from ..normalization import FP32LayerNorm + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class Magi1AttnProcessor2_0: + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("Magi1AttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + encoder_hidden_states_img = None + if attn.add_k_proj is not None: + # 512 is the context length of the text encoder, hardcoded for now + image_context_length = encoder_hidden_states.shape[1] - 512 + encoder_hidden_states_img = encoder_hidden_states[:, :image_context_length] + encoder_hidden_states = encoder_hidden_states[:, image_context_length:] + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + query = attn.to_q(hidden_states) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2) + key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2) + value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2) + + if rotary_emb is not None: + + def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor): + dtype = torch.float32 if hidden_states.device.type == "mps" else torch.float64 + x_rotated = torch.view_as_complex(hidden_states.to(dtype).unflatten(3, (-1, 2))) + x_out = torch.view_as_real(x_rotated * freqs).flatten(3, 4) + return x_out.type_as(hidden_states) + + query = apply_rotary_emb(query, rotary_emb) + key = apply_rotary_emb(key, rotary_emb) + + # I2V task + hidden_states_img = None + if encoder_hidden_states_img is not None: + key_img = attn.add_k_proj(encoder_hidden_states_img) + key_img = attn.norm_added_k(key_img) + value_img = attn.add_v_proj(encoder_hidden_states_img) + + key_img = key_img.unflatten(2, (attn.heads, -1)).transpose(1, 2) + value_img = value_img.unflatten(2, (attn.heads, -1)).transpose(1, 2) + + hidden_states_img = F.scaled_dot_product_attention( + query, key_img, value_img, attn_mask=None, dropout_p=0.0, is_causal=False + ) + hidden_states_img = hidden_states_img.transpose(1, 2).flatten(2, 3) + hidden_states_img = hidden_states_img.type_as(query) + + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) + hidden_states = hidden_states.type_as(query) + + if hidden_states_img is not None: + hidden_states = hidden_states + hidden_states_img + + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + return hidden_states + + +class Magi1ImageEmbedding(torch.nn.Module): + def __init__(self, in_features: int, out_features: int, pos_embed_seq_len=None): + super().__init__() + + self.norm1 = FP32LayerNorm(in_features) + self.ff = FeedForward(in_features, out_features, mult=1, activation_fn="gelu") + self.norm2 = FP32LayerNorm(out_features) + if pos_embed_seq_len is not None: + self.pos_embed = nn.Parameter(torch.zeros(1, pos_embed_seq_len, in_features)) + else: + self.pos_embed = None + + def forward(self, encoder_hidden_states_image: torch.Tensor) -> torch.Tensor: + if self.pos_embed is not None: + batch_size, seq_len, embed_dim = encoder_hidden_states_image.shape + encoder_hidden_states_image = encoder_hidden_states_image.view(-1, 2 * seq_len, embed_dim) + encoder_hidden_states_image = encoder_hidden_states_image + self.pos_embed + + hidden_states = self.norm1(encoder_hidden_states_image) + hidden_states = self.ff(hidden_states) + hidden_states = self.norm2(hidden_states) + return hidden_states + + +class Magi1TimeTextImageEmbedding(nn.Module): + def __init__( + self, + dim: int, + time_freq_dim: int, + time_proj_dim: int, + text_embed_dim: int, + image_embed_dim: Optional[int] = None, + pos_embed_seq_len: Optional[int] = None, + ): + super().__init__() + + self.timesteps_proj = Timesteps(num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0) + self.time_embedder = TimestepEmbedding(in_channels=time_freq_dim, time_embed_dim=dim) + self.act_fn = nn.SiLU() + self.time_proj = nn.Linear(dim, time_proj_dim) + self.text_embedder = PixArtAlphaTextProjection(text_embed_dim, dim, act_fn="gelu_tanh") + + self.image_embedder = None + if image_embed_dim is not None: + self.image_embedder = Magi1ImageEmbedding(image_embed_dim, dim, pos_embed_seq_len=pos_embed_seq_len) + + def forward( + self, + timestep: torch.Tensor, + encoder_hidden_states: torch.Tensor, + encoder_hidden_states_image: Optional[torch.Tensor] = None, + ): + timestep = self.timesteps_proj(timestep) + + time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype + if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8: + timestep = timestep.to(time_embedder_dtype) + temb = self.time_embedder(timestep).type_as(encoder_hidden_states) + timestep_proj = self.time_proj(self.act_fn(temb)) + + encoder_hidden_states = self.text_embedder(encoder_hidden_states) + if encoder_hidden_states_image is not None: + encoder_hidden_states_image = self.image_embedder(encoder_hidden_states_image) + + return temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image + + +class Magi1RotaryPosEmbed(nn.Module): + def __init__( + self, attention_head_dim: int, patch_size: Tuple[int, int, int], max_seq_len: int, theta: float = 10000.0 + ): + super().__init__() + + self.attention_head_dim = attention_head_dim + self.patch_size = patch_size + self.max_seq_len = max_seq_len + + h_dim = w_dim = 2 * (attention_head_dim // 6) + t_dim = attention_head_dim - h_dim - w_dim + + freqs = [] + freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64 + for dim in [t_dim, h_dim, w_dim]: + freq = get_1d_rotary_pos_embed( + dim, max_seq_len, theta, use_real=False, repeat_interleave_real=False, freqs_dtype=freqs_dtype + ) + freqs.append(freq) + self.freqs = torch.cat(freqs, dim=1) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, num_channels, num_frames, height, width = hidden_states.shape + p_t, p_h, p_w = self.patch_size + ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w + + freqs = self.freqs.to(hidden_states.device) + freqs = freqs.split_with_sizes( + [ + self.attention_head_dim // 2 - 2 * (self.attention_head_dim // 6), + self.attention_head_dim // 6, + self.attention_head_dim // 6, + ], + dim=1, + ) + + freqs_f = freqs[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1) + freqs_h = freqs[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1) + freqs_w = freqs[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1) + freqs = torch.cat([freqs_f, freqs_h, freqs_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1) + return freqs + + +class Magi1TransformerBlock(nn.Module): + def __init__( + self, + dim: int, + ffn_dim: int, + num_heads: int, + qk_norm: str = "rms_norm_across_heads", + cross_attn_norm: bool = False, + eps: float = 1e-6, + added_kv_proj_dim: Optional[int] = None, + ): + super().__init__() + + # 1. Self-attention + self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False) + self.attn1 = Attention( + query_dim=dim, + heads=num_heads, + kv_heads=num_heads, + dim_head=dim // num_heads, + qk_norm=qk_norm, + eps=eps, + bias=True, + cross_attention_dim=None, + out_bias=True, + processor=Magi1AttnProcessor2_0(), + ) + + # 2. Cross-attention + self.attn2 = Attention( + query_dim=dim, + heads=num_heads, + kv_heads=num_heads, + dim_head=dim // num_heads, + qk_norm=qk_norm, + eps=eps, + bias=True, + cross_attention_dim=None, + out_bias=True, + added_kv_proj_dim=added_kv_proj_dim, + added_proj_bias=True, + processor=Magi1AttnProcessor2_0(), + ) + self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity() + + # 3. Feed-forward + self.ffn = FeedForward(dim, inner_dim=ffn_dim, activation_fn="gelu-approximate") + self.norm3 = FP32LayerNorm(dim, eps, elementwise_affine=False) + + self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + rotary_emb: torch.Tensor, + ) -> torch.Tensor: + shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = ( + self.scale_shift_table + temb.float() + ).chunk(6, dim=1) + + # 1. Self-attention + norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states) + attn_output = self.attn1(hidden_states=norm_hidden_states, rotary_emb=rotary_emb) + hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states) + + # 2. Cross-attention + norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states) + attn_output = self.attn2(hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states) + hidden_states = hidden_states + attn_output + + # 3. Feed-forward + norm_hidden_states = (self.norm3(hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as( + hidden_states + ) + ff_output = self.ffn(norm_hidden_states) + hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states) + + return hidden_states + + +class Magi1Transformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin): + r""" + A Transformer model for video-like data used in the Magi1 model. + + Args: + patch_size (`Tuple[int]`, defaults to `(1, 2, 2)`): + 3D patch dimensions for video embedding (t_patch, h_patch, w_patch). + num_attention_heads (`int`, defaults to `40`): + Fixed length for text embeddings. + attention_head_dim (`int`, defaults to `128`): + The number of channels in each head. + in_channels (`int`, defaults to `16`): + The number of channels in the input. + out_channels (`int`, defaults to `16`): + The number of channels in the output. + text_dim (`int`, defaults to `512`): + Input dimension for text embeddings. + freq_dim (`int`, defaults to `256`): + Dimension for sinusoidal time embeddings. + ffn_dim (`int`, defaults to `13824`): + Intermediate dimension in feed-forward network. + num_layers (`int`, defaults to `40`): + The number of layers of transformer blocks to use. + window_size (`Tuple[int]`, defaults to `(-1, -1)`): + Window size for local attention (-1 indicates global attention). + cross_attn_norm (`bool`, defaults to `True`): + Enable cross-attention normalization. + qk_norm (`bool`, defaults to `True`): + Enable query/key normalization. + eps (`float`, defaults to `1e-6`): + Epsilon value for normalization layers. + add_img_emb (`bool`, defaults to `False`): + Whether to use img_emb. + added_kv_proj_dim (`int`, *optional*, defaults to `None`): + The number of channels to use for the added key and value projections. If `None`, no projection is used. + """ + + _supports_gradient_checkpointing = True + _skip_layerwise_casting_patterns = ["patch_embedding", "condition_embedder", "norm"] + _no_split_modules = ["Magi1TransformerBlock"] + _keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"] + _keys_to_ignore_on_load_unexpected = ["norm_added_q"] + + @register_to_config + def __init__( + self, + patch_size: Tuple[int] = (1, 2, 2), + num_attention_heads: int = 40, + attention_head_dim: int = 128, + in_channels: int = 16, + out_channels: int = 16, + text_dim: int = 4096, + freq_dim: int = 256, + ffn_dim: int = 13824, + num_layers: int = 40, + cross_attn_norm: bool = True, + qk_norm: Optional[str] = "rms_norm_across_heads", + eps: float = 1e-6, + image_dim: Optional[int] = None, + added_kv_proj_dim: Optional[int] = None, + rope_max_seq_len: int = 1024, + pos_embed_seq_len: Optional[int] = None, + ) -> None: + super().__init__() + + inner_dim = num_attention_heads * attention_head_dim + out_channels = out_channels or in_channels + + # 1. Patch & position embedding + self.rope = Magi1RotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len) + self.patch_embedding = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size) + + # 2. Condition embeddings + # image_embedding_dim=1280 for I2V model + self.condition_embedder = Magi1TimeTextImageEmbedding( + dim=inner_dim, + time_freq_dim=freq_dim, + time_proj_dim=inner_dim * 6, + text_embed_dim=text_dim, + image_embed_dim=image_dim, + pos_embed_seq_len=pos_embed_seq_len, + ) + + # 3. Transformer blocks + self.blocks = nn.ModuleList( + [ + Magi1TransformerBlock( + inner_dim, ffn_dim, num_attention_heads, qk_norm, cross_attn_norm, eps, added_kv_proj_dim + ) + for _ in range(num_layers) + ] + ) + + # 4. Output norm & projection + self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False) + self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size)) + self.scale_shift_table = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + timestep: torch.LongTensor, + encoder_hidden_states: torch.Tensor, + encoder_hidden_states_image: Optional[torch.Tensor] = None, + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: + if attention_kwargs is not None: + attention_kwargs = attention_kwargs.copy() + lora_scale = attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." + ) + + batch_size, num_channels, num_frames, height, width = hidden_states.shape + p_t, p_h, p_w = self.config.patch_size + post_patch_num_frames = num_frames // p_t + post_patch_height = height // p_h + post_patch_width = width // p_w + + rotary_emb = self.rope(hidden_states) + + hidden_states = self.patch_embedding(hidden_states) + hidden_states = hidden_states.flatten(2).transpose(1, 2) + + temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder( + timestep, encoder_hidden_states, encoder_hidden_states_image + ) + timestep_proj = timestep_proj.unflatten(1, (6, -1)) + + if encoder_hidden_states_image is not None: + encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1) + + # 4. Transformer blocks + if torch.is_grad_enabled() and self.gradient_checkpointing: + for block in self.blocks: + hidden_states = self._gradient_checkpointing_func( + block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb + ) + else: + for block in self.blocks: + hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb) + + # 5. Output norm, projection & unpatchify + shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1) + + # Move the shift and scale tensors to the same device as hidden_states. + # When using multi-GPU inference via accelerate these will be on the + # first device rather than the last device, which hidden_states ends up + # on. + shift = shift.to(hidden_states.device) + scale = scale.to(hidden_states.device) + + hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states) + hidden_states = self.proj_out(hidden_states) + + hidden_states = hidden_states.reshape( + batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1 + ) + hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6) + output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) From 878488128c1952cea13ae96066f438d41d2bae6d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sat, 14 Jun 2025 15:42:17 +0300 Subject: [PATCH 7/8] up --- src/diffusers/models/autoencoders/autoencoder_kl_magi1.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_magi1.py b/src/diffusers/models/autoencoders/autoencoder_kl_magi1.py index cdd05298dcb2..fe701cd59209 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_magi1.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_magi1.py @@ -1,4 +1,4 @@ -# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved. +# Copyright 2025 The Sand AI Team and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -326,7 +326,7 @@ def forward(self, x): class Magi1MidBlock(nn.Module): """ - Middle block for WanVAE encoder and decoder. + Middle block for Magi1VAE encoder and decoder. Args: dim (int): Number of input/output channels. @@ -472,7 +472,7 @@ def forward(self, x, feat_cache=None, feat_idx=[0]): class Magi1UpBlock(nn.Module): """ - A block that handles upsampling for the WanVAE decoder. + A block that handles upsampling for the Magi1VAE decoder. Args: in_dim (int): Input dimension @@ -659,7 +659,7 @@ def forward(self, x, feat_cache=None, feat_idx=[0]): class AutoencoderKLMagi1(ModelMixin, ConfigMixin, FromOriginalModelMixin): r""" A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. - Introduced in [Wan 2.1]. + Introduced in [Magi1](https://arxiv.org/abs/2505.13211). This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented for all models (such as downloading or saving). From ae03b7d69d9017d628020d95378f87d7e6973e25 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sun, 15 Jun 2025 16:22:18 +0300 Subject: [PATCH 8/8] Refactor Magi1AttentionBlock to support rotary embeddings and integrate attention mechanism accordingly. Updated initialization parameters and reshaping logic. --- .../autoencoders/autoencoder_kl_magi1.py | 56 ++++++++++++------- 1 file changed, 35 insertions(+), 21 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_magi1.py b/src/diffusers/models/autoencoders/autoencoder_kl_magi1.py index fe701cd59209..f70a8709d6c6 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_magi1.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_magi1.py @@ -27,6 +27,8 @@ from ..modeling_outputs import AutoencoderKLOutput from ..modeling_utils import ModelMixin from .vae import DecoderOutput, DiagonalGaussianDistribution +from ..normalization import FP32LayerNorm +from ..embeddings import apply_rotary_emb logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -281,45 +283,57 @@ def forward(self, x, feat_cache=None, feat_idx=[0]): class Magi1AttentionBlock(nn.Module): r""" - Causal self-attention with a single head. Args: - dim (int): The number of channels in the input tensor. """ - def __init__(self, dim): + def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0.0, proj_drop=0.0, ln_in_attn=False, use_rope=False): super().__init__() - self.dim = dim - + self.use_rope = use_rope + self.num_heads = num_heads # layers - self.norm = Magi1RMS_norm(dim) - self.to_qkv = nn.Conv2d(dim, dim * 3, 1) - self.proj = nn.Conv2d(dim, dim, 1) + self.to_qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop_rate = attn_drop + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + if ln_in_attn: + # TODO: ManualLayerNorm at original repo? + self.qkv_norm = FP32LayerNorm(dim // num_heads, elementwise_affine=False) + else: + self.qkv_norm = nn.Identity() - def forward(self, x): + def forward(self, x, feat_shape=None): identity = x batch_size, channels, time, height, width = x.size() - x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * time, channels, height, width) - x = self.norm(x) + x = x.permute(0, 2, 3, 4, 1).reshape(batch_size, time * height * width, channels) # compute query, key, value qkv = self.to_qkv(x) - qkv = qkv.reshape(batch_size * time, 1, channels * 3, -1) - qkv = qkv.permute(0, 1, 3, 2).contiguous() - q, k, v = qkv.chunk(3, dim=-1) - - # apply attention - x = F.scaled_dot_product_attention(q, k, v) + qkv = qkv.reshape(batch_size, time * height * width, 3, self.num_heads, channels // self.num_heads) + x = self.qkv_norm(qkv) + q, k, v = qkv.chunk(3, dim=2) + + if self.use_rope: + rope_emb = cache_rotary_emb(feat_shape=feat_shape, + dim=channels // self.num_heads, + device=x.device, dtype=x.dtype) + q = q.reshape(batch_size, self.num_heads, time * height * width, channels // self.num_heads) + k = k.reshape(batch_size, self.num_heads, time * height * width, channels // self.num_heads) + q[:, 1:, :] = apply_rotary_emb(q[:, :, 1:], (cos_emb, sin_emb)).bfloat16() + k[:, 1:, :] = apply_rotary_emb(k[:, :, 1:], (cos_emb, sin_emb)).bfloat16() + x = F.scaled_dot_product_attention(q, k, v, dropout_p=self.attn_drop_rate) + else: + x = flash_attn_qkvpacked_func(qkv=qkv.bfloat16(), dropout_p=self.attn_drop_rate) - x = x.squeeze(1).permute(0, 2, 1).reshape(batch_size * time, channels, height, width) + # the output of sdpa = (batch, num_heads, seq_len, head_dim) + x = x.permute(0, 2, 1, 3).reshape(batch_size, time * height * width, channels) # output projection x = self.proj(x) - # Reshape back: [(b*t), c, h, w] -> [b, c, t, h, w] - x = x.view(batch_size, time, channels, height, width) - x = x.permute(0, 2, 1, 3, 4) + # Reshape back: [b, t*h*w, c] -> [b, c, t, h, w] + x = x.permute(0, 2, 1).reshape(batch_size, channels, time, height, width) return x + identity