1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15+ import gc
1516import unittest
1617
1718import numpy as np
3233 WanAnimateTransformer3DModel ,
3334)
3435
35- from ...testing_utils import enable_full_determinism
36+ from ...testing_utils import (
37+ backend_empty_cache ,
38+ enable_full_determinism ,
39+ require_torch_accelerator ,
40+ slow ,
41+ torch_device ,
42+ )
3643from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS , TEXT_TO_IMAGE_IMAGE_PARAMS , TEXT_TO_IMAGE_PARAMS
3744from ..test_pipelines_common import PipelineTesterMixin
3845
@@ -75,21 +82,30 @@ def get_dummy_components(self):
7582 tokenizer = AutoTokenizer .from_pretrained ("hf-internal-testing/tiny-random-t5" )
7683
7784 torch .manual_seed (0 )
85+ channel_sizes = {"4" : 16 , "8" : 16 , "16" : 16 }
7886 transformer = WanAnimateTransformer3DModel (
7987 patch_size = (1 , 2 , 2 ),
8088 num_attention_heads = 2 ,
8189 attention_head_dim = 12 ,
8290 in_channels = 36 ,
91+ latent_channels = 16 ,
8392 out_channels = 16 ,
8493 text_dim = 32 ,
8594 freq_dim = 256 ,
8695 ffn_dim = 32 ,
8796 num_layers = 2 ,
8897 cross_attn_norm = True ,
8998 qk_norm = "rms_norm_across_heads" ,
90- rope_max_seq_len = 32 ,
9199 image_dim = 4 ,
92- pos_embed_seq_len = 2 * (4 * 4 + 1 ),
100+ rope_max_seq_len = 32 ,
101+ motion_encoder_channel_sizes = channel_sizes ,
102+ motion_encoder_size = 16 ,
103+ motion_style_dim = 8 ,
104+ motion_dim = 4 ,
105+ motion_encoder_dim = 16 ,
106+ face_encoder_hidden_dim = 16 ,
107+ face_encoder_num_heads = 2 ,
108+ inject_face_latents_blocks = 2 ,
93109 )
94110
95111 torch .manual_seed (0 )
@@ -127,27 +143,29 @@ def get_dummy_inputs(self, device, seed=0):
127143 num_frames = 17
128144 height = 16
129145 width = 16
146+ face_height = 16
147+ face_width = 16
130148
131- pose_video = [Image .new ("RGB" , (height , width ))] * num_frames
132- face_video = [Image .new ("RGB" , (height , width ))] * num_frames
133149 image = Image .new ("RGB" , (height , width ))
150+ pose_video = [Image .new ("RGB" , (height , width ))] * num_frames
151+ face_video = [Image .new ("RGB" , (face_height , face_width ))] * num_frames
134152
135153 inputs = {
136154 "image" : image ,
137155 "pose_video" : pose_video ,
138156 "face_video" : face_video ,
139157 "prompt" : "dance monkey" ,
140158 "negative_prompt" : "negative" ,
141- "generator" : generator ,
142- "num_inference_steps" : 2 ,
143- "guidance_scale" : 1.0 ,
144159 "height" : height ,
145160 "width" : width ,
146- "num_frames" : num_frames ,
147- "mode" : "animation" ,
148- "num_frames_for_temporal_guidance" : 1 ,
149- "max_sequence_length" : 16 ,
161+ "segment_frame_length" : 77 , # TODO: can we set this to num_frames?
162+ "num_inference_steps" : 2 ,
163+ "mode" : "animate" ,
164+ "prev_segment_conditioning_frames" : 1 ,
165+ "generator" : generator ,
166+ "guidance_scale" : 1.0 ,
150167 "output_type" : "pt" ,
168+ "max_sequence_length" : 16 ,
151169 }
152170 return inputs
153171
@@ -168,6 +186,26 @@ def test_inference(self):
168186 max_diff = np .abs (video - expected_video ).max ()
169187 self .assertLessEqual (max_diff , 1e10 )
170188
189+ def test_inference_replacement (self ):
190+ """Test the pipeline in replacement mode with background and mask videos."""
191+ device = "cpu"
192+
193+ components = self .get_dummy_components ()
194+ pipe = self .pipeline_class (** components )
195+ pipe .to (device )
196+ pipe .set_progress_bar_config (disable = None )
197+
198+ inputs = self .get_dummy_inputs (device )
199+ inputs ["mode" ] = "replace"
200+ num_frames = 17
201+ height = 16
202+ width = 16
203+ inputs ["background_video" ] = [Image .new ("RGB" , (height , width ))] * num_frames
204+ inputs ["mask_video" ] = [Image .new ("RGB" , (height , width ))] * num_frames
205+
206+ video = pipe (** inputs ).frames [0 ]
207+ self .assertEqual (video .shape , (17 , 3 , 16 , 16 ))
208+
171209 def test_inference_with_single_reference_image (self ):
172210 """Test inference with a single reference image for additional context."""
173211 device = "cpu"
@@ -200,46 +238,22 @@ def test_inference_with_multiple_reference_image(self):
200238 def test_attention_slicing_forward_pass (self ):
201239 pass
202240
203- @unittest .skip ("Errors out because passing multiple prompts at once is not yet supported by this pipeline." )
204- def test_encode_prompt_works_in_isolation (self ):
205- pass
206241
207- @unittest .skip ("Batching is not yet supported with this pipeline" )
208- def test_inference_batch_consistent (self ):
209- pass
242+ @slow
243+ @require_torch_accelerator
244+ class WanAnimatePipelineIntegrationTests (unittest .TestCase ):
245+ prompt = "A painting of a squirrel eating a burger."
210246
211- @unittest .skip ("Batching is not yet supported with this pipeline" )
212- def test_inference_batch_single_identical (self ):
213- return super ().test_inference_batch_single_identical ()
247+ def setUp (self ):
248+ super ().setUp ()
249+ gc .collect ()
250+ backend_empty_cache (torch_device )
214251
215- @unittest .skip (
216- "AutoencoderKLWan encoded latents are always in FP32. This test is not designed to handle mixed dtype inputs"
217- )
218- def test_float16_inference (self ):
219- pass
252+ def tearDown (self ):
253+ super ().tearDown ()
254+ gc .collect ()
255+ backend_empty_cache (torch_device )
220256
221- @unittest .skip (
222- "AutoencoderKLWan encoded latents are always in FP32. This test is not designed to handle mixed dtype inputs"
223- )
224- def test_save_load_float16 (self ):
257+ @unittest .skip ("TODO: test needs to be implemented" )
258+ def test_wan_animate (self ):
225259 pass
226-
227- def test_inference_replacement_mode (self ):
228- """Test the pipeline in replacement mode with background and mask videos."""
229- device = "cpu"
230-
231- components = self .get_dummy_components ()
232- pipe = self .pipeline_class (** components )
233- pipe .to (device )
234- pipe .set_progress_bar_config (disable = None )
235-
236- inputs = self .get_dummy_inputs (device )
237- inputs ["mode" ] = "replacement"
238- num_frames = 17
239- height = 16
240- width = 16
241- inputs ["background_video" ] = [Image .new ("RGB" , (height , width ))] * num_frames
242- inputs ["mask_video" ] = [Image .new ("RGB" , (height , width ))] * num_frames
243-
244- video = pipe (** inputs ).frames [0 ]
245- self .assertEqual (video .shape , (17 , 3 , 16 , 16 ))
0 commit comments