From 89f39bec13b8e027e11d8aecbe77b6f13f1ba771 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 2 Jul 2025 18:55:17 +0300 Subject: [PATCH 1/7] Refactor output norm to use AdaLayerNorm in Wan transformers Replace the final `FP32LayerNorm` and manual shift/scale application with a single `AdaLayerNorm` module in both the `WanTransformer3DModel` and `WanVACETransformer3DModel`. This change simplifies the forward pass by encapsulating the adaptive normalization logic within the `AdaLayerNorm` layer, removing the need for a separate `scale_shift_table`. The `_no_split_modules` list is also updated to include `norm_out` for compatibility with model parallelism. --- .../models/transformers/transformer_wan.py | 24 ++++++++----------- .../transformers/transformer_wan_vace.py | 24 ++++++++----------- 2 files changed, 20 insertions(+), 28 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_wan.py b/src/diffusers/models/transformers/transformer_wan.py index 5fb71b69f7ac..9702d47a9cf5 100644 --- a/src/diffusers/models/transformers/transformer_wan.py +++ b/src/diffusers/models/transformers/transformer_wan.py @@ -28,7 +28,7 @@ from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin -from ..normalization import FP32LayerNorm +from ..normalization import AdaLayerNorm, FP32LayerNorm logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -370,7 +370,7 @@ class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi _supports_gradient_checkpointing = True _skip_layerwise_casting_patterns = ["patch_embedding", "condition_embedder", "norm"] - _no_split_modules = ["WanTransformerBlock"] + _no_split_modules = ["WanTransformerBlock", "norm_out"] _keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"] _keys_to_ignore_on_load_unexpected = ["norm_added_q"] _repeated_blocks = ["WanTransformerBlock"] @@ -426,9 +426,14 @@ def __init__( ) # 4. Output norm & projection - self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False) + self.norm_out = AdaLayerNorm( + embedding_dim=inner_dim, + output_dim=2 * inner_dim, + norm_elementwise_affine=False, + norm_eps=eps, + chunk_dim=1, + ) self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size)) - self.scale_shift_table = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5) self.gradient_checkpointing = False @@ -486,16 +491,7 @@ def forward( hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb) # 5. Output norm, projection & unpatchify - shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1) - - # Move the shift and scale tensors to the same device as hidden_states. - # When using multi-GPU inference via accelerate these will be on the - # first device rather than the last device, which hidden_states ends up - # on. - shift = shift.to(hidden_states.device) - scale = scale.to(hidden_states.device) - - hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states) + hidden_states = self.norm_out(hidden_states, temb=temb) hidden_states = self.proj_out(hidden_states) hidden_states = hidden_states.reshape( diff --git a/src/diffusers/models/transformers/transformer_wan_vace.py b/src/diffusers/models/transformers/transformer_wan_vace.py index 1a6f2af59a87..0813051cf5fc 100644 --- a/src/diffusers/models/transformers/transformer_wan_vace.py +++ b/src/diffusers/models/transformers/transformer_wan_vace.py @@ -26,7 +26,7 @@ from ..cache_utils import CacheMixin from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin -from ..normalization import FP32LayerNorm +from ..normalization import AdaLayerNorm, FP32LayerNorm from .transformer_wan import WanAttnProcessor2_0, WanRotaryPosEmbed, WanTimeTextImageEmbedding, WanTransformerBlock @@ -179,7 +179,7 @@ class WanVACETransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromO _supports_gradient_checkpointing = True _skip_layerwise_casting_patterns = ["patch_embedding", "vace_patch_embedding", "condition_embedder", "norm"] - _no_split_modules = ["WanTransformerBlock", "WanVACETransformerBlock"] + _no_split_modules = ["WanTransformerBlock", "WanVACETransformerBlock", "norm_out"] _keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"] _keys_to_ignore_on_load_unexpected = ["norm_added_q"] @@ -259,9 +259,14 @@ def __init__( ) # 4. Output norm & projection - self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False) + self.norm_out = AdaLayerNorm( + embedding_dim=inner_dim, + output_dim=2 * inner_dim, + norm_elementwise_affine=False, + norm_eps=eps, + chunk_dim=1, + ) self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size)) - self.scale_shift_table = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5) self.gradient_checkpointing = False @@ -365,16 +370,7 @@ def forward( hidden_states = hidden_states + control_hint * scale # 6. Output norm, projection & unpatchify - shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1) - - # Move the shift and scale tensors to the same device as hidden_states. - # When using multi-GPU inference via accelerate these will be on the - # first device rather than the last device, which hidden_states ends up - # on. - shift = shift.to(hidden_states.device) - scale = scale.to(hidden_states.device) - - hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states) + hidden_states = self.norm_out(hidden_states, temb=temb) hidden_states = self.proj_out(hidden_states) hidden_states = hidden_states.reshape( From e4b30b88beb170fdcdcd36d6a5c1fe76ba1cb2fd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 2 Jul 2025 19:51:34 +0300 Subject: [PATCH 2/7] fix: remove scale_shift_table from _keep_in_fp32_modules in Wan and WanVACE transformers --- src/diffusers/models/transformers/transformer_wan.py | 2 +- src/diffusers/models/transformers/transformer_wan_vace.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_wan.py b/src/diffusers/models/transformers/transformer_wan.py index 9702d47a9cf5..d5e2fe2aea31 100644 --- a/src/diffusers/models/transformers/transformer_wan.py +++ b/src/diffusers/models/transformers/transformer_wan.py @@ -371,7 +371,7 @@ class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi _supports_gradient_checkpointing = True _skip_layerwise_casting_patterns = ["patch_embedding", "condition_embedder", "norm"] _no_split_modules = ["WanTransformerBlock", "norm_out"] - _keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"] + _keep_in_fp32_modules = ["time_embedder", "norm1", "norm2", "norm3"] _keys_to_ignore_on_load_unexpected = ["norm_added_q"] _repeated_blocks = ["WanTransformerBlock"] diff --git a/src/diffusers/models/transformers/transformer_wan_vace.py b/src/diffusers/models/transformers/transformer_wan_vace.py index 0813051cf5fc..2a6a64032f5a 100644 --- a/src/diffusers/models/transformers/transformer_wan_vace.py +++ b/src/diffusers/models/transformers/transformer_wan_vace.py @@ -180,7 +180,7 @@ class WanVACETransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromO _supports_gradient_checkpointing = True _skip_layerwise_casting_patterns = ["patch_embedding", "vace_patch_embedding", "condition_embedder", "norm"] _no_split_modules = ["WanTransformerBlock", "WanVACETransformerBlock", "norm_out"] - _keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"] + _keep_in_fp32_modules = ["time_embedder", "norm1", "norm2", "norm3"] _keys_to_ignore_on_load_unexpected = ["norm_added_q"] @register_to_config From 92f8237638bc18a6539cbcb6c72461ef3dac1728 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 2 Jul 2025 20:21:05 +0300 Subject: [PATCH 3/7] Fixes transformed head modulation layer mapping Updates the key mapping for the `head.modulation` layer to `norm_out.linear` in the model conversion script. This correction ensures that weights are loaded correctly for both standard and VACE transformer models. --- scripts/convert_wan_to_diffusers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/convert_wan_to_diffusers.py b/scripts/convert_wan_to_diffusers.py index 6d25cde071b1..24cb798cc198 100644 --- a/scripts/convert_wan_to_diffusers.py +++ b/scripts/convert_wan_to_diffusers.py @@ -25,7 +25,7 @@ "text_embedding.0": "condition_embedder.text_embedder.linear_1", "text_embedding.2": "condition_embedder.text_embedder.linear_2", "time_projection.1": "condition_embedder.time_proj", - "head.modulation": "scale_shift_table", + "head.modulation": "norm_out.linear", "head.head": "proj_out", "modulation": "scale_shift_table", "ffn.0": "ffn.net.0.proj", @@ -67,7 +67,7 @@ "text_embedding.0": "condition_embedder.text_embedder.linear_1", "text_embedding.2": "condition_embedder.text_embedder.linear_2", "time_projection.1": "condition_embedder.time_proj", - "head.modulation": "scale_shift_table", + "head.modulation": "norm_out.linear", "head.head": "proj_out", "modulation": "scale_shift_table", "ffn.0": "ffn.net.0.proj", From df07b88b0fac437dd1320ea2d67988017940dad0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 2 Jul 2025 20:32:49 +0300 Subject: [PATCH 4/7] Fix: Revert removing `scale_shift_table` from `_keep_in_fp32_modules` in Wan and WanVACE transformers --- src/diffusers/models/transformers/transformer_wan.py | 2 +- src/diffusers/models/transformers/transformer_wan_vace.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_wan.py b/src/diffusers/models/transformers/transformer_wan.py index d5e2fe2aea31..9702d47a9cf5 100644 --- a/src/diffusers/models/transformers/transformer_wan.py +++ b/src/diffusers/models/transformers/transformer_wan.py @@ -371,7 +371,7 @@ class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi _supports_gradient_checkpointing = True _skip_layerwise_casting_patterns = ["patch_embedding", "condition_embedder", "norm"] _no_split_modules = ["WanTransformerBlock", "norm_out"] - _keep_in_fp32_modules = ["time_embedder", "norm1", "norm2", "norm3"] + _keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"] _keys_to_ignore_on_load_unexpected = ["norm_added_q"] _repeated_blocks = ["WanTransformerBlock"] diff --git a/src/diffusers/models/transformers/transformer_wan_vace.py b/src/diffusers/models/transformers/transformer_wan_vace.py index 2a6a64032f5a..0813051cf5fc 100644 --- a/src/diffusers/models/transformers/transformer_wan_vace.py +++ b/src/diffusers/models/transformers/transformer_wan_vace.py @@ -180,7 +180,7 @@ class WanVACETransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromO _supports_gradient_checkpointing = True _skip_layerwise_casting_patterns = ["patch_embedding", "vace_patch_embedding", "condition_embedder", "norm"] _no_split_modules = ["WanTransformerBlock", "WanVACETransformerBlock", "norm_out"] - _keep_in_fp32_modules = ["time_embedder", "norm1", "norm2", "norm3"] + _keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"] _keys_to_ignore_on_load_unexpected = ["norm_added_q"] @register_to_config From e555903067c43d13fbafa84d0f39d918a408e7d1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 2 Jul 2025 21:22:17 +0300 Subject: [PATCH 5/7] Refactors transformer output blocks to use AdaLayerNorm Replaces the manual implementation of adaptive layer normalization, which used a separate `scale_shift_table` and `nn.LayerNorm`, with the unified `AdaLayerNorm` module. This change simplifies the forward pass logic in several transformer models by encapsulating the normalization and modulation steps into a single component. It also adds `norm_out` to `_no_split_modules` for model parallelism compatibility. --- .../transformers/latte_transformer_3d.py | 17 +++++++++------- .../transformers/pixart_transformer_2d.py | 20 +++++++++---------- .../transformers/transformer_allegro.py | 18 +++++++++-------- .../models/transformers/transformer_ltx.py | 18 +++++++++-------- 4 files changed, 40 insertions(+), 33 deletions(-) diff --git a/src/diffusers/models/transformers/latte_transformer_3d.py b/src/diffusers/models/transformers/latte_transformer_3d.py index 990c90512e39..486969047d6a 100644 --- a/src/diffusers/models/transformers/latte_transformer_3d.py +++ b/src/diffusers/models/transformers/latte_transformer_3d.py @@ -23,11 +23,12 @@ from ..embeddings import PatchEmbed, PixArtAlphaTextProjection, get_1d_sincos_pos_embed_from_grid from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin -from ..normalization import AdaLayerNormSingle +from ..normalization import AdaLayerNorm, AdaLayerNormSingle class LatteTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin): _supports_gradient_checkpointing = True + _no_split_modules = ["norm_out"] """ A 3D Transformer model for video-like data, paper: https://huggingface.co/papers/2401.03048, official code: @@ -149,8 +150,13 @@ def __init__( # 4. Define output layers self.out_channels = in_channels if out_channels is None else out_channels - self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) - self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5) + self.norm_out = AdaLayerNorm( + embedding_dim=inner_dim, + output_dim=2 * inner_dim, + norm_elementwise_affine=False, + norm_eps=1e-6, + chunk_dim=1, + ) self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels) # 5. Latte other blocks. @@ -305,10 +311,7 @@ def forward( embedded_timestep = embedded_timestep.repeat_interleave( num_frame, dim=0, output_size=embedded_timestep.shape[0] * num_frame ).view(-1, embedded_timestep.shape[-1]) - shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1) - hidden_states = self.norm_out(hidden_states) - # Modulation - hidden_states = hidden_states * (1 + scale) + shift + hidden_states = self.norm_out(hidden_states, temb=embedded_timestep) hidden_states = self.proj_out(hidden_states) # unpatchify diff --git a/src/diffusers/models/transformers/pixart_transformer_2d.py b/src/diffusers/models/transformers/pixart_transformer_2d.py index 40a14bfd9b27..f05da448367d 100644 --- a/src/diffusers/models/transformers/pixart_transformer_2d.py +++ b/src/diffusers/models/transformers/pixart_transformer_2d.py @@ -23,7 +23,7 @@ from ..embeddings import PatchEmbed, PixArtAlphaTextProjection from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin -from ..normalization import AdaLayerNormSingle +from ..normalization import AdaLayerNorm, AdaLayerNormSingle logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -78,7 +78,7 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin): """ _supports_gradient_checkpointing = True - _no_split_modules = ["BasicTransformerBlock", "PatchEmbed"] + _no_split_modules = ["BasicTransformerBlock", "PatchEmbed", "norm_out"] _skip_layerwise_casting_patterns = ["pos_embed", "norm", "adaln_single"] @register_to_config @@ -171,8 +171,13 @@ def __init__( ) # 3. Output blocks. - self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6) - self.scale_shift_table = nn.Parameter(torch.randn(2, self.inner_dim) / self.inner_dim**0.5) + self.norm_out = AdaLayerNorm( + embedding_dim=self.inner_dim, + output_dim=2 * self.inner_dim, + norm_elementwise_affine=False, + norm_eps=1e-6, + chunk_dim=1, + ) self.proj_out = nn.Linear(self.inner_dim, self.config.patch_size * self.config.patch_size * self.out_channels) self.adaln_single = AdaLayerNormSingle( @@ -406,12 +411,7 @@ def forward( ) # 3. Output - shift, scale = ( - self.scale_shift_table[None] + embedded_timestep[:, None].to(self.scale_shift_table.device) - ).chunk(2, dim=1) - hidden_states = self.norm_out(hidden_states) - # Modulation - hidden_states = hidden_states * (1 + scale.to(hidden_states.device)) + shift.to(hidden_states.device) + hidden_states = self.norm_out(hidden_states, temb=embedded_timestep) hidden_states = self.proj_out(hidden_states) hidden_states = hidden_states.squeeze(1) diff --git a/src/diffusers/models/transformers/transformer_allegro.py b/src/diffusers/models/transformers/transformer_allegro.py index 5fa59a71d977..a20c64801de5 100644 --- a/src/diffusers/models/transformers/transformer_allegro.py +++ b/src/diffusers/models/transformers/transformer_allegro.py @@ -28,7 +28,7 @@ from ..embeddings import PatchEmbed, PixArtAlphaTextProjection from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin -from ..normalization import AdaLayerNormSingle +from ..normalization import AdaLayerNorm, AdaLayerNormSingle logger = logging.get_logger(__name__) @@ -175,6 +175,7 @@ def forward( class AllegroTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin): _supports_gradient_checkpointing = True + _no_split_modules = ["norm_out"] """ A 3D Transformer model for video-like data. @@ -292,8 +293,13 @@ def __init__( ) # 3. Output projection & norm - self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6) - self.scale_shift_table = nn.Parameter(torch.randn(2, self.inner_dim) / self.inner_dim**0.5) + self.norm_out = AdaLayerNorm( + embedding_dim=self.inner_dim, + output_dim=2 * self.inner_dim, + norm_elementwise_affine=False, + norm_eps=1e-6, + chunk_dim=1, + ) self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * out_channels) # 4. Timestep embeddings @@ -393,11 +399,7 @@ def forward( ) # 4. Output normalization & projection - shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1) - hidden_states = self.norm_out(hidden_states) - - # Modulation - hidden_states = hidden_states * (1 + scale) + shift + hidden_states = self.norm_out(hidden_states, temb=embedded_timestep) hidden_states = self.proj_out(hidden_states) hidden_states = hidden_states.squeeze(1) diff --git a/src/diffusers/models/transformers/transformer_ltx.py b/src/diffusers/models/transformers/transformer_ltx.py index 2d06124282d1..91e40c8f3471 100644 --- a/src/diffusers/models/transformers/transformer_ltx.py +++ b/src/diffusers/models/transformers/transformer_ltx.py @@ -30,7 +30,7 @@ from ..embeddings import PixArtAlphaTextProjection from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin -from ..normalization import AdaLayerNormSingle, RMSNorm +from ..normalization import AdaLayerNorm, AdaLayerNormSingle, RMSNorm logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -328,6 +328,7 @@ class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin _supports_gradient_checkpointing = True _skip_layerwise_casting_patterns = ["norm"] + _no_split_modules = ["norm_out"] _repeated_blocks = ["LTXVideoTransformerBlock"] @register_to_config @@ -356,7 +357,6 @@ def __init__( self.proj_in = nn.Linear(in_channels, inner_dim) - self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5) self.time_embed = AdaLayerNormSingle(inner_dim, use_additional_conditions=False) self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim) @@ -389,7 +389,13 @@ def __init__( ] ) - self.norm_out = nn.LayerNorm(inner_dim, eps=1e-6, elementwise_affine=False) + self.norm_out = AdaLayerNorm( + embedding_dim=inner_dim, + output_dim=2 * inner_dim, + norm_elementwise_affine=False, + norm_eps=1e-6, + chunk_dim=1, + ) self.proj_out = nn.Linear(inner_dim, out_channels) self.gradient_checkpointing = False @@ -464,11 +470,7 @@ def forward( encoder_attention_mask=encoder_attention_mask, ) - scale_shift_values = self.scale_shift_table[None, None] + embedded_timestep[:, :, None] - shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] - - hidden_states = self.norm_out(hidden_states) - hidden_states = hidden_states * (1 + scale) + shift + hidden_states = self.norm_out(hidden_states, temb=embedded_timestep.squeeze(1)) output = self.proj_out(hidden_states) if USE_PEFT_BACKEND: From 921396a147868486a5662e10e3285bdd5ee35550 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Thu, 3 Jul 2025 18:37:20 +0300 Subject: [PATCH 6/7] Fix `head.modulation` mapping in conversion script Corrects the target key for `head.modulation` to `norm_out.linear.weight`. This ensures the weights are correctly mapped to the weight parameter of the output normalization layer during model conversion for both transformer types. --- scripts/convert_wan_to_diffusers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/convert_wan_to_diffusers.py b/scripts/convert_wan_to_diffusers.py index 24cb798cc198..cf287bad0149 100644 --- a/scripts/convert_wan_to_diffusers.py +++ b/scripts/convert_wan_to_diffusers.py @@ -25,7 +25,7 @@ "text_embedding.0": "condition_embedder.text_embedder.linear_1", "text_embedding.2": "condition_embedder.text_embedder.linear_2", "time_projection.1": "condition_embedder.time_proj", - "head.modulation": "norm_out.linear", + "head.modulation": "norm_out.linear.weight", "head.head": "proj_out", "modulation": "scale_shift_table", "ffn.0": "ffn.net.0.proj", @@ -67,7 +67,7 @@ "text_embedding.0": "condition_embedder.text_embedder.linear_1", "text_embedding.2": "condition_embedder.text_embedder.linear_2", "time_projection.1": "condition_embedder.time_proj", - "head.modulation": "norm_out.linear", + "head.modulation": "norm_out.linear.weight", "head.head": "proj_out", "modulation": "scale_shift_table", "ffn.0": "ffn.net.0.proj", From ff95d5db6badb99242006733008d08eda1050ee7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Thu, 3 Jul 2025 19:51:01 +0300 Subject: [PATCH 7/7] Fix handling of missing bias keys in conversion script Adds a default zero-initialized bias tensor for the transformer's output normalization layer if it is missing from the original state dictionary. --- scripts/convert_wan_to_diffusers.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/scripts/convert_wan_to_diffusers.py b/scripts/convert_wan_to_diffusers.py index cf287bad0149..012982fe8f97 100644 --- a/scripts/convert_wan_to_diffusers.py +++ b/scripts/convert_wan_to_diffusers.py @@ -105,8 +105,12 @@ "after_proj": "proj_out", } -TRANSFORMER_SPECIAL_KEYS_REMAP = {} -VACE_TRANSFORMER_SPECIAL_KEYS_REMAP = {} +TRANSFORMER_SPECIAL_KEYS_REMAP = { + "norm_out.linear.bias": lambda key, state_dict: state_dict.setdefault(key, torch.zeros(state_dict["norm_out.linear.weight"].shape[0])) +} +VACE_TRANSFORMER_SPECIAL_KEYS_REMAP = { + "norm_out.linear.bias": lambda key, state_dict: state_dict.setdefault(key, torch.zeros(state_dict["norm_out.linear.weight"].shape[0])) +} def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]: @@ -308,6 +312,10 @@ def convert_transformer(model_type: str): continue handler_fn_inplace(key, original_state_dict) + for special_key, handler_fn_inplace in SPECIAL_KEYS_REMAP.items(): + if special_key not in original_state_dict: + handler_fn_inplace(special_key, original_state_dict) + transformer.load_state_dict(original_state_dict, strict=True, assign=True) return transformer