@@ -501,12 +501,15 @@ def get_i2v_mask(
501501 latent_w : int ,
502502 mask_len : int = 1 ,
503503 mask_pixel_values : Optional [torch .Tensor ] = None ,
504+ dtype : Optional [torch .dtype ] = None ,
504505 device : Union [str , torch .device ] = "cuda" ,
505506 ) -> torch .Tensor :
506507 if mask_pixel_values is None :
507- mask_lat_size = torch .zeros (batch_size , 1 , (latent_t - 1 ) * 4 + 1 , latent_h , latent_w , device = device )
508+ mask_lat_size = torch .zeros (
509+ batch_size , 1 , (latent_t - 1 ) * 4 + 1 , latent_h , latent_w , dtype = dtype , device = device
510+ )
508511 else :
509- mask_lat_size = mask_pixel_values .clone ()
512+ mask_lat_size = mask_pixel_values .clone (). to ( device = device , dtype = dtype )
510513 mask_lat_size [:, :, :mask_len ] = 1
511514 first_frame_mask = mask_lat_size [:, :, 0 :1 ]
512515 first_frame_mask = torch .repeat_interleave (first_frame_mask , dim = 2 , repeats = self .vae_scale_factor_temporal )
@@ -527,6 +530,7 @@ def prepare_reference_image_latents(
527530 device : Optional [torch .device ] = None ,
528531 ) -> torch .Tensor :
529532 # image shape: (B, C, H, W) or (B, C, T, H, W)
533+ dtype = dtype or self .vae .dtype
530534 if image .ndim == 4 :
531535 # Add a singleton frame dimension after the channels dimension
532536 image = image .unsqueeze (2 )
@@ -536,7 +540,7 @@ def prepare_reference_image_latents(
536540 latent_width = width // self .vae_scale_factor_spatial
537541
538542 # Encode image to latents using VAE
539- image = image .to (device = device , dtype = dtype if dtype is not None else self . vae . dtype )
543+ image = image .to (device = device , dtype = dtype )
540544 if isinstance (generator , list ):
541545 # Like in prepare_latents, assume len(generator) == batch_size
542546 ref_image_latents = [
@@ -552,7 +556,7 @@ def prepare_reference_image_latents(
552556 ref_image_latents = ref_image_latents .expand (batch_size , - 1 , - 1 , - 1 , - 1 )
553557
554558 # Prepare I2V mask in latent space and prepend to the reference image latents along channel dim
555- reference_image_mask = self .get_i2v_mask (batch_size , 1 , latent_height , latent_width , 1 , None , device )
559+ reference_image_mask = self .get_i2v_mask (batch_size , 1 , latent_height , latent_width , 1 , None , dtype , device )
556560 reference_image_latents = torch .cat ([reference_image_mask , ref_image_latents ], dim = 1 )
557561
558562 return reference_image_latents
@@ -575,12 +579,13 @@ def prepare_prev_segment_cond_latents(
575579 device : Optional [torch .device ] = None ,
576580 ) -> torch .Tensor :
577581 # prev_segment_cond_video shape: (B, C, T, H, W) in pixel space if supplied
582+ dtype = dtype or self .vae .dtype
578583 if prev_segment_cond_video is None :
579584 if task == "replace" :
580- prev_segment_cond_video = background_video [:, :, :prev_segment_cond_frames ]
585+ prev_segment_cond_video = background_video [:, :, :prev_segment_cond_frames ]. to ( dtype )
581586 else :
582587 cond_frames_shape = (batch_size , 3 , prev_segment_cond_frames , height , width ) # In pixel space
583- prev_segment_cond_video = torch .zeros (cond_frames_shape , device = device )
588+ prev_segment_cond_video = torch .zeros (cond_frames_shape , dtype = dtype , device = device )
584589
585590 data_batch_size , channels , _ , segment_height , segment_width = prev_segment_cond_video .shape
586591 num_latent_frames = (segment_frame_length - 1 ) // self .vae_scale_factor_temporal + 1
@@ -596,14 +601,15 @@ def prepare_prev_segment_cond_latents(
596601 # replacing).
597602 # TODO: check shapes here
598603 if task == "replace" :
599- remaining_segment = background_video [:, :, prev_segment_cond_frames :]
604+ remaining_segment = background_video [:, :, prev_segment_cond_frames :]. to ( dtype )
600605 else :
601606 remaining_segment_frames = segment_frame_length - prev_segment_cond_frames
602607 remaining_segment = torch .zeros (
603- batch_size , channels , remaining_segment_frames , height , width , device = device
608+ batch_size , channels , remaining_segment_frames , height , width , dtype = dtype , device = device
604609 )
605610
606611 # Prepend the conditioning frames from the previous segment to the remaining segment video in the frame dim
612+ prev_segment_cond_video = prev_segment_cond_video .to (dtype = dtype )
607613 full_segment_cond_video = torch .cat ([prev_segment_cond_video , remaining_segment ], dim = 2 )
608614
609615 if isinstance (generator , list ):
@@ -643,6 +649,7 @@ def prepare_prev_segment_cond_latents(
643649 latent_width ,
644650 mask_len = prev_segment_cond_frames ,
645651 mask_pixel_values = mask_pixel_values ,
652+ dtype = dtype ,
646653 device = device ,
647654 )
648655
@@ -1031,6 +1038,7 @@ def __call__(
10311038 face_video_segment = face_video [start :end ]
10321039
10331040 face_video_segment = face_video_segment .expand (batch_size * num_videos_per_prompt , - 1 , - 1 , - 1 , - 1 )
1041+ face_video_segment = face_video_segment .to (dtype = transformer_dtype )
10341042
10351043 if start > 0 :
10361044 # TODO: check shapes here, why do we take index 0 in the first dim.?
@@ -1053,6 +1061,7 @@ def __call__(
10531061 pose_latents = self .prepare_pose_latents (
10541062 pose_video_segment , batch_size * num_videos_per_prompt , generator = generator , device = device
10551063 )
1064+ pose_latents = pose_latents .to (dtype = transformer_dtype )
10561065
10571066 prev_segment_cond_latents = self .prepare_prev_segment_cond_latents (
10581067 prev_segment_cond_video ,
0 commit comments