Skip to content

Commit 178d32d

Browse files
authored
[tests] Add test slices for Wan (#11920)
* update * fix wan vace test slice * test * fix
1 parent ef1e628 commit 178d32d

File tree

3 files changed

+73
-19
lines changed

3 files changed

+73
-19
lines changed

tests/pipelines/wan/test_wan.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
import gc
1616
import unittest
1717

18-
import numpy as np
1918
import torch
2019
from transformers import AutoTokenizer, T5EncoderModel
2120

@@ -29,9 +28,7 @@
2928
)
3029

3130
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
32-
from ..test_pipelines_common import (
33-
PipelineTesterMixin,
34-
)
31+
from ..test_pipelines_common import PipelineTesterMixin
3532

3633

3734
enable_full_determinism()
@@ -127,11 +124,15 @@ def test_inference(self):
127124
inputs = self.get_dummy_inputs(device)
128125
video = pipe(**inputs).frames
129126
generated_video = video[0]
130-
131127
self.assertEqual(generated_video.shape, (9, 3, 16, 16))
132-
expected_video = torch.randn(9, 3, 16, 16)
133-
max_diff = np.abs(generated_video - expected_video).max()
134-
self.assertLessEqual(max_diff, 1e10)
128+
129+
# fmt: off
130+
expected_slice = torch.tensor([0.4525, 0.452, 0.4485, 0.4534, 0.4524, 0.4529, 0.454, 0.453, 0.5127, 0.5326, 0.5204, 0.5253, 0.5439, 0.5424, 0.5133, 0.5078])
131+
# fmt: on
132+
133+
generated_slice = generated_video.flatten()
134+
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
135+
self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3))
135136

136137
@unittest.skip("Test not supported")
137138
def test_attention_slicing_forward_pass(self):

tests/pipelines/wan/test_wan_image_to_video.py

Lines changed: 56 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
import unittest
1616

17-
import numpy as np
1817
import torch
1918
from PIL import Image
2019
from transformers import (
@@ -147,11 +146,15 @@ def test_inference(self):
147146
inputs = self.get_dummy_inputs(device)
148147
video = pipe(**inputs).frames
149148
generated_video = video[0]
150-
151149
self.assertEqual(generated_video.shape, (9, 3, 16, 16))
152-
expected_video = torch.randn(9, 3, 16, 16)
153-
max_diff = np.abs(generated_video - expected_video).max()
154-
self.assertLessEqual(max_diff, 1e10)
150+
151+
# fmt: off
152+
expected_slice = torch.tensor([0.4525, 0.4525, 0.4497, 0.4536, 0.452, 0.4529, 0.454, 0.4535, 0.5072, 0.5527, 0.5165, 0.5244, 0.5481, 0.5282, 0.5208, 0.5214])
153+
# fmt: on
154+
155+
generated_slice = generated_video.flatten()
156+
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
157+
self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3))
155158

156159
@unittest.skip("Test not supported")
157160
def test_attention_slicing_forward_pass(self):
@@ -162,7 +165,25 @@ def test_inference_batch_single_identical(self):
162165
pass
163166

164167

165-
class WanFLFToVideoPipelineFastTests(WanImageToVideoPipelineFastTests):
168+
class WanFLFToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
169+
pipeline_class = WanImageToVideoPipeline
170+
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs", "height", "width"}
171+
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
172+
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
173+
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
174+
required_optional_params = frozenset(
175+
[
176+
"num_inference_steps",
177+
"generator",
178+
"latents",
179+
"return_dict",
180+
"callback_on_step_end",
181+
"callback_on_step_end_tensor_inputs",
182+
]
183+
)
184+
test_xformers_attention = False
185+
supports_dduf = False
186+
166187
def get_dummy_components(self):
167188
torch.manual_seed(0)
168189
vae = AutoencoderKLWan(
@@ -247,3 +268,32 @@ def get_dummy_inputs(self, device, seed=0):
247268
"output_type": "pt",
248269
}
249270
return inputs
271+
272+
def test_inference(self):
273+
device = "cpu"
274+
275+
components = self.get_dummy_components()
276+
pipe = self.pipeline_class(**components)
277+
pipe.to(device)
278+
pipe.set_progress_bar_config(disable=None)
279+
280+
inputs = self.get_dummy_inputs(device)
281+
video = pipe(**inputs).frames
282+
generated_video = video[0]
283+
self.assertEqual(generated_video.shape, (9, 3, 16, 16))
284+
285+
# fmt: off
286+
expected_slice = torch.tensor([0.4531, 0.4527, 0.4498, 0.4542, 0.4526, 0.4527, 0.4534, 0.4534, 0.5061, 0.5185, 0.5283, 0.5181, 0.5309, 0.5365, 0.5113, 0.5244])
287+
# fmt: on
288+
289+
generated_slice = generated_video.flatten()
290+
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
291+
self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3))
292+
293+
@unittest.skip("Test not supported")
294+
def test_attention_slicing_forward_pass(self):
295+
pass
296+
297+
@unittest.skip("TODO: revisit failing as it requires a very high threshold to pass")
298+
def test_inference_batch_single_identical(self):
299+
pass

tests/pipelines/wan/test_wan_video_to_video.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
import unittest
1616

17-
import numpy as np
1817
import torch
1918
from PIL import Image
2019
from transformers import AutoTokenizer, T5EncoderModel
@@ -123,11 +122,15 @@ def test_inference(self):
123122
inputs = self.get_dummy_inputs(device)
124123
video = pipe(**inputs).frames
125124
generated_video = video[0]
126-
127125
self.assertEqual(generated_video.shape, (17, 3, 16, 16))
128-
expected_video = torch.randn(17, 3, 16, 16)
129-
max_diff = np.abs(generated_video - expected_video).max()
130-
self.assertLessEqual(max_diff, 1e10)
126+
127+
# fmt: off
128+
expected_slice = torch.tensor([0.4522, 0.4534, 0.4532, 0.4553, 0.4526, 0.4538, 0.4533, 0.4547, 0.513, 0.5176, 0.5286, 0.4958, 0.4955, 0.5381, 0.5154, 0.5195])
129+
# fmt:on
130+
131+
generated_slice = generated_video.flatten()
132+
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
133+
self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3))
131134

132135
@unittest.skip("Test not supported")
133136
def test_attention_slicing_forward_pass(self):

0 commit comments

Comments
 (0)