Skip to content

Commit 0145135

Browse files
committed
Update Wan Animate conversion script to reflect changes to transformer
1 parent ac2962d commit 0145135

File tree

1 file changed

+119
-114
lines changed

1 file changed

+119
-114
lines changed

scripts/convert_wan_to_diffusers.py

Lines changed: 119 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -157,59 +157,23 @@
157157
"attn2.to_k_img": "attn2.add_k_proj",
158158
"attn2.to_v_img": "attn2.add_v_proj",
159159
"attn2.norm_k_img": "attn2.norm_added_k",
160+
# Wan Animate-specific mappings (motion encoder, face encoder, face adapter)
160161
# Motion encoder mappings
161-
"motion_encoder.enc.net_app.convs": "condition_embedder.motion_embedder.convs",
162-
"motion_encoder.enc.fc": "condition_embedder.motion_embedder.linears",
163-
"motion_encoder.dec.direction.weight": "condition_embedder.motion_embedder.motion_synthesis_weight",
162+
# The name mapping is complicated for the convolutional part so we handle that in its own function
163+
"motion_encoder.enc.fc": "motion_encoder.motion_network",
164+
"motion_encoder.dec.direction.weight": "motion_encoder.motion_synthesis_weight",
164165
# Face encoder mappings - CausalConv1d has a .conv submodule that we need to flatten
165-
"face_encoder.conv1_local.conv": "condition_embedder.face_embedder.conv1_local",
166-
"face_encoder.conv2.conv": "condition_embedder.face_embedder.conv2",
167-
"face_encoder.conv3.conv": "condition_embedder.face_embedder.conv3",
168-
"face_encoder.out_proj": "condition_embedder.face_embedder.out_proj",
169-
"face_encoder.norm1": "condition_embedder.face_embedder.norm1",
170-
# Return to the original order for face_embedder norms
171-
"face_encoder.norm2": "face_embedder_norm__placeholder",
172-
"face_encoder.norm3": "condition_embedder.face_embedder.norm2",
173-
"face_embedder_norm__placeholder": "condition_embedder.face_embedder.norm3",
174-
"face_encoder.padding_tokens": "condition_embedder.face_embedder.padding_tokens",
175-
# Face adapter mappings
176-
"face_adapter.fuser_blocks": "face_adapter",
166+
"face_encoder.conv1_local.conv": "face_encoder.conv1_local",
167+
"face_encoder.conv2.conv": "face_encoder.conv2",
168+
"face_encoder.conv3.conv": "face_encoder.conv3",
169+
# Face adapter mappings are handled in a separate function
177170
}
178171

179172

180-
def convert_equal_linear_weight(key: str, state_dict: Dict[str, Any]) -> None:
181-
"""
182-
Convert EqualLinear weights to standard Linear weights by applying the scale factor.
183-
EqualLinear uses: F.linear(input, self.weight * self.scale, bias=self.bias)
184-
where scale = (1 / sqrt(in_dim))
185-
"""
186-
if ".weight" not in key:
187-
return
188-
189-
in_dim = state_dict[key].shape[1]
190-
scale = 1.0 / math.sqrt(in_dim)
191-
state_dict[key] = state_dict[key] * scale
192-
193-
194-
def convert_equal_conv2d_weight(key: str, state_dict: Dict[str, Any]) -> None:
195-
"""
196-
Convert EqualConv2d weights to standard Conv2d weights by applying the scale factor.
197-
EqualConv2d uses: F.conv2d(input, self.weight * self.scale, bias=self.bias, ...)
198-
where scale = 1 / sqrt(in_channel * kernel_size^2)
199-
"""
200-
if ".weight" not in key or len(state_dict[key].shape) != 4:
201-
return
202-
203-
out_channel, in_channel, kernel_size, kernel_size = state_dict[key].shape
204-
scale = 1.0 / math.sqrt(in_channel * kernel_size**2)
205-
state_dict[key] = state_dict[key] * scale
206-
207-
208173
# TODO: Verify this and simplify if possible.
209-
def convert_animate_motion_encoder_weights(key: str, state_dict: Dict[str, Any]) -> None:
174+
def convert_animate_motion_encoder_weights(key: str, state_dict: Dict[str, Any], final_conv_idx: int = 8) -> None:
210175
"""
211176
Convert all motion encoder weights for Animate model.
212-
This handles both EqualLinear (in linears) and EqualConv2d (in convs).
213177
214178
In the original model:
215179
- All Linear layers in fc use EqualLinear
@@ -220,89 +184,135 @@ def convert_animate_motion_encoder_weights(key: str, state_dict: Dict[str, Any])
220184
Conversion strategy:
221185
1. Drop .kernel buffers (blur kernels)
222186
2. Rename sequential indices to named components (e.g., 0 -> conv2d, 1 -> bias_leaky_relu)
223-
3. Scale EqualLinear and EqualConv2d weights
224187
"""
225188
# Skip if not a weight, bias, or kernel
226189
if ".weight" not in key and ".bias" not in key and ".kernel" not in key:
227190
return
228191

229192
# Handle Blur kernel buffers from original implementation.
230-
# After renaming, these appear under: condition_embedder.motion_embedder.convs.*.conv{1,2}.0.kernel
231-
# Diffusers constructs blur kernels procedurally (ConvLayer.blur_conv) so we must drop these keys
232-
if ".kernel" in key and "condition_embedder.motion_embedder.convs" in key:
193+
# After renaming, these appear under: motion_encoder.res_blocks.*.conv{2,skip}.blur_kernel
194+
# Diffusers constructs blur kernels as a non-persistent buffer so we must drop these keys
195+
if ".kernel" in key and "motion_encoder" in key:
233196
# Remove unexpected blur kernel buffers to avoid strict load errors
234197
state_dict.pop(key, None)
235198
return
236199

237200
# Rename Sequential indices to named components in ConvLayer and ResBlock
238-
# This must happen BEFORE weight scaling because we need to rename the keys first
239-
# Original: convs.X.Y.weight/bias or convs.X.conv1/conv2/skip.Y.weight/bias
240-
# Target: convs.X.conv2d.weight or convs.X.conv1/conv2/skip.conv2d.weight or .bias_leaky_relu
241-
if ".convs." in key and (".weight" in key or ".bias" in key):
201+
if ".enc.net_app.convs." in key and (".weight" in key or ".bias" in key):
242202
parts = key.split(".")
243203

244204
# Find the sequential index (digit) after convs or after conv1/conv2/skip
245205
# Examples:
246-
# - convs.0.0.weight -> convs.0.conv2d.weight (ConvLayer, no blur)
247-
# - convs.0.1.weight -> convs.0.conv2d.weight (ConvLayer, with blur at index 0)
248-
# - convs.0.1.bias -> convs.0.bias_leaky_relu (FusedLeakyReLU)
249-
# - convs.1.conv1.1.weight -> convs.1.conv1.conv2d.weight (ResBlock ConvLayer)
250-
# - convs.1.conv1.2.bias -> convs.1.conv1.bias_leaky_relu (ResBlock FusedLeakyReLU)
251-
# - convs.8.weight -> unchanged (final Conv2d, not in Sequential)
252-
253-
# Check if we have a digit as second-to-last part before .weight or .bias
254-
# But we need to distinguish between Sequential indices (convs.X.Y.weight)
255-
# and ModuleList indices (convs.X.weight)
256-
# We only rename if there are at least 3 parts after finding 'convs'
206+
# - enc.net_app.convs.0.0.weight -> conv_in.weight (initial conv layer weight)
207+
# - enc.net_app.convs.0.1.bias -> conv_in.act_fn.bias (initial conv layer bias)
208+
# - enc.net_app.convs.{n:1-7}.conv1.0.weight -> res_blocks.{(n-1):0-6}.conv1.weight (conv1 weight)
209+
# - e.g. enc.net_app.convs.1.conv1.0.weight -> res_blocks.0.conv1.weight
210+
# - enc.net_app.convs.{n:1-7}.conv1.1.bias -> res_blocks.{(n-1):0-6}.conv1.act_fn.bias (conv1 bias)
211+
# - e.g. enc.net_app.convs.1.conv1.1.bias -> res_blocks.0.conv1.act_fn.bias
212+
# - enc.net_app.convs.{n:1-7}.conv2.1.weight -> res_blocks.{(n-1):0-6}.conv2.weight (conv2 weight)
213+
# - enc.net_app.convs.1.conv2.2.bias -> res_blocks.0.conv2.act_fn.bias (conv2 bias)
214+
# - enc.net_app.convs.{n:1-7}.skip.1.weight -> res_blocks.{(n-1):0-6}.conv_skip.weight (skip conv weight)
215+
# - enc.net_app.convs.8 -> conv_out (final conv layer)
216+
257217
convs_idx = parts.index("convs") if "convs" in parts else -1
258-
if (
259-
convs_idx >= 0 and len(parts) - convs_idx > 3
260-
): # e.g., ['convs', '0', '0', 'weight'] has 4 parts after convs
261-
if len(parts) >= 2 and parts[-2].isdigit():
218+
if convs_idx >= 0 and len(parts) - convs_idx >= 2:
219+
bias = False
220+
# The nn.Sequential index will always follow convs
221+
sequential_idx = int(parts[convs_idx + 1])
222+
if sequential_idx == 0:
262223
if key.endswith(".weight"):
263-
# Replace digit index with 'conv2d' for EqualConv2d weight parameters
264-
parts[-2] = "conv2d"
265-
new_key = ".".join(parts)
266-
state_dict[new_key] = state_dict.pop(key)
267-
# Update key for subsequent processing
268-
key = new_key
224+
new_key = "motion_encoder.conv_in.weight"
269225
elif key.endswith(".bias"):
270-
# Replace digit index + .bias with 'bias_leaky_relu' for FusedLeakyReLU bias
271-
new_key = ".".join(parts[:-2]) + ".bias_leaky_relu"
272-
state_dict[new_key] = state_dict.pop(key)
273-
# Bias doesn't need scaling, we're done
274-
return
275-
276-
# Skip blur_conv weights that are already initialized in diffusers
277-
if "blur_conv.weight" in key:
278-
return
226+
new_key = "motion_encoder.conv_in.act_fn.bias"
227+
bias = True
228+
elif sequential_idx == final_conv_idx:
229+
if key.endswith(".weight"):
230+
new_key = "motion_encoder.conv_out.weight"
231+
else:
232+
# Intermediate .convs. layers, which get mapped to .res_blocks.
233+
prefix = "motion_encoder.res_blocks."
279234

280-
# Skip bias_leaky_relu as it doesn't need any transformation
281-
if "bias_leaky_relu" in key:
235+
layer_name = parts[convs_idx + 2]
236+
if layer_name == "skip":
237+
layer_name = "conv_skip"
238+
239+
if key.endswith(".weight"):
240+
param_name = "weight"
241+
elif key.endswith(".bias"):
242+
param_name = "act_fn.bias"
243+
bias = True
244+
245+
suffix_parts = [str(sequential_idx - 1), layer_name, param_name]
246+
suffix = ".".join(suffix_parts)
247+
new_key = prefix + suffix
248+
249+
param = state_dict.pop(key)
250+
if bias:
251+
param = param.squeeze()
252+
state_dict[new_key] = param
253+
return
282254
return
255+
return
256+
283257

284-
# Scale EqualLinear weights in linear layers
285-
if ".linears." in key and ".weight" in key:
286-
convert_equal_linear_weight(key, state_dict)
258+
def convert_animate_face_adapter_weights(key: str, state_dict: Dict[str, Any]) -> None:
259+
"""
260+
Convert face adapter weights for the Animate model.
261+
262+
The original model uses a fused KV projection but the diffusers models uses separate K and V projections.
263+
"""
264+
# Skip if not a weight or bias
265+
if ".weight" not in key and ".bias" not in key:
287266
return
288267

289-
# Scale EqualConv2d weights in convolution layers
290-
if ".convs." in key and ".weight" in key:
291-
# Two cases:
292-
# 1. ConvLayer with EqualConv2d: convs.<i>.conv2d.weight (after renaming)
293-
# 2. Direct EqualConv2d (last conv): convs.<i>.weight (where <i> is a single digit)
294-
if ".conv2d.weight" in key:
295-
convert_equal_conv2d_weight(key, state_dict)
296-
return
297-
elif key.split(".")[-2].isdigit() and key.endswith(".weight"):
298-
# This handles keys like "convs.7.weight" where the second-to-last part is a digit
299-
convert_equal_conv2d_weight(key, state_dict)
300-
return
268+
prefix = "face_adapter."
269+
if ".fuser_blocks." in key:
270+
parts = key.split(".")
271+
272+
module_list_idx = parts.index("fuser_blocks") if "fuser_blocks" in parts else -1
273+
if module_list_idx >= 0 and (len(parts) - 1) - module_list_idx == 3:
274+
block_idx = parts[module_list_idx + 1]
275+
layer_name = parts[module_list_idx + 2]
276+
param_name = parts[module_list_idx + 3]
277+
278+
if layer_name == "linear1_kv":
279+
layer_name_k = "to_k"
280+
layer_name_v = "to_v"
281+
282+
suffix_k = ".".join([block_idx, layer_name_k, param_name])
283+
suffix_v = ".".join([block_idx, layer_name_v, param_name])
284+
new_key_k = prefix + suffix_k
285+
new_key_v = prefix + suffix_v
286+
287+
kv_proj = state_dict.pop(key)
288+
k_proj, v_proj = torch.chunk(kv_proj, 2, dim=0)
289+
state_dict[new_key_k] = k_proj
290+
state_dict[new_key_v] = v_proj
291+
return
292+
else:
293+
if layer_name == "q_norm":
294+
new_layer_name = "norm_q"
295+
elif layer_name == "k_norm":
296+
new_layer_name = "norm_k"
297+
elif layer_name == "linear1_q":
298+
new_layer_name = "to_q"
299+
elif layer_name == "linear2":
300+
new_layer_name = "to_out"
301+
302+
suffix_parts = [block_idx, new_layer_name, param_name]
303+
suffix = ".".join(suffix_parts)
304+
new_key = prefix + suffix
305+
state_dict[new_key] = state_dict.pop(key)
306+
return
307+
return
301308

302309

303310
TRANSFORMER_SPECIAL_KEYS_REMAP = {}
304311
VACE_TRANSFORMER_SPECIAL_KEYS_REMAP = {}
305-
ANIMATE_TRANSFORMER_SPECIAL_KEYS_REMAP = {"condition_embedder.motion_embedder": convert_animate_motion_encoder_weights}
312+
ANIMATE_TRANSFORMER_SPECIAL_KEYS_REMAP = {
313+
"motion_encoder": convert_animate_motion_encoder_weights,
314+
"face_adapter": convert_animate_face_adapter_weights,
315+
}
306316

307317

308318
def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:
@@ -580,7 +590,14 @@ def get_transformer_config(model_type: str) -> Tuple[Dict[str, Any], ...]:
580590
"qk_norm": "rms_norm_across_heads",
581591
"text_dim": 4096,
582592
"rope_max_seq_len": 1024,
583-
"pos_embed_seq_len": 257 * 2,
593+
"pos_embed_seq_len": None,
594+
"motion_encoder_size": 512, # Start of Wan Animate-specific configs
595+
"motion_style_dim": 512,
596+
"motion_dim": 20,
597+
"motion_encoder_dim": 512,
598+
"face_encoder_hidden_dim": 1024,
599+
"face_encoder_num_heads": 4,
600+
"inject_face_latents_blocks": 5,
584601
},
585602
}
586603
RENAME_DICT = ANIMATE_TRANSFORMER_KEYS_RENAME_DICT
@@ -620,18 +637,6 @@ def convert_transformer(model_type: str, stage: str = None):
620637
continue
621638
handler_fn_inplace(key, original_state_dict)
622639

623-
# For Animate model, add blur_conv weights from the initialized model
624-
# These are procedurally generated in the diffusers ConvLayer and not present in original checkpoint
625-
if "Animate" in model_type:
626-
# Create a temporary model on CPU to get the blur_conv weights
627-
with torch.device("cpu"):
628-
temp_transformer = WanAnimateTransformer3DModel.from_config(diffusers_config)
629-
temp_model_state = temp_transformer.state_dict()
630-
for key in temp_model_state.keys():
631-
if "blur_conv.weight" in key and "motion_embedder" in key:
632-
original_state_dict[key] = temp_model_state[key]
633-
del temp_transformer
634-
635640
# Load state dict into the meta model, which will materialize the tensors
636641
transformer.load_state_dict(original_state_dict, strict=True, assign=True)
637642

0 commit comments

Comments
 (0)