Skip to content

Commit ac2962d

Browse files
committed
Fix issues (such as device placement issues) to get remaining transformer tests passing
1 parent 275d324 commit ac2962d

File tree

2 files changed

+36
-16
lines changed

2 files changed

+36
-16
lines changed

src/diffusers/models/transformers/transformer_wan_animate.py

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040

4141

4242
WAN_ANIMATE_MOTION_ENCODER_CHANNEL_SIZES = {
43-
4: 512, 8: 512, 16: 512, 32: 512, 64: 256, 128: 128, 256: 64, 512: 32, 1024: 16
43+
"4": 512, "8": 512, "16": 512, "32": 512, "64": 256, "128": 128, "256": 64, "512": 32, "1024": 16
4444
}
4545

4646

@@ -162,7 +162,7 @@ def forward(self, x: torch.Tensor, channel_dim: int = 1) -> torch.Tensor:
162162
def __repr__(self):
163163
return (
164164
f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},'
165-
f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})'
165+
f' kernel_size={self.weight.shape[2]}, stride={self.stride}, padding={self.padding})'
166166
)
167167

168168

@@ -199,7 +199,10 @@ def forward(self, input: torch.Tensor, channel_dim: int = 1) -> torch.Tensor:
199199
return out
200200

201201
def __repr__(self):
202-
return (f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})')
202+
return (
203+
f'{self.__class__.__name__}(in_features={self.weight.shape[1]}, out_features={self.weight.shape[0]},'
204+
f' bias={self.bias is not None})'
205+
)
203206

204207

205208
class MotionEncoderResBlock(nn.Module):
@@ -266,21 +269,22 @@ def __init__(
266269
motion_dim: int = 20,
267270
out_dim: int = 512,
268271
motion_blocks: int = 5,
269-
channels: Optional[Dict[int, int]] = None,
272+
channels: Optional[Dict[str, int]] = None,
270273
):
271274
super().__init__()
275+
self.size = size
272276

273277
# Appearance encoder: conv layers
274278
if channels is None:
275279
channels = WAN_ANIMATE_MOTION_ENCODER_CHANNEL_SIZES
276280

277-
self.conv_in = MotionConv2d(3, channels[size], 1, use_activation=True)
281+
self.conv_in = MotionConv2d(3, channels[str(size)], 1, use_activation=True)
278282

279283
self.res_blocks = nn.ModuleList()
280-
in_channels = channels[size]
284+
in_channels = channels[str(size)]
281285
log_size = int(math.log(size, 2))
282286
for i in range(log_size, 2, -1):
283-
out_channels = channels[2 ** (i - 1)]
287+
out_channels = channels[str(2 ** (i - 1))]
284288
self.res_blocks.append(MotionEncoderResBlock(in_channels, out_channels))
285289
in_channels = out_channels
286290

@@ -296,6 +300,12 @@ def __init__(
296300
self.motion_synthesis_weight = nn.Parameter(torch.randn(out_dim, motion_dim))
297301

298302
def forward(self, face_image: torch.Tensor, channel_dim: int = 1, upcast_to_fp32: bool = True) -> torch.Tensor:
303+
if (face_image.shape[-2] != self.size) or (face_image.shape[-1] != self.size):
304+
raise ValueError(
305+
f"Face pixel values has resolution ({face_image.shape[-1]}, {face_image.shape[-2]}) but is expected"
306+
f" to have resolution ({self.size}, {self.size})"
307+
)
308+
299309
# Appearance encoding through convs
300310
face_image = self.conv_in(face_image, channel_dim)
301311
for block in self.res_blocks:
@@ -314,7 +324,7 @@ def forward(self, face_image: torch.Tensor, channel_dim: int = 1, upcast_to_fp32
314324
motion_feat = motion_feat.to(torch.float32)
315325
weight = weight.to(torch.float32)
316326

317-
Q = torch.linalg.qr(weight)[0]
327+
Q = torch.linalg.qr(weight)[0].to(device=motion_feat.device)
318328

319329
motion_feat_diag = torch.diag_embed(motion_feat) # Alpha, diagonal matrix
320330
motion_decomposition = torch.matmul(motion_feat_diag, Q.T)
@@ -384,7 +394,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
384394
x = self.out_proj(x)
385395
x = x.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3) # [B * N, T, C_out] --> [B, T, N, C_out]
386396

387-
padding = self.padding_tokens.repeat(batch_size, x.shape[1], 1, 1)
397+
padding = self.padding_tokens.repeat(batch_size, x.shape[1], 1, 1).to(device=x.device)
388398
x = torch.cat([x, padding], dim=-2) # [B, T, N, C_out] --> [B, T, N + 1, C_out]
389399

390400
return x
@@ -552,7 +562,7 @@ class WanAnimateTransformer3DModel(
552562

553563
_supports_gradient_checkpointing = True
554564
_skip_layerwise_casting_patterns = ["patch_embedding", "condition_embedder", "norm"]
555-
_no_split_modules = ["WanAnimateTransformerBlock"]
565+
_no_split_modules = ["WanTransformerBlock", "MotionEncoderResBlock"]
556566
_keep_in_fp32_modules = [
557567
"time_embedder",
558568
"scale_shift_table",
@@ -583,7 +593,7 @@ def __init__(
583593
added_kv_proj_dim: Optional[int] = None,
584594
rope_max_seq_len: int = 1024,
585595
pos_embed_seq_len: Optional[int] = None,
586-
motion_encoder_channel_sizes: Optional[Dict[int, int]] = None, # Start of Wan Animate-specific args
596+
motion_encoder_channel_sizes: Optional[Dict[str, int]] = None, # Start of Wan Animate-specific args
587597
motion_encoder_size: int = 512,
588598
motion_style_dim: int = 512,
589599
motion_dim: int = 20,
@@ -822,6 +832,9 @@ def forward(
822832
if block_idx % self.config.inject_face_latents_blocks == 0:
823833
face_adapter_block_idx = block_idx // self.config.inject_face_latents_blocks
824834
face_adapter_output = self.face_adapter[face_adapter_block_idx](hidden_states, motion_vec)
835+
# In case the face adapter and main transformer blocks are on different devices, which can happen when
836+
# using model parallelism
837+
face_adapter_output = face_adapter_output.to(device=hidden_states.device)
825838
hidden_states = face_adapter_output + hidden_states
826839

827840
# 6. Output norm, projection & unpatchify
@@ -834,14 +847,16 @@ def forward(
834847
# batch_size, inner_dim
835848
shift, scale = (self.scale_shift_table.to(temb.device) + temb.unsqueeze(1)).chunk(2, dim=1)
836849

850+
hidden_states_original_dtype = hidden_states.dtype
851+
hidden_states = self.norm_out(hidden_states.float())
837852
# Move the shift and scale tensors to the same device as hidden_states.
838853
# When using multi-GPU inference via accelerate these will be on the
839854
# first device rather than the last device, which hidden_states ends up
840855
# on.
841856
shift = shift.to(hidden_states.device)
842857
scale = scale.to(hidden_states.device)
858+
hidden_states = (hidden_states * (1 + scale) + shift).to(dtype=hidden_states_original_dtype)
843859

844-
hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states)
845860
hidden_states = self.proj_out(hidden_states)
846861

847862
hidden_states = hidden_states.reshape(

tests/models/transformers/test_models_transformer_wan_animate.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@ def dummy_input(self):
4747
clip_dim = 16
4848

4949
inference_segment_length = 77 # The inference segment length in the full Wan2.2-Animate-14B model
50-
face_height = 8
51-
face_width = 8
50+
face_height = 16 # Should be square and match `motion_encoder_size` below
51+
face_width = 16
5252

5353
hidden_states = torch.randn((batch_size, 2 * num_channels + 4, num_frames + 1, height, width)).to(torch_device)
5454
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
@@ -70,13 +70,17 @@ def dummy_input(self):
7070

7171
@property
7272
def input_shape(self):
73-
return (4, 1, 16, 16)
73+
return (12, 1, 16, 16)
7474

7575
@property
7676
def output_shape(self):
7777
return (4, 1, 16, 16)
7878

7979
def prepare_init_args_and_inputs_for_common(self):
80+
# Use custom channel sizes since the default Wan Animate channel sizes will cause the motion encoder to
81+
# contain the vast majority of the parameters in the test model
82+
channel_sizes = {"4": 16, "8": 16, "16": 16}
83+
8084
init_dict = {
8185
"patch_size": (1, 2, 2),
8286
"num_attention_heads": 2,
@@ -92,7 +96,8 @@ def prepare_init_args_and_inputs_for_common(self):
9296
"qk_norm": "rms_norm_across_heads",
9397
"image_dim": 16,
9498
"rope_max_seq_len": 32,
95-
"motion_encoder_size": 8, # Start of Wan Animate-specific config
99+
"motion_encoder_channel_sizes": channel_sizes, # Start of Wan Animate-specific config
100+
"motion_encoder_size": 16, # Ensures that there will be 2 motion encoder resblocks
96101
"motion_style_dim": 8,
97102
"motion_dim": 4,
98103
"motion_encoder_dim": 16,

0 commit comments

Comments
 (0)