14
14
15
15
import unittest
16
16
17
- import numpy as np
18
17
import torch
19
18
from PIL import Image
20
19
from transformers import (
@@ -147,11 +146,15 @@ def test_inference(self):
147
146
inputs = self .get_dummy_inputs (device )
148
147
video = pipe (** inputs ).frames
149
148
generated_video = video [0 ]
150
-
151
149
self .assertEqual (generated_video .shape , (9 , 3 , 16 , 16 ))
152
- expected_video = torch .randn (9 , 3 , 16 , 16 )
153
- max_diff = np .abs (generated_video - expected_video ).max ()
154
- self .assertLessEqual (max_diff , 1e10 )
150
+
151
+ # fmt: off
152
+ expected_slice = torch .tensor ([0.4525 , 0.4525 , 0.4497 , 0.4536 , 0.452 , 0.4529 , 0.454 , 0.4535 , 0.5072 , 0.5527 , 0.5165 , 0.5244 , 0.5481 , 0.5282 , 0.5208 , 0.5214 ])
153
+ # fmt: on
154
+
155
+ generated_slice = generated_video .flatten ()
156
+ generated_slice = torch .cat ([generated_slice [:8 ], generated_slice [- 8 :]])
157
+ self .assertTrue (torch .allclose (generated_slice , expected_slice , atol = 1e-3 ))
155
158
156
159
@unittest .skip ("Test not supported" )
157
160
def test_attention_slicing_forward_pass (self ):
@@ -162,7 +165,25 @@ def test_inference_batch_single_identical(self):
162
165
pass
163
166
164
167
165
- class WanFLFToVideoPipelineFastTests (WanImageToVideoPipelineFastTests ):
168
+ class WanFLFToVideoPipelineFastTests (PipelineTesterMixin , unittest .TestCase ):
169
+ pipeline_class = WanImageToVideoPipeline
170
+ params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs" , "height" , "width" }
171
+ batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
172
+ image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
173
+ image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
174
+ required_optional_params = frozenset (
175
+ [
176
+ "num_inference_steps" ,
177
+ "generator" ,
178
+ "latents" ,
179
+ "return_dict" ,
180
+ "callback_on_step_end" ,
181
+ "callback_on_step_end_tensor_inputs" ,
182
+ ]
183
+ )
184
+ test_xformers_attention = False
185
+ supports_dduf = False
186
+
166
187
def get_dummy_components (self ):
167
188
torch .manual_seed (0 )
168
189
vae = AutoencoderKLWan (
@@ -247,3 +268,32 @@ def get_dummy_inputs(self, device, seed=0):
247
268
"output_type" : "pt" ,
248
269
}
249
270
return inputs
271
+
272
+ def test_inference (self ):
273
+ device = "cpu"
274
+
275
+ components = self .get_dummy_components ()
276
+ pipe = self .pipeline_class (** components )
277
+ pipe .to (device )
278
+ pipe .set_progress_bar_config (disable = None )
279
+
280
+ inputs = self .get_dummy_inputs (device )
281
+ video = pipe (** inputs ).frames
282
+ generated_video = video [0 ]
283
+ self .assertEqual (generated_video .shape , (9 , 3 , 16 , 16 ))
284
+
285
+ # fmt: off
286
+ expected_slice = torch .tensor ([0.4531 , 0.4527 , 0.4498 , 0.4542 , 0.4526 , 0.4527 , 0.4534 , 0.4534 , 0.5061 , 0.5185 , 0.5283 , 0.5181 , 0.5309 , 0.5365 , 0.5113 , 0.5244 ])
287
+ # fmt: on
288
+
289
+ generated_slice = generated_video .flatten ()
290
+ generated_slice = torch .cat ([generated_slice [:8 ], generated_slice [- 8 :]])
291
+ self .assertTrue (torch .allclose (generated_slice , expected_slice , atol = 1e-3 ))
292
+
293
+ @unittest .skip ("Test not supported" )
294
+ def test_attention_slicing_forward_pass (self ):
295
+ pass
296
+
297
+ @unittest .skip ("TODO: revisit failing as it requires a very high threshold to pass" )
298
+ def test_inference_batch_single_identical (self ):
299
+ pass
0 commit comments