Skip to content

Commit 332d3c2

Browse files
committed
Update Wan Animate pipeline tests after transformer an pipeline changes
1 parent 2537133 commit 332d3c2

File tree

1 file changed

+64
-50
lines changed

1 file changed

+64
-50
lines changed

tests/pipelines/wan/test_wan_animate.py

Lines changed: 64 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import gc
1516
import unittest
1617

1718
import numpy as np
@@ -32,7 +33,13 @@
3233
WanAnimateTransformer3DModel,
3334
)
3435

35-
from ...testing_utils import enable_full_determinism
36+
from ...testing_utils import (
37+
backend_empty_cache,
38+
enable_full_determinism,
39+
require_torch_accelerator,
40+
slow,
41+
torch_device,
42+
)
3643
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
3744
from ..test_pipelines_common import PipelineTesterMixin
3845

@@ -75,21 +82,30 @@ def get_dummy_components(self):
7582
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
7683

7784
torch.manual_seed(0)
85+
channel_sizes = {"4": 16, "8": 16, "16": 16}
7886
transformer = WanAnimateTransformer3DModel(
7987
patch_size=(1, 2, 2),
8088
num_attention_heads=2,
8189
attention_head_dim=12,
8290
in_channels=36,
91+
latent_channels=16,
8392
out_channels=16,
8493
text_dim=32,
8594
freq_dim=256,
8695
ffn_dim=32,
8796
num_layers=2,
8897
cross_attn_norm=True,
8998
qk_norm="rms_norm_across_heads",
90-
rope_max_seq_len=32,
9199
image_dim=4,
92-
pos_embed_seq_len=2 * (4 * 4 + 1),
100+
rope_max_seq_len=32,
101+
motion_encoder_channel_sizes=channel_sizes,
102+
motion_encoder_size=16,
103+
motion_style_dim=8,
104+
motion_dim=4,
105+
motion_encoder_dim=16,
106+
face_encoder_hidden_dim=16,
107+
face_encoder_num_heads=2,
108+
inject_face_latents_blocks=2,
93109
)
94110

95111
torch.manual_seed(0)
@@ -127,27 +143,29 @@ def get_dummy_inputs(self, device, seed=0):
127143
num_frames = 17
128144
height = 16
129145
width = 16
146+
face_height = 16
147+
face_width = 16
130148

131-
pose_video = [Image.new("RGB", (height, width))] * num_frames
132-
face_video = [Image.new("RGB", (height, width))] * num_frames
133149
image = Image.new("RGB", (height, width))
150+
pose_video = [Image.new("RGB", (height, width))] * num_frames
151+
face_video = [Image.new("RGB", (face_height, face_width))] * num_frames
134152

135153
inputs = {
136154
"image": image,
137155
"pose_video": pose_video,
138156
"face_video": face_video,
139157
"prompt": "dance monkey",
140158
"negative_prompt": "negative",
141-
"generator": generator,
142-
"num_inference_steps": 2,
143-
"guidance_scale": 1.0,
144159
"height": height,
145160
"width": width,
146-
"num_frames": num_frames,
147-
"mode": "animation",
148-
"num_frames_for_temporal_guidance": 1,
149-
"max_sequence_length": 16,
161+
"segment_frame_length": 77, # TODO: can we set this to num_frames?
162+
"num_inference_steps": 2,
163+
"mode": "animate",
164+
"prev_segment_conditioning_frames": 1,
165+
"generator": generator,
166+
"guidance_scale": 1.0,
150167
"output_type": "pt",
168+
"max_sequence_length": 16,
151169
}
152170
return inputs
153171

@@ -168,6 +186,26 @@ def test_inference(self):
168186
max_diff = np.abs(video - expected_video).max()
169187
self.assertLessEqual(max_diff, 1e10)
170188

189+
def test_inference_replacement(self):
190+
"""Test the pipeline in replacement mode with background and mask videos."""
191+
device = "cpu"
192+
193+
components = self.get_dummy_components()
194+
pipe = self.pipeline_class(**components)
195+
pipe.to(device)
196+
pipe.set_progress_bar_config(disable=None)
197+
198+
inputs = self.get_dummy_inputs(device)
199+
inputs["mode"] = "replace"
200+
num_frames = 17
201+
height = 16
202+
width = 16
203+
inputs["background_video"] = [Image.new("RGB", (height, width))] * num_frames
204+
inputs["mask_video"] = [Image.new("RGB", (height, width))] * num_frames
205+
206+
video = pipe(**inputs).frames[0]
207+
self.assertEqual(video.shape, (17, 3, 16, 16))
208+
171209
def test_inference_with_single_reference_image(self):
172210
"""Test inference with a single reference image for additional context."""
173211
device = "cpu"
@@ -200,46 +238,22 @@ def test_inference_with_multiple_reference_image(self):
200238
def test_attention_slicing_forward_pass(self):
201239
pass
202240

203-
@unittest.skip("Errors out because passing multiple prompts at once is not yet supported by this pipeline.")
204-
def test_encode_prompt_works_in_isolation(self):
205-
pass
206241

207-
@unittest.skip("Batching is not yet supported with this pipeline")
208-
def test_inference_batch_consistent(self):
209-
pass
242+
@slow
243+
@require_torch_accelerator
244+
class WanAnimatePipelineIntegrationTests(unittest.TestCase):
245+
prompt = "A painting of a squirrel eating a burger."
210246

211-
@unittest.skip("Batching is not yet supported with this pipeline")
212-
def test_inference_batch_single_identical(self):
213-
return super().test_inference_batch_single_identical()
247+
def setUp(self):
248+
super().setUp()
249+
gc.collect()
250+
backend_empty_cache(torch_device)
214251

215-
@unittest.skip(
216-
"AutoencoderKLWan encoded latents are always in FP32. This test is not designed to handle mixed dtype inputs"
217-
)
218-
def test_float16_inference(self):
219-
pass
252+
def tearDown(self):
253+
super().tearDown()
254+
gc.collect()
255+
backend_empty_cache(torch_device)
220256

221-
@unittest.skip(
222-
"AutoencoderKLWan encoded latents are always in FP32. This test is not designed to handle mixed dtype inputs"
223-
)
224-
def test_save_load_float16(self):
257+
@unittest.skip("TODO: test needs to be implemented")
258+
def test_wan_animate(self):
225259
pass
226-
227-
def test_inference_replacement_mode(self):
228-
"""Test the pipeline in replacement mode with background and mask videos."""
229-
device = "cpu"
230-
231-
components = self.get_dummy_components()
232-
pipe = self.pipeline_class(**components)
233-
pipe.to(device)
234-
pipe.set_progress_bar_config(disable=None)
235-
236-
inputs = self.get_dummy_inputs(device)
237-
inputs["mode"] = "replacement"
238-
num_frames = 17
239-
height = 16
240-
width = 16
241-
inputs["background_video"] = [Image.new("RGB", (height, width))] * num_frames
242-
inputs["mask_video"] = [Image.new("RGB", (height, width))] * num_frames
243-
244-
video = pipe(**inputs).frames[0]
245-
self.assertEqual(video.shape, (17, 3, 16, 16))

0 commit comments

Comments
 (0)