Skip to content

Commit a56bee1

Browse files
committed
Get Wan Animate pipeline fp16 inference tests working
1 parent 1e1e706 commit a56bee1

File tree

1 file changed

+17
-8
lines changed

1 file changed

+17
-8
lines changed

src/diffusers/pipelines/wan/pipeline_wan_animate.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -501,12 +501,15 @@ def get_i2v_mask(
501501
latent_w: int,
502502
mask_len: int = 1,
503503
mask_pixel_values: Optional[torch.Tensor] = None,
504+
dtype: Optional[torch.dtype] = None,
504505
device: Union[str, torch.device] = "cuda",
505506
) -> torch.Tensor:
506507
if mask_pixel_values is None:
507-
mask_lat_size = torch.zeros(batch_size, 1, (latent_t - 1) * 4 + 1, latent_h, latent_w, device=device)
508+
mask_lat_size = torch.zeros(
509+
batch_size, 1, (latent_t - 1) * 4 + 1, latent_h, latent_w, dtype=dtype, device=device
510+
)
508511
else:
509-
mask_lat_size = mask_pixel_values.clone()
512+
mask_lat_size = mask_pixel_values.clone().to(device=device, dtype=dtype)
510513
mask_lat_size[:, :, :mask_len] = 1
511514
first_frame_mask = mask_lat_size[:, :, 0:1]
512515
first_frame_mask = torch.repeat_interleave(first_frame_mask, dim=2, repeats=self.vae_scale_factor_temporal)
@@ -527,6 +530,7 @@ def prepare_reference_image_latents(
527530
device: Optional[torch.device] = None,
528531
) -> torch.Tensor:
529532
# image shape: (B, C, H, W) or (B, C, T, H, W)
533+
dtype = dtype or self.vae.dtype
530534
if image.ndim == 4:
531535
# Add a singleton frame dimension after the channels dimension
532536
image = image.unsqueeze(2)
@@ -536,7 +540,7 @@ def prepare_reference_image_latents(
536540
latent_width = width // self.vae_scale_factor_spatial
537541

538542
# Encode image to latents using VAE
539-
image = image.to(device=device, dtype=dtype if dtype is not None else self.vae.dtype)
543+
image = image.to(device=device, dtype=dtype)
540544
if isinstance(generator, list):
541545
# Like in prepare_latents, assume len(generator) == batch_size
542546
ref_image_latents = [
@@ -552,7 +556,7 @@ def prepare_reference_image_latents(
552556
ref_image_latents = ref_image_latents.expand(batch_size, -1, -1, -1, -1)
553557

554558
# Prepare I2V mask in latent space and prepend to the reference image latents along channel dim
555-
reference_image_mask = self.get_i2v_mask(batch_size, 1, latent_height, latent_width, 1, None, device)
559+
reference_image_mask = self.get_i2v_mask(batch_size, 1, latent_height, latent_width, 1, None, dtype, device)
556560
reference_image_latents = torch.cat([reference_image_mask, ref_image_latents], dim=1)
557561

558562
return reference_image_latents
@@ -575,12 +579,13 @@ def prepare_prev_segment_cond_latents(
575579
device: Optional[torch.device] = None,
576580
) -> torch.Tensor:
577581
# prev_segment_cond_video shape: (B, C, T, H, W) in pixel space if supplied
582+
dtype = dtype or self.vae.dtype
578583
if prev_segment_cond_video is None:
579584
if task == "replace":
580-
prev_segment_cond_video = background_video[:, :, :prev_segment_cond_frames]
585+
prev_segment_cond_video = background_video[:, :, :prev_segment_cond_frames].to(dtype)
581586
else:
582587
cond_frames_shape = (batch_size, 3, prev_segment_cond_frames, height, width) # In pixel space
583-
prev_segment_cond_video = torch.zeros(cond_frames_shape, device=device)
588+
prev_segment_cond_video = torch.zeros(cond_frames_shape, dtype=dtype, device=device)
584589

585590
data_batch_size, channels, _, segment_height, segment_width = prev_segment_cond_video.shape
586591
num_latent_frames = (segment_frame_length - 1) // self.vae_scale_factor_temporal + 1
@@ -596,14 +601,15 @@ def prepare_prev_segment_cond_latents(
596601
# replacing).
597602
# TODO: check shapes here
598603
if task == "replace":
599-
remaining_segment = background_video[:, :, prev_segment_cond_frames:]
604+
remaining_segment = background_video[:, :, prev_segment_cond_frames:].to(dtype)
600605
else:
601606
remaining_segment_frames = segment_frame_length - prev_segment_cond_frames
602607
remaining_segment = torch.zeros(
603-
batch_size, channels, remaining_segment_frames, height, width, device=device
608+
batch_size, channels, remaining_segment_frames, height, width, dtype=dtype, device=device
604609
)
605610

606611
# Prepend the conditioning frames from the previous segment to the remaining segment video in the frame dim
612+
prev_segment_cond_video = prev_segment_cond_video.to(dtype=dtype)
607613
full_segment_cond_video = torch.cat([prev_segment_cond_video, remaining_segment], dim=2)
608614

609615
if isinstance(generator, list):
@@ -643,6 +649,7 @@ def prepare_prev_segment_cond_latents(
643649
latent_width,
644650
mask_len=prev_segment_cond_frames,
645651
mask_pixel_values=mask_pixel_values,
652+
dtype=dtype,
646653
device=device,
647654
)
648655

@@ -1031,6 +1038,7 @@ def __call__(
10311038
face_video_segment = face_video[start:end]
10321039

10331040
face_video_segment = face_video_segment.expand(batch_size * num_videos_per_prompt, -1, -1, -1, -1)
1041+
face_video_segment = face_video_segment.to(dtype=transformer_dtype)
10341042

10351043
if start > 0:
10361044
# TODO: check shapes here, why do we take index 0 in the first dim.?
@@ -1053,6 +1061,7 @@ def __call__(
10531061
pose_latents = self.prepare_pose_latents(
10541062
pose_video_segment, batch_size * num_videos_per_prompt, generator=generator, device=device
10551063
)
1064+
pose_latents = pose_latents.to(dtype=transformer_dtype)
10561065

10571066
prev_segment_cond_latents = self.prepare_prev_segment_cond_latents(
10581067
prev_segment_cond_video,

0 commit comments

Comments
 (0)