Skip to content

Commit 3a80241

Browse files
committed
Fix some more Wan Animate pipeline shape errors
1 parent e2846f6 commit 3a80241

File tree

1 file changed

+15
-7
lines changed

1 file changed

+15
-7
lines changed

src/diffusers/pipelines/wan/pipeline_wan_animate.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -989,9 +989,17 @@ def __call__(
989989
pose_video = self.video_processor.preprocess_video(pose_video, height=height, width=width).to(
990990
device, dtype=torch.float32
991991
)
992-
face_video = self.video_processor.preprocess_video(face_video, height=height, width=width).to(
993-
device, dtype=torch.float32
994-
)
992+
993+
face_video_width, face_video_height = face_video[0].size
994+
expected_face_size = self.transformer.config.motion_encoder_size
995+
if face_video_width != expected_face_size or face_video_height != expected_face_size:
996+
logger.warning(
997+
f"Reshaping face video from ({face_video_width}, {face_video_height}) to ({expected_face_size},"
998+
f" {expected_face_size})"
999+
)
1000+
face_video = self.video_processor.preprocess_video(
1001+
face_video, height=expected_face_size, width=expected_face_size
1002+
).to(device, dtype=torch.float32)
9951003

9961004
if mode == "replace":
9971005
background_video = self.pad_video_frames(background_video, num_target_frames)
@@ -1040,8 +1048,8 @@ def __call__(
10401048
# while start + prev_segment_conditioning_frames < len(pose_video):
10411049
for _ in range(num_segments):
10421050
assert start + prev_segment_conditioning_frames < cond_video_frames
1043-
pose_video_segment = pose_video[start:end]
1044-
face_video_segment = face_video[start:end]
1051+
pose_video_segment = pose_video[:, :, start:end]
1052+
face_video_segment = face_video[:, :, start:end]
10451053

10461054
face_video_segment = face_video_segment.expand(batch_size * num_videos_per_prompt, -1, -1, -1, -1)
10471055
face_video_segment = face_video_segment.to(dtype=transformer_dtype)
@@ -1052,8 +1060,8 @@ def __call__(
10521060
prev_segment_cond_video = None
10531061

10541062
if mode == "replace":
1055-
background_video_segment = background_video[start:end]
1056-
mask_video_segment = mask_video[start:end]
1063+
background_video_segment = background_video[:, :, start:end]
1064+
mask_video_segment = mask_video[:, :, start:end]
10571065

10581066
background_video_segment = background_video_segment.expand(
10591067
batch_size * num_videos_per_prompt, -1, -1, -1, -1

0 commit comments

Comments
 (0)