4040
4141
4242WAN_ANIMATE_MOTION_ENCODER_CHANNEL_SIZES = {
43- 4 : 512 , 8 : 512 , 16 : 512 , 32 : 512 , 64 : 256 , 128 : 128 , 256 : 64 , 512 : 32 , 1024 : 16
43+ "4" : 512 , "8" : 512 , "16" : 512 , "32" : 512 , "64" : 256 , " 128" : 128 , " 256" : 64 , " 512" : 32 , " 1024" : 16
4444}
4545
4646
@@ -162,7 +162,7 @@ def forward(self, x: torch.Tensor, channel_dim: int = 1) -> torch.Tensor:
162162 def __repr__ (self ):
163163 return (
164164 f'{ self .__class__ .__name__ } ({ self .weight .shape [1 ]} , { self .weight .shape [0 ]} ,'
165- f' { self .weight .shape [2 ]} , stride={ self .stride } , padding={ self .padding } )'
165+ f' kernel_size= { self .weight .shape [2 ]} , stride={ self .stride } , padding={ self .padding } )'
166166 )
167167
168168
@@ -199,7 +199,10 @@ def forward(self, input: torch.Tensor, channel_dim: int = 1) -> torch.Tensor:
199199 return out
200200
201201 def __repr__ (self ):
202- return (f'{ self .__class__ .__name__ } ({ self .weight .shape [1 ]} , { self .weight .shape [0 ]} )' )
202+ return (
203+ f'{ self .__class__ .__name__ } (in_features={ self .weight .shape [1 ]} , out_features={ self .weight .shape [0 ]} ,'
204+ f' bias={ self .bias is not None } )'
205+ )
203206
204207
205208class MotionEncoderResBlock (nn .Module ):
@@ -266,21 +269,22 @@ def __init__(
266269 motion_dim : int = 20 ,
267270 out_dim : int = 512 ,
268271 motion_blocks : int = 5 ,
269- channels : Optional [Dict [int , int ]] = None ,
272+ channels : Optional [Dict [str , int ]] = None ,
270273 ):
271274 super ().__init__ ()
275+ self .size = size
272276
273277 # Appearance encoder: conv layers
274278 if channels is None :
275279 channels = WAN_ANIMATE_MOTION_ENCODER_CHANNEL_SIZES
276280
277- self .conv_in = MotionConv2d (3 , channels [size ], 1 , use_activation = True )
281+ self .conv_in = MotionConv2d (3 , channels [str ( size ) ], 1 , use_activation = True )
278282
279283 self .res_blocks = nn .ModuleList ()
280- in_channels = channels [size ]
284+ in_channels = channels [str ( size ) ]
281285 log_size = int (math .log (size , 2 ))
282286 for i in range (log_size , 2 , - 1 ):
283- out_channels = channels [2 ** (i - 1 )]
287+ out_channels = channels [str ( 2 ** (i - 1 ) )]
284288 self .res_blocks .append (MotionEncoderResBlock (in_channels , out_channels ))
285289 in_channels = out_channels
286290
@@ -296,6 +300,12 @@ def __init__(
296300 self .motion_synthesis_weight = nn .Parameter (torch .randn (out_dim , motion_dim ))
297301
298302 def forward (self , face_image : torch .Tensor , channel_dim : int = 1 , upcast_to_fp32 : bool = True ) -> torch .Tensor :
303+ if (face_image .shape [- 2 ] != self .size ) or (face_image .shape [- 1 ] != self .size ):
304+ raise ValueError (
305+ f"Face pixel values has resolution ({ face_image .shape [- 1 ]} , { face_image .shape [- 2 ]} ) but is expected"
306+ f" to have resolution ({ self .size } , { self .size } )"
307+ )
308+
299309 # Appearance encoding through convs
300310 face_image = self .conv_in (face_image , channel_dim )
301311 for block in self .res_blocks :
@@ -314,7 +324,7 @@ def forward(self, face_image: torch.Tensor, channel_dim: int = 1, upcast_to_fp32
314324 motion_feat = motion_feat .to (torch .float32 )
315325 weight = weight .to (torch .float32 )
316326
317- Q = torch .linalg .qr (weight )[0 ]
327+ Q = torch .linalg .qr (weight )[0 ]. to ( device = motion_feat . device )
318328
319329 motion_feat_diag = torch .diag_embed (motion_feat ) # Alpha, diagonal matrix
320330 motion_decomposition = torch .matmul (motion_feat_diag , Q .T )
@@ -384,7 +394,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
384394 x = self .out_proj (x )
385395 x = x .unflatten (0 , (batch_size , - 1 )).permute (0 , 2 , 1 , 3 ) # [B * N, T, C_out] --> [B, T, N, C_out]
386396
387- padding = self .padding_tokens .repeat (batch_size , x .shape [1 ], 1 , 1 )
397+ padding = self .padding_tokens .repeat (batch_size , x .shape [1 ], 1 , 1 ). to ( device = x . device )
388398 x = torch .cat ([x , padding ], dim = - 2 ) # [B, T, N, C_out] --> [B, T, N + 1, C_out]
389399
390400 return x
@@ -552,7 +562,7 @@ class WanAnimateTransformer3DModel(
552562
553563 _supports_gradient_checkpointing = True
554564 _skip_layerwise_casting_patterns = ["patch_embedding" , "condition_embedder" , "norm" ]
555- _no_split_modules = ["WanAnimateTransformerBlock " ]
565+ _no_split_modules = ["WanTransformerBlock" , "MotionEncoderResBlock " ]
556566 _keep_in_fp32_modules = [
557567 "time_embedder" ,
558568 "scale_shift_table" ,
@@ -583,7 +593,7 @@ def __init__(
583593 added_kv_proj_dim : Optional [int ] = None ,
584594 rope_max_seq_len : int = 1024 ,
585595 pos_embed_seq_len : Optional [int ] = None ,
586- motion_encoder_channel_sizes : Optional [Dict [int , int ]] = None , # Start of Wan Animate-specific args
596+ motion_encoder_channel_sizes : Optional [Dict [str , int ]] = None , # Start of Wan Animate-specific args
587597 motion_encoder_size : int = 512 ,
588598 motion_style_dim : int = 512 ,
589599 motion_dim : int = 20 ,
@@ -822,6 +832,9 @@ def forward(
822832 if block_idx % self .config .inject_face_latents_blocks == 0 :
823833 face_adapter_block_idx = block_idx // self .config .inject_face_latents_blocks
824834 face_adapter_output = self .face_adapter [face_adapter_block_idx ](hidden_states , motion_vec )
835+ # In case the face adapter and main transformer blocks are on different devices, which can happen when
836+ # using model parallelism
837+ face_adapter_output = face_adapter_output .to (device = hidden_states .device )
825838 hidden_states = face_adapter_output + hidden_states
826839
827840 # 6. Output norm, projection & unpatchify
@@ -834,14 +847,16 @@ def forward(
834847 # batch_size, inner_dim
835848 shift , scale = (self .scale_shift_table .to (temb .device ) + temb .unsqueeze (1 )).chunk (2 , dim = 1 )
836849
850+ hidden_states_original_dtype = hidden_states .dtype
851+ hidden_states = self .norm_out (hidden_states .float ())
837852 # Move the shift and scale tensors to the same device as hidden_states.
838853 # When using multi-GPU inference via accelerate these will be on the
839854 # first device rather than the last device, which hidden_states ends up
840855 # on.
841856 shift = shift .to (hidden_states .device )
842857 scale = scale .to (hidden_states .device )
858+ hidden_states = (hidden_states * (1 + scale ) + shift ).to (dtype = hidden_states_original_dtype )
843859
844- hidden_states = (self .norm_out (hidden_states .float ()) * (1 + scale ) + shift ).type_as (hidden_states )
845860 hidden_states = self .proj_out (hidden_states )
846861
847862 hidden_states = hidden_states .reshape (
0 commit comments