-
Notifications
You must be signed in to change notification settings - Fork 6.5k
add ChronoEdit #12593
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
add ChronoEdit #12593
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,329 @@ | ||
| # Copyright 2025 The ChronoEdit Team and The HuggingFace Team. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import math | ||
| from typing import Any, Dict, Optional, Tuple, Union | ||
|
|
||
| import torch | ||
| import torch.nn as nn | ||
|
|
||
| from ...configuration_utils import ConfigMixin, register_to_config | ||
| from ...loaders import FromOriginalModelMixin, PeftAdapterMixin | ||
| from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers | ||
| from .._modeling_parallel import ContextParallelInput, ContextParallelOutput | ||
| from ..attention import AttentionMixin | ||
| from ..cache_utils import CacheMixin | ||
| from ..embeddings import get_1d_rotary_pos_embed | ||
| from ..modeling_outputs import Transformer2DModelOutput | ||
| from ..modeling_utils import ModelMixin | ||
| from ..normalization import FP32LayerNorm | ||
| from .transformer_wan import WanTimeTextImageEmbedding, WanTransformerBlock | ||
|
|
||
|
|
||
| logger = logging.get_logger(__name__) # pylint: disable=invalid-name | ||
|
|
||
|
|
||
| class ChronoEditRotaryPosEmbed(nn.Module): | ||
| def __init__( | ||
| self, | ||
| attention_head_dim: int, | ||
| patch_size: Tuple[int, int, int], | ||
| max_seq_len: int, | ||
| theta: float = 10000.0, | ||
| temporal_skip_len: int = 8, | ||
| ): | ||
| super().__init__() | ||
|
|
||
| self.attention_head_dim = attention_head_dim | ||
| self.patch_size = patch_size | ||
| self.max_seq_len = max_seq_len | ||
| self.temporal_skip_len = temporal_skip_len | ||
|
|
||
| h_dim = w_dim = 2 * (attention_head_dim // 6) | ||
| t_dim = attention_head_dim - h_dim - w_dim | ||
| freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64 | ||
|
|
||
| freqs_cos = [] | ||
| freqs_sin = [] | ||
|
|
||
| for dim in [t_dim, h_dim, w_dim]: | ||
| freq_cos, freq_sin = get_1d_rotary_pos_embed( | ||
| dim, | ||
| max_seq_len, | ||
| theta, | ||
| use_real=True, | ||
| repeat_interleave_real=True, | ||
| freqs_dtype=freqs_dtype, | ||
| ) | ||
| freqs_cos.append(freq_cos) | ||
| freqs_sin.append(freq_sin) | ||
|
|
||
| self.register_buffer("freqs_cos", torch.cat(freqs_cos, dim=1), persistent=False) | ||
| self.register_buffer("freqs_sin", torch.cat(freqs_sin, dim=1), persistent=False) | ||
|
|
||
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | ||
| batch_size, num_channels, num_frames, height, width = hidden_states.shape | ||
| p_t, p_h, p_w = self.patch_size | ||
| ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w | ||
|
|
||
| split_sizes = [ | ||
| self.attention_head_dim - 2 * (self.attention_head_dim // 3), | ||
| self.attention_head_dim // 3, | ||
| self.attention_head_dim // 3, | ||
| ] | ||
|
|
||
| freqs_cos = self.freqs_cos.split(split_sizes, dim=1) | ||
| freqs_sin = self.freqs_sin.split(split_sizes, dim=1) | ||
|
|
||
| assert num_frames == 2 or num_frames == self.temporal_skip_len, ( | ||
|
||
| f"num_frames must be 2 or {self.temporal_skip_len}, but got {num_frames}" | ||
| ) | ||
| if num_frames == 2: | ||
| freqs_cos_f = freqs_cos[0][: self.temporal_skip_len][[0, -1]].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1) | ||
| else: | ||
| freqs_cos_f = freqs_cos[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1) | ||
| freqs_cos_h = freqs_cos[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1) | ||
| freqs_cos_w = freqs_cos[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1) | ||
|
|
||
| if num_frames == 2: | ||
| freqs_sin_f = freqs_sin[0][: self.temporal_skip_len][[0, -1]].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1) | ||
| else: | ||
| freqs_sin_f = freqs_sin[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1) | ||
| freqs_sin_h = freqs_sin[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1) | ||
| freqs_sin_w = freqs_sin[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1) | ||
|
|
||
| freqs_cos = torch.cat([freqs_cos_f, freqs_cos_h, freqs_cos_w], dim=-1).reshape(1, ppf * pph * ppw, 1, -1) | ||
| freqs_sin = torch.cat([freqs_sin_f, freqs_sin_h, freqs_sin_w], dim=-1).reshape(1, ppf * pph * ppw, 1, -1) | ||
|
|
||
| return freqs_cos, freqs_sin | ||
|
|
||
|
|
||
| class ChronoEditTransformer3DModel( | ||
| ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin, AttentionMixin | ||
| ): | ||
| r""" | ||
| A Transformer model for video-like data used in the ChronoEdit model. | ||
|
|
||
| Args: | ||
| patch_size (`Tuple[int]`, defaults to `(1, 2, 2)`): | ||
| 3D patch dimensions for video embedding (t_patch, h_patch, w_patch). | ||
| num_attention_heads (`int`, defaults to `40`): | ||
| Fixed length for text embeddings. | ||
| attention_head_dim (`int`, defaults to `128`): | ||
| The number of channels in each head. | ||
| in_channels (`int`, defaults to `16`): | ||
| The number of channels in the input. | ||
| out_channels (`int`, defaults to `16`): | ||
| The number of channels in the output. | ||
| text_dim (`int`, defaults to `512`): | ||
| Input dimension for text embeddings. | ||
| freq_dim (`int`, defaults to `256`): | ||
| Dimension for sinusoidal time embeddings. | ||
| ffn_dim (`int`, defaults to `13824`): | ||
| Intermediate dimension in feed-forward network. | ||
| num_layers (`int`, defaults to `40`): | ||
| The number of layers of transformer blocks to use. | ||
| window_size (`Tuple[int]`, defaults to `(-1, -1)`): | ||
| Window size for local attention (-1 indicates global attention). | ||
| cross_attn_norm (`bool`, defaults to `True`): | ||
| Enable cross-attention normalization. | ||
| qk_norm (`bool`, defaults to `True`): | ||
| Enable query/key normalization. | ||
| eps (`float`, defaults to `1e-6`): | ||
| Epsilon value for normalization layers. | ||
| add_img_emb (`bool`, defaults to `False`): | ||
| Whether to use img_emb. | ||
| added_kv_proj_dim (`int`, *optional*, defaults to `None`): | ||
| The number of channels to use for the added key and value projections. If `None`, no projection is used. | ||
| """ | ||
|
|
||
| _supports_gradient_checkpointing = True | ||
| _skip_layerwise_casting_patterns = ["patch_embedding", "condition_embedder", "norm"] | ||
| _no_split_modules = ["WanTransformerBlock"] | ||
| _keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"] | ||
| _keys_to_ignore_on_load_unexpected = ["norm_added_q"] | ||
| _repeated_blocks = ["WanTransformerBlock"] | ||
| _cp_plan = { | ||
| "rope": { | ||
| 0: ContextParallelInput(split_dim=1, expected_dims=4, split_output=True), | ||
| 1: ContextParallelInput(split_dim=1, expected_dims=4, split_output=True), | ||
| }, | ||
| "blocks.0": { | ||
| "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), | ||
| }, | ||
| "blocks.*": { | ||
| "encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), | ||
| }, | ||
| "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3), | ||
| } | ||
|
|
||
| @register_to_config | ||
| def __init__( | ||
| self, | ||
| patch_size: Tuple[int] = (1, 2, 2), | ||
| num_attention_heads: int = 40, | ||
| attention_head_dim: int = 128, | ||
| in_channels: int = 16, | ||
| out_channels: int = 16, | ||
| text_dim: int = 4096, | ||
| freq_dim: int = 256, | ||
| ffn_dim: int = 13824, | ||
| num_layers: int = 40, | ||
| cross_attn_norm: bool = True, | ||
| qk_norm: Optional[str] = "rms_norm_across_heads", | ||
| eps: float = 1e-6, | ||
| image_dim: Optional[int] = None, | ||
| added_kv_proj_dim: Optional[int] = None, | ||
| rope_max_seq_len: int = 1024, | ||
| pos_embed_seq_len: Optional[int] = None, | ||
| rope_temporal_skip_len: int = 8, | ||
| ) -> None: | ||
| super().__init__() | ||
|
|
||
| inner_dim = num_attention_heads * attention_head_dim | ||
| out_channels = out_channels or in_channels | ||
|
|
||
| # 1. Patch & position embedding | ||
| self.rope = ChronoEditRotaryPosEmbed( | ||
| attention_head_dim, patch_size, rope_max_seq_len, temporal_skip_len=rope_temporal_skip_len | ||
| ) | ||
| self.patch_embedding = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size) | ||
|
|
||
| # 2. Condition embeddings | ||
| # image_embedding_dim=1280 for I2V model | ||
| self.condition_embedder = WanTimeTextImageEmbedding( | ||
| dim=inner_dim, | ||
| time_freq_dim=freq_dim, | ||
| time_proj_dim=inner_dim * 6, | ||
| text_embed_dim=text_dim, | ||
| image_embed_dim=image_dim, | ||
| pos_embed_seq_len=pos_embed_seq_len, | ||
| ) | ||
|
|
||
| # 3. Transformer blocks | ||
| self.blocks = nn.ModuleList( | ||
| [ | ||
| WanTransformerBlock( | ||
| inner_dim, ffn_dim, num_attention_heads, qk_norm, cross_attn_norm, eps, added_kv_proj_dim | ||
| ) | ||
| for _ in range(num_layers) | ||
| ] | ||
| ) | ||
|
|
||
| # 4. Output norm & projection | ||
| self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False) | ||
| self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size)) | ||
| self.scale_shift_table = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5) | ||
|
|
||
| self.gradient_checkpointing = False | ||
|
|
||
| def forward( | ||
| self, | ||
| hidden_states: torch.Tensor, | ||
| timestep: torch.LongTensor, | ||
| encoder_hidden_states: torch.Tensor, | ||
| encoder_hidden_states_image: Optional[torch.Tensor] = None, | ||
| return_dict: bool = True, | ||
| attention_kwargs: Optional[Dict[str, Any]] = None, | ||
| ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: | ||
| if attention_kwargs is not None: | ||
| attention_kwargs = attention_kwargs.copy() | ||
| lora_scale = attention_kwargs.pop("scale", 1.0) | ||
| else: | ||
| lora_scale = 1.0 | ||
|
|
||
| if USE_PEFT_BACKEND: | ||
| # weight the lora layers by setting `lora_scale` for each PEFT layer | ||
| scale_lora_layers(self, lora_scale) | ||
| else: | ||
| if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: | ||
| logger.warning( | ||
| "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." | ||
| ) | ||
|
|
||
| batch_size, num_channels, num_frames, height, width = hidden_states.shape | ||
| p_t, p_h, p_w = self.config.patch_size | ||
| post_patch_num_frames = num_frames // p_t | ||
| post_patch_height = height // p_h | ||
| post_patch_width = width // p_w | ||
|
|
||
| rotary_emb = self.rope(hidden_states) | ||
|
|
||
| hidden_states = self.patch_embedding(hidden_states) | ||
| hidden_states = hidden_states.flatten(2).transpose(1, 2) | ||
|
|
||
| # timestep shape: batch_size, or batch_size, seq_len (wan 2.2 ti2v) | ||
| if timestep.ndim == 2: | ||
| ts_seq_len = timestep.shape[1] | ||
| timestep = timestep.flatten() # batch_size * seq_len | ||
| else: | ||
| ts_seq_len = None | ||
|
|
||
| temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder( | ||
| timestep, encoder_hidden_states, encoder_hidden_states_image, timestep_seq_len=ts_seq_len | ||
| ) | ||
| if ts_seq_len is not None: | ||
| # batch_size, seq_len, 6, inner_dim | ||
| timestep_proj = timestep_proj.unflatten(2, (6, -1)) | ||
| else: | ||
| # batch_size, 6, inner_dim | ||
| timestep_proj = timestep_proj.unflatten(1, (6, -1)) | ||
|
|
||
| if encoder_hidden_states_image is not None: | ||
| encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1) | ||
|
|
||
| # 4. Transformer blocks | ||
| if torch.is_grad_enabled() and self.gradient_checkpointing: | ||
| for block in self.blocks: | ||
| hidden_states = self._gradient_checkpointing_func( | ||
| block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb | ||
| ) | ||
| else: | ||
| for block in self.blocks: | ||
| hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb) | ||
|
|
||
| # 5. Output norm, projection & unpatchify | ||
| if temb.ndim == 3: | ||
| # batch_size, seq_len, inner_dim (wan 2.2 ti2v) | ||
| shift, scale = (self.scale_shift_table.unsqueeze(0).to(temb.device) + temb.unsqueeze(2)).chunk(2, dim=2) | ||
| shift = shift.squeeze(2) | ||
| scale = scale.squeeze(2) | ||
| else: | ||
| # batch_size, inner_dim | ||
| shift, scale = (self.scale_shift_table.to(temb.device) + temb.unsqueeze(1)).chunk(2, dim=1) | ||
|
|
||
| # Move the shift and scale tensors to the same device as hidden_states. | ||
| # When using multi-GPU inference via accelerate these will be on the | ||
| # first device rather than the last device, which hidden_states ends up | ||
| # on. | ||
| shift = shift.to(hidden_states.device) | ||
| scale = scale.to(hidden_states.device) | ||
|
|
||
| hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states) | ||
| hidden_states = self.proj_out(hidden_states) | ||
|
|
||
| hidden_states = hidden_states.reshape( | ||
| batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1 | ||
| ) | ||
| hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6) | ||
| output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) | ||
|
|
||
| if USE_PEFT_BACKEND: | ||
| # remove `lora_scale` from each PEFT layer | ||
| unscale_lora_layers(self, lora_scale) | ||
|
|
||
| if not return_dict: | ||
| return (output,) | ||
|
|
||
| return Transformer2DModelOutput(sample=output) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we copy over these 2 things and add a
#Copied from, instead of importing from wan?Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yep, that makes sense. so we’ll need to copy the all the modules in
transformer_wanhere.