@@ -989,9 +989,17 @@ def __call__(
989989 pose_video = self .video_processor .preprocess_video (pose_video , height = height , width = width ).to (
990990 device , dtype = torch .float32
991991 )
992- face_video = self .video_processor .preprocess_video (face_video , height = height , width = width ).to (
993- device , dtype = torch .float32
994- )
992+
993+ face_video_width , face_video_height = face_video [0 ].size
994+ expected_face_size = self .transformer .config .motion_encoder_size
995+ if face_video_width != expected_face_size or face_video_height != expected_face_size :
996+ logger .warning (
997+ f"Reshaping face video from ({ face_video_width } , { face_video_height } ) to ({ expected_face_size } ,"
998+ f" { expected_face_size } )"
999+ )
1000+ face_video = self .video_processor .preprocess_video (
1001+ face_video , height = expected_face_size , width = expected_face_size
1002+ ).to (device , dtype = torch .float32 )
9951003
9961004 if mode == "replace" :
9971005 background_video = self .pad_video_frames (background_video , num_target_frames )
@@ -1040,8 +1048,8 @@ def __call__(
10401048 # while start + prev_segment_conditioning_frames < len(pose_video):
10411049 for _ in range (num_segments ):
10421050 assert start + prev_segment_conditioning_frames < cond_video_frames
1043- pose_video_segment = pose_video [start :end ]
1044- face_video_segment = face_video [start :end ]
1051+ pose_video_segment = pose_video [:, :, start :end ]
1052+ face_video_segment = face_video [:, :, start :end ]
10451053
10461054 face_video_segment = face_video_segment .expand (batch_size * num_videos_per_prompt , - 1 , - 1 , - 1 , - 1 )
10471055 face_video_segment = face_video_segment .to (dtype = transformer_dtype )
@@ -1052,8 +1060,8 @@ def __call__(
10521060 prev_segment_cond_video = None
10531061
10541062 if mode == "replace" :
1055- background_video_segment = background_video [start :end ]
1056- mask_video_segment = mask_video [start :end ]
1063+ background_video_segment = background_video [:, :, start :end ]
1064+ mask_video_segment = mask_video [:, :, start :end ]
10571065
10581066 background_video_segment = background_video_segment .expand (
10591067 batch_size * num_videos_per_prompt , - 1 , - 1 , - 1 , - 1
0 commit comments