157157 "attn2.to_k_img" : "attn2.add_k_proj" ,
158158 "attn2.to_v_img" : "attn2.add_v_proj" ,
159159 "attn2.norm_k_img" : "attn2.norm_added_k" ,
160+ # Wan Animate-specific mappings (motion encoder, face encoder, face adapter)
160161 # Motion encoder mappings
161- "motion_encoder.enc.net_app.convs" : "condition_embedder.motion_embedder.convs" ,
162- "motion_encoder.enc.fc" : "condition_embedder.motion_embedder.linears " ,
163- "motion_encoder.dec.direction.weight" : "condition_embedder.motion_embedder .motion_synthesis_weight" ,
162+ # The name mapping is complicated for the convolutional part so we handle that in its own function
163+ "motion_encoder.enc.fc" : "motion_encoder.motion_network " ,
164+ "motion_encoder.dec.direction.weight" : "motion_encoder .motion_synthesis_weight" ,
164165 # Face encoder mappings - CausalConv1d has a .conv submodule that we need to flatten
165- "face_encoder.conv1_local.conv" : "condition_embedder.face_embedder.conv1_local" ,
166- "face_encoder.conv2.conv" : "condition_embedder.face_embedder.conv2" ,
167- "face_encoder.conv3.conv" : "condition_embedder.face_embedder.conv3" ,
168- "face_encoder.out_proj" : "condition_embedder.face_embedder.out_proj" ,
169- "face_encoder.norm1" : "condition_embedder.face_embedder.norm1" ,
170- # Return to the original order for face_embedder norms
171- "face_encoder.norm2" : "face_embedder_norm__placeholder" ,
172- "face_encoder.norm3" : "condition_embedder.face_embedder.norm2" ,
173- "face_embedder_norm__placeholder" : "condition_embedder.face_embedder.norm3" ,
174- "face_encoder.padding_tokens" : "condition_embedder.face_embedder.padding_tokens" ,
175- # Face adapter mappings
176- "face_adapter.fuser_blocks" : "face_adapter" ,
166+ "face_encoder.conv1_local.conv" : "face_encoder.conv1_local" ,
167+ "face_encoder.conv2.conv" : "face_encoder.conv2" ,
168+ "face_encoder.conv3.conv" : "face_encoder.conv3" ,
169+ # Face adapter mappings are handled in a separate function
177170}
178171
179172
180- def convert_equal_linear_weight (key : str , state_dict : Dict [str , Any ]) -> None :
181- """
182- Convert EqualLinear weights to standard Linear weights by applying the scale factor.
183- EqualLinear uses: F.linear(input, self.weight * self.scale, bias=self.bias)
184- where scale = (1 / sqrt(in_dim))
185- """
186- if ".weight" not in key :
187- return
188-
189- in_dim = state_dict [key ].shape [1 ]
190- scale = 1.0 / math .sqrt (in_dim )
191- state_dict [key ] = state_dict [key ] * scale
192-
193-
194- def convert_equal_conv2d_weight (key : str , state_dict : Dict [str , Any ]) -> None :
195- """
196- Convert EqualConv2d weights to standard Conv2d weights by applying the scale factor.
197- EqualConv2d uses: F.conv2d(input, self.weight * self.scale, bias=self.bias, ...)
198- where scale = 1 / sqrt(in_channel * kernel_size^2)
199- """
200- if ".weight" not in key or len (state_dict [key ].shape ) != 4 :
201- return
202-
203- out_channel , in_channel , kernel_size , kernel_size = state_dict [key ].shape
204- scale = 1.0 / math .sqrt (in_channel * kernel_size ** 2 )
205- state_dict [key ] = state_dict [key ] * scale
206-
207-
208173# TODO: Verify this and simplify if possible.
209- def convert_animate_motion_encoder_weights (key : str , state_dict : Dict [str , Any ]) -> None :
174+ def convert_animate_motion_encoder_weights (key : str , state_dict : Dict [str , Any ], final_conv_idx : int = 8 ) -> None :
210175 """
211176 Convert all motion encoder weights for Animate model.
212- This handles both EqualLinear (in linears) and EqualConv2d (in convs).
213177
214178 In the original model:
215179 - All Linear layers in fc use EqualLinear
@@ -220,89 +184,135 @@ def convert_animate_motion_encoder_weights(key: str, state_dict: Dict[str, Any])
220184 Conversion strategy:
221185 1. Drop .kernel buffers (blur kernels)
222186 2. Rename sequential indices to named components (e.g., 0 -> conv2d, 1 -> bias_leaky_relu)
223- 3. Scale EqualLinear and EqualConv2d weights
224187 """
225188 # Skip if not a weight, bias, or kernel
226189 if ".weight" not in key and ".bias" not in key and ".kernel" not in key :
227190 return
228191
229192 # Handle Blur kernel buffers from original implementation.
230- # After renaming, these appear under: condition_embedder.motion_embedder.convs. *.conv{1,2}.0.kernel
231- # Diffusers constructs blur kernels procedurally (ConvLayer.blur_conv) so we must drop these keys
232- if ".kernel" in key and "condition_embedder.motion_embedder.convs " in key :
193+ # After renaming, these appear under: motion_encoder.res_blocks. *.conv{2,skip}.blur_kernel
194+ # Diffusers constructs blur kernels as a non-persistent buffer so we must drop these keys
195+ if ".kernel" in key and "motion_encoder " in key :
233196 # Remove unexpected blur kernel buffers to avoid strict load errors
234197 state_dict .pop (key , None )
235198 return
236199
237200 # Rename Sequential indices to named components in ConvLayer and ResBlock
238- # This must happen BEFORE weight scaling because we need to rename the keys first
239- # Original: convs.X.Y.weight/bias or convs.X.conv1/conv2/skip.Y.weight/bias
240- # Target: convs.X.conv2d.weight or convs.X.conv1/conv2/skip.conv2d.weight or .bias_leaky_relu
241- if ".convs." in key and (".weight" in key or ".bias" in key ):
201+ if ".enc.net_app.convs." in key and (".weight" in key or ".bias" in key ):
242202 parts = key .split ("." )
243203
244204 # Find the sequential index (digit) after convs or after conv1/conv2/skip
245205 # Examples:
246- # - convs.0.0.weight -> convs.0.conv2d. weight (ConvLayer, no blur )
247- # - convs.0.1.weight -> convs.0.conv2d.weight (ConvLayer, with blur at index 0 )
248- # - convs.0.1.bias -> convs.0.bias_leaky_relu (FusedLeakyReLU )
249- # - convs.1.conv1.1 .weight -> convs.1 .conv1.conv2d. weight (ResBlock ConvLayer)
250- # - convs.1 .conv1.2 .bias -> convs.1 .conv1.bias_leaky_relu (ResBlock FusedLeakyReLU )
251- # - convs.8.weight -> unchanged (final Conv2d, not in Sequential)
252-
253- # Check if we have a digit as second-to-last part before .weight or . bias
254- # But we need to distinguish between Sequential indices (convs.X.Y. weight)
255- # and ModuleList indices (convs.X.weight )
256- # We only rename if there are at least 3 parts after finding 'convs'
206+ # - enc.net_app. convs.0.0.weight -> conv_in. weight (initial conv layer weight )
207+ # - enc.net_app. convs.0.1.bias -> conv_in.act_fn.bias (initial conv layer bias )
208+ # - enc.net_app.convs.{n:1-7}.conv1.0.weight -> res_blocks.{(n-1):0-6}.conv1.weight (conv1 weight )
209+ # - e.g. enc.net_app. convs.1.conv1.0 .weight -> res_blocks.0 .conv1.weight
210+ # - enc.net_app. convs.{n:1-7} .conv1.1 .bias -> res_blocks.{(n-1):0-6} .conv1.act_fn.bias (conv1 bias )
211+ # - e.g. enc.net_app. convs.1.conv1.1.bias -> res_blocks.0.conv1.act_fn.bias
212+ # - enc.net_app.convs.{n:1-7}.conv2.1.weight -> res_blocks.{(n-1):0-6}.conv2.weight (conv2 weight)
213+ # - enc.net_app.convs.1.conv2.2.bias -> res_blocks.0.conv2.act_fn.bias (conv2 bias)
214+ # - enc.net_app.convs.{n:1-7}.skip.1.weight -> res_blocks.{(n-1):0-6}.conv_skip.weight (skip conv weight)
215+ # - enc.net_app.convs.8 -> conv_out (final conv layer )
216+
257217 convs_idx = parts .index ("convs" ) if "convs" in parts else - 1
258- if (
259- convs_idx >= 0 and len (parts ) - convs_idx > 3
260- ): # e.g., ['convs', '0', '0', 'weight'] has 4 parts after convs
261- if len (parts ) >= 2 and parts [- 2 ].isdigit ():
218+ if convs_idx >= 0 and len (parts ) - convs_idx >= 2 :
219+ bias = False
220+ # The nn.Sequential index will always follow convs
221+ sequential_idx = int (parts [convs_idx + 1 ])
222+ if sequential_idx == 0 :
262223 if key .endswith (".weight" ):
263- # Replace digit index with 'conv2d' for EqualConv2d weight parameters
264- parts [- 2 ] = "conv2d"
265- new_key = "." .join (parts )
266- state_dict [new_key ] = state_dict .pop (key )
267- # Update key for subsequent processing
268- key = new_key
224+ new_key = "motion_encoder.conv_in.weight"
269225 elif key .endswith (".bias" ):
270- # Replace digit index + .bias with 'bias_leaky_relu' for FusedLeakyReLU bias
271- new_key = "." .join (parts [:- 2 ]) + ".bias_leaky_relu"
272- state_dict [new_key ] = state_dict .pop (key )
273- # Bias doesn't need scaling, we're done
274- return
275-
276- # Skip blur_conv weights that are already initialized in diffusers
277- if "blur_conv.weight" in key :
278- return
226+ new_key = "motion_encoder.conv_in.act_fn.bias"
227+ bias = True
228+ elif sequential_idx == final_conv_idx :
229+ if key .endswith (".weight" ):
230+ new_key = "motion_encoder.conv_out.weight"
231+ else :
232+ # Intermediate .convs. layers, which get mapped to .res_blocks.
233+ prefix = "motion_encoder.res_blocks."
279234
280- # Skip bias_leaky_relu as it doesn't need any transformation
281- if "bias_leaky_relu" in key :
235+ layer_name = parts [convs_idx + 2 ]
236+ if layer_name == "skip" :
237+ layer_name = "conv_skip"
238+
239+ if key .endswith (".weight" ):
240+ param_name = "weight"
241+ elif key .endswith (".bias" ):
242+ param_name = "act_fn.bias"
243+ bias = True
244+
245+ suffix_parts = [str (sequential_idx - 1 ), layer_name , param_name ]
246+ suffix = "." .join (suffix_parts )
247+ new_key = prefix + suffix
248+
249+ param = state_dict .pop (key )
250+ if bias :
251+ param = param .squeeze ()
252+ state_dict [new_key ] = param
253+ return
282254 return
255+ return
256+
283257
284- # Scale EqualLinear weights in linear layers
285- if ".linears." in key and ".weight" in key :
286- convert_equal_linear_weight (key , state_dict )
258+ def convert_animate_face_adapter_weights (key : str , state_dict : Dict [str , Any ]) -> None :
259+ """
260+ Convert face adapter weights for the Animate model.
261+
262+ The original model uses a fused KV projection but the diffusers models uses separate K and V projections.
263+ """
264+ # Skip if not a weight or bias
265+ if ".weight" not in key and ".bias" not in key :
287266 return
288267
289- # Scale EqualConv2d weights in convolution layers
290- if ".convs." in key and ".weight" in key :
291- # Two cases:
292- # 1. ConvLayer with EqualConv2d: convs.<i>.conv2d.weight (after renaming)
293- # 2. Direct EqualConv2d (last conv): convs.<i>.weight (where <i> is a single digit)
294- if ".conv2d.weight" in key :
295- convert_equal_conv2d_weight (key , state_dict )
296- return
297- elif key .split ("." )[- 2 ].isdigit () and key .endswith (".weight" ):
298- # This handles keys like "convs.7.weight" where the second-to-last part is a digit
299- convert_equal_conv2d_weight (key , state_dict )
300- return
268+ prefix = "face_adapter."
269+ if ".fuser_blocks." in key :
270+ parts = key .split ("." )
271+
272+ module_list_idx = parts .index ("fuser_blocks" ) if "fuser_blocks" in parts else - 1
273+ if module_list_idx >= 0 and (len (parts ) - 1 ) - module_list_idx == 3 :
274+ block_idx = parts [module_list_idx + 1 ]
275+ layer_name = parts [module_list_idx + 2 ]
276+ param_name = parts [module_list_idx + 3 ]
277+
278+ if layer_name == "linear1_kv" :
279+ layer_name_k = "to_k"
280+ layer_name_v = "to_v"
281+
282+ suffix_k = "." .join ([block_idx , layer_name_k , param_name ])
283+ suffix_v = "." .join ([block_idx , layer_name_v , param_name ])
284+ new_key_k = prefix + suffix_k
285+ new_key_v = prefix + suffix_v
286+
287+ kv_proj = state_dict .pop (key )
288+ k_proj , v_proj = torch .chunk (kv_proj , 2 , dim = 0 )
289+ state_dict [new_key_k ] = k_proj
290+ state_dict [new_key_v ] = v_proj
291+ return
292+ else :
293+ if layer_name == "q_norm" :
294+ new_layer_name = "norm_q"
295+ elif layer_name == "k_norm" :
296+ new_layer_name = "norm_k"
297+ elif layer_name == "linear1_q" :
298+ new_layer_name = "to_q"
299+ elif layer_name == "linear2" :
300+ new_layer_name = "to_out"
301+
302+ suffix_parts = [block_idx , new_layer_name , param_name ]
303+ suffix = "." .join (suffix_parts )
304+ new_key = prefix + suffix
305+ state_dict [new_key ] = state_dict .pop (key )
306+ return
307+ return
301308
302309
303310TRANSFORMER_SPECIAL_KEYS_REMAP = {}
304311VACE_TRANSFORMER_SPECIAL_KEYS_REMAP = {}
305- ANIMATE_TRANSFORMER_SPECIAL_KEYS_REMAP = {"condition_embedder.motion_embedder" : convert_animate_motion_encoder_weights }
312+ ANIMATE_TRANSFORMER_SPECIAL_KEYS_REMAP = {
313+ "motion_encoder" : convert_animate_motion_encoder_weights ,
314+ "face_adapter" : convert_animate_face_adapter_weights ,
315+ }
306316
307317
308318def update_state_dict_ (state_dict : Dict [str , Any ], old_key : str , new_key : str ) -> Dict [str , Any ]:
@@ -580,7 +590,14 @@ def get_transformer_config(model_type: str) -> Tuple[Dict[str, Any], ...]:
580590 "qk_norm" : "rms_norm_across_heads" ,
581591 "text_dim" : 4096 ,
582592 "rope_max_seq_len" : 1024 ,
583- "pos_embed_seq_len" : 257 * 2 ,
593+ "pos_embed_seq_len" : None ,
594+ "motion_encoder_size" : 512 , # Start of Wan Animate-specific configs
595+ "motion_style_dim" : 512 ,
596+ "motion_dim" : 20 ,
597+ "motion_encoder_dim" : 512 ,
598+ "face_encoder_hidden_dim" : 1024 ,
599+ "face_encoder_num_heads" : 4 ,
600+ "inject_face_latents_blocks" : 5 ,
584601 },
585602 }
586603 RENAME_DICT = ANIMATE_TRANSFORMER_KEYS_RENAME_DICT
@@ -620,18 +637,6 @@ def convert_transformer(model_type: str, stage: str = None):
620637 continue
621638 handler_fn_inplace (key , original_state_dict )
622639
623- # For Animate model, add blur_conv weights from the initialized model
624- # These are procedurally generated in the diffusers ConvLayer and not present in original checkpoint
625- if "Animate" in model_type :
626- # Create a temporary model on CPU to get the blur_conv weights
627- with torch .device ("cpu" ):
628- temp_transformer = WanAnimateTransformer3DModel .from_config (diffusers_config )
629- temp_model_state = temp_transformer .state_dict ()
630- for key in temp_model_state .keys ():
631- if "blur_conv.weight" in key and "motion_embedder" in key :
632- original_state_dict [key ] = temp_model_state [key ]
633- del temp_transformer
634-
635640 # Load state dict into the meta model, which will materialize the tensors
636641 transformer .load_state_dict (original_state_dict , strict = True , assign = True )
637642
0 commit comments