Skip to content

Commit 2d38089

Browse files
authored
enable 7 cases on XPU (#11503)
* enable 7 cases on XPU Signed-off-by: Yao Matrix <matrix.yao@intel.com> * calibrate A100 expectations Signed-off-by: YAO Matrix <matrix.yao@intel.com> --------- Signed-off-by: Yao Matrix <matrix.yao@intel.com> Signed-off-by: YAO Matrix <matrix.yao@intel.com>
1 parent 0c47c95 commit 2d38089

File tree

7 files changed

+104
-51
lines changed

7 files changed

+104
-51
lines changed

tests/pipelines/consisid/test_consisid.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,10 @@
2424
from diffusers import AutoencoderKLCogVideoX, ConsisIDPipeline, ConsisIDTransformer3DModel, DDIMScheduler
2525
from diffusers.utils import load_image
2626
from diffusers.utils.testing_utils import (
27+
backend_empty_cache,
2728
enable_full_determinism,
2829
numpy_cosine_similarity_distance,
29-
require_torch_gpu,
30+
require_torch_accelerator,
3031
slow,
3132
torch_device,
3233
)
@@ -316,19 +317,19 @@ def test_vae_tiling(self, expected_diff_max: float = 0.4):
316317

317318

318319
@slow
319-
@require_torch_gpu
320+
@require_torch_accelerator
320321
class ConsisIDPipelineIntegrationTests(unittest.TestCase):
321322
prompt = "A painting of a squirrel eating a burger."
322323

323324
def setUp(self):
324325
super().setUp()
325326
gc.collect()
326-
torch.cuda.empty_cache()
327+
backend_empty_cache(torch_device)
327328

328329
def tearDown(self):
329330
super().tearDown()
330331
gc.collect()
331-
torch.cuda.empty_cache()
332+
backend_empty_cache(torch_device)
332333

333334
def test_consisid(self):
334335
generator = torch.Generator("cpu").manual_seed(0)
@@ -338,8 +339,8 @@ def test_consisid(self):
338339

339340
prompt = self.prompt
340341
image = load_image("https://github.com/PKU-YuanGroup/ConsisID/blob/main/asserts/example_images/2.png?raw=true")
341-
id_vit_hidden = [torch.ones([1, 2, 2])] * 1
342-
id_cond = torch.ones(1, 2)
342+
id_vit_hidden = [torch.ones([1, 577, 1024])] * 5
343+
id_cond = torch.ones(1, 1280)
343344

344345
videos = pipe(
345346
image=image,
@@ -357,5 +358,5 @@ def test_consisid(self):
357358
video = videos[0]
358359
expected_video = torch.randn(1, 16, 480, 720, 3).numpy()
359360

360-
max_diff = numpy_cosine_similarity_distance(video, expected_video)
361+
max_diff = numpy_cosine_similarity_distance(video.cpu(), expected_video)
361362
assert max_diff < 1e-3, f"Max diff is too high. got {video}"

tests/pipelines/easyanimate/test_easyanimate.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,10 @@
2727
FlowMatchEulerDiscreteScheduler,
2828
)
2929
from diffusers.utils.testing_utils import (
30+
backend_empty_cache,
3031
enable_full_determinism,
3132
numpy_cosine_similarity_distance,
32-
require_torch_gpu,
33+
require_torch_accelerator,
3334
slow,
3435
torch_device,
3536
)
@@ -256,19 +257,19 @@ def test_encode_prompt_works_in_isolation(self):
256257

257258

258259
@slow
259-
@require_torch_gpu
260+
@require_torch_accelerator
260261
class EasyAnimatePipelineIntegrationTests(unittest.TestCase):
261262
prompt = "A painting of a squirrel eating a burger."
262263

263264
def setUp(self):
264265
super().setUp()
265266
gc.collect()
266-
torch.cuda.empty_cache()
267+
backend_empty_cache(torch_device)
267268

268269
def tearDown(self):
269270
super().tearDown()
270271
gc.collect()
271-
torch.cuda.empty_cache()
272+
backend_empty_cache(torch_device)
272273

273274
def test_EasyAnimate(self):
274275
generator = torch.Generator("cpu").manual_seed(0)

tests/pipelines/mochi/test_mochi.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@
2727
enable_full_determinism,
2828
nightly,
2929
numpy_cosine_similarity_distance,
30-
require_big_gpu_with_torch_cuda,
31-
require_torch_gpu,
30+
require_big_accelerator,
31+
require_torch_accelerator,
3232
torch_device,
3333
)
3434

@@ -266,9 +266,9 @@ def test_vae_tiling(self, expected_diff_max: float = 0.2):
266266

267267

268268
@nightly
269-
@require_torch_gpu
270-
@require_big_gpu_with_torch_cuda
271-
@pytest.mark.big_gpu_with_torch_cuda
269+
@require_torch_accelerator
270+
@require_big_accelerator
271+
@pytest.mark.big_accelerator
272272
class MochiPipelineIntegrationTests(unittest.TestCase):
273273
prompt = "A painting of a squirrel eating a burger."
274274

@@ -302,5 +302,5 @@ def test_mochi(self):
302302
video = videos[0]
303303
expected_video = torch.randn(1, 19, 480, 848, 3).numpy()
304304

305-
max_diff = numpy_cosine_similarity_distance(video, expected_video)
305+
max_diff = numpy_cosine_similarity_distance(video.cpu(), expected_video)
306306
assert max_diff < 1e-3, f"Max diff is too high. got {video}"

tests/pipelines/omnigen/test_pipeline_omnigen.py

Lines changed: 55 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,10 @@
77

88
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, OmniGenPipeline, OmniGenTransformer2DModel
99
from diffusers.utils.testing_utils import (
10+
Expectations,
11+
backend_empty_cache,
1012
numpy_cosine_similarity_distance,
11-
require_torch_gpu,
13+
require_torch_accelerator,
1214
slow,
1315
torch_device,
1416
)
@@ -87,20 +89,20 @@ def test_inference(self):
8789

8890

8991
@slow
90-
@require_torch_gpu
92+
@require_torch_accelerator
9193
class OmniGenPipelineSlowTests(unittest.TestCase):
9294
pipeline_class = OmniGenPipeline
9395
repo_id = "shitao/OmniGen-v1-diffusers"
9496

9597
def setUp(self):
9698
super().setUp()
9799
gc.collect()
98-
torch.cuda.empty_cache()
100+
backend_empty_cache(torch_device)
99101

100102
def tearDown(self):
101103
super().tearDown()
102104
gc.collect()
103-
torch.cuda.empty_cache()
105+
backend_empty_cache(torch_device)
104106

105107
def get_inputs(self, device, seed=0):
106108
if str(device).startswith("mps"):
@@ -125,21 +127,56 @@ def test_omnigen_inference(self):
125127
image = pipe(**inputs).images[0]
126128
image_slice = image[0, :10, :10]
127129

128-
expected_slice = np.array(
129-
[
130-
[0.1783447, 0.16772744, 0.14339337],
131-
[0.17066911, 0.15521264, 0.13757327],
132-
[0.17072496, 0.15531206, 0.13524258],
133-
[0.16746324, 0.1564025, 0.13794944],
134-
[0.16490817, 0.15258026, 0.13697758],
135-
[0.16971767, 0.15826806, 0.13928896],
136-
[0.16782972, 0.15547255, 0.13783783],
137-
[0.16464645, 0.15281534, 0.13522372],
138-
[0.16535294, 0.15301755, 0.13526791],
139-
[0.16365296, 0.15092957, 0.13443318],
140-
],
141-
dtype=np.float32,
130+
expected_slices = Expectations(
131+
{
132+
("xpu", 3): np.array(
133+
[
134+
[0.05859375, 0.05859375, 0.04492188],
135+
[0.04882812, 0.04101562, 0.03320312],
136+
[0.04882812, 0.04296875, 0.03125],
137+
[0.04296875, 0.0390625, 0.03320312],
138+
[0.04296875, 0.03710938, 0.03125],
139+
[0.04492188, 0.0390625, 0.03320312],
140+
[0.04296875, 0.03710938, 0.03125],
141+
[0.04101562, 0.03710938, 0.02734375],
142+
[0.04101562, 0.03515625, 0.02734375],
143+
[0.04101562, 0.03515625, 0.02929688],
144+
],
145+
dtype=np.float32,
146+
),
147+
("cuda", 7): np.array(
148+
[
149+
[0.1783447, 0.16772744, 0.14339337],
150+
[0.17066911, 0.15521264, 0.13757327],
151+
[0.17072496, 0.15531206, 0.13524258],
152+
[0.16746324, 0.1564025, 0.13794944],
153+
[0.16490817, 0.15258026, 0.13697758],
154+
[0.16971767, 0.15826806, 0.13928896],
155+
[0.16782972, 0.15547255, 0.13783783],
156+
[0.16464645, 0.15281534, 0.13522372],
157+
[0.16535294, 0.15301755, 0.13526791],
158+
[0.16365296, 0.15092957, 0.13443318],
159+
],
160+
dtype=np.float32,
161+
),
162+
("cuda", 8): np.array(
163+
[
164+
[0.0546875, 0.05664062, 0.04296875],
165+
[0.046875, 0.04101562, 0.03320312],
166+
[0.05078125, 0.04296875, 0.03125],
167+
[0.04296875, 0.04101562, 0.03320312],
168+
[0.0390625, 0.03710938, 0.02929688],
169+
[0.04296875, 0.03710938, 0.03125],
170+
[0.0390625, 0.03710938, 0.02929688],
171+
[0.0390625, 0.03710938, 0.02734375],
172+
[0.0390625, 0.03320312, 0.02734375],
173+
[0.0390625, 0.03320312, 0.02734375],
174+
],
175+
dtype=np.float32,
176+
),
177+
}
142178
)
179+
expected_slice = expected_slices.get_expectation()
143180

144181
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten())
145182

tests/pipelines/paint_by_example/test_paint_by_example.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,12 @@
2525
from diffusers import AutoencoderKL, PaintByExamplePipeline, PNDMScheduler, UNet2DConditionModel
2626
from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder
2727
from diffusers.utils.testing_utils import (
28+
backend_empty_cache,
2829
enable_full_determinism,
2930
floats_tensor,
3031
load_image,
3132
nightly,
32-
require_torch_gpu,
33+
require_torch_accelerator,
3334
torch_device,
3435
)
3536

@@ -174,19 +175,19 @@ def test_inference_batch_single_identical(self):
174175

175176

176177
@nightly
177-
@require_torch_gpu
178+
@require_torch_accelerator
178179
class PaintByExamplePipelineIntegrationTests(unittest.TestCase):
179180
def setUp(self):
180181
# clean up the VRAM before each test
181182
super().setUp()
182183
gc.collect()
183-
torch.cuda.empty_cache()
184+
backend_empty_cache(torch_device)
184185

185186
def tearDown(self):
186187
# clean up the VRAM after each test
187188
super().tearDown()
188189
gc.collect()
189-
torch.cuda.empty_cache()
190+
backend_empty_cache(torch_device)
190191

191192
def test_paint_by_example(self):
192193
# make sure here that pndm scheduler skips prk

tests/pipelines/stable_audio/test_stable_audio.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,14 @@
3232
StableAudioProjectionModel,
3333
)
3434
from diffusers.utils import is_xformers_available
35-
from diffusers.utils.testing_utils import enable_full_determinism, nightly, require_torch_gpu, torch_device
35+
from diffusers.utils.testing_utils import (
36+
Expectations,
37+
backend_empty_cache,
38+
enable_full_determinism,
39+
nightly,
40+
require_torch_accelerator,
41+
torch_device,
42+
)
3643

3744
from ..pipeline_params import TEXT_TO_AUDIO_BATCH_PARAMS
3845
from ..test_pipelines_common import PipelineTesterMixin
@@ -419,17 +426,17 @@ def test_encode_prompt_works_in_isolation(self):
419426

420427

421428
@nightly
422-
@require_torch_gpu
429+
@require_torch_accelerator
423430
class StableAudioPipelineIntegrationTests(unittest.TestCase):
424431
def setUp(self):
425432
super().setUp()
426433
gc.collect()
427-
torch.cuda.empty_cache()
434+
backend_empty_cache(torch_device)
428435

429436
def tearDown(self):
430437
super().tearDown()
431438
gc.collect()
432-
torch.cuda.empty_cache()
439+
backend_empty_cache(torch_device)
433440

434441
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
435442
generator = torch.Generator(device=generator_device).manual_seed(seed)
@@ -459,9 +466,15 @@ def test_stable_audio(self):
459466
# check the portion of the generated audio with the largest dynamic range (reduces flakiness)
460467
audio_slice = audio[0, 447590:447600]
461468
# fmt: off
462-
expected_slice = np.array(
463-
[-0.0278, 0.1096, 0.1877, 0.3178, 0.5329, 0.6990, 0.6972, 0.6186, 0.5608, 0.5060]
469+
expected_slices = Expectations(
470+
{
471+
("xpu", 3): np.array([-0.0285, 0.1083, 0.1863, 0.3165, 0.5312, 0.6971, 0.6958, 0.6177, 0.5598, 0.5048]),
472+
("cuda", 7): np.array([-0.0278, 0.1096, 0.1877, 0.3178, 0.5329, 0.6990, 0.6972, 0.6186, 0.5608, 0.5060]),
473+
("cuda", 8): np.array([-0.0285, 0.1082, 0.1862, 0.3163, 0.5306, 0.6964, 0.6953, 0.6172, 0.5593, 0.5044]),
474+
}
464475
)
465-
# fmt: one
476+
# fmt: on
477+
478+
expected_slice = expected_slices.get_expectation()
466479
max_diff = np.abs(expected_slice - audio_slice.detach().cpu().numpy()).max()
467480
assert max_diff < 1.5e-3

tests/quantization/bnb/test_4bit.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,7 @@ def test_bnb_4bit_logs_warning_for_no_quantization(self):
389389
class BnB4BitTrainingTests(Base4bitTests):
390390
def setUp(self):
391391
gc.collect()
392-
torch.cuda.empty_cache()
392+
backend_empty_cache(torch_device)
393393

394394
nf4_config = BitsAndBytesConfig(
395395
load_in_4bit=True,
@@ -657,7 +657,7 @@ def get_dummy_tensor_inputs(device=None, seed: int = 0):
657657
class SlowBnb4BitFluxTests(Base4bitTests):
658658
def setUp(self) -> None:
659659
gc.collect()
660-
torch.cuda.empty_cache()
660+
backend_empty_cache(torch_device)
661661

662662
model_id = "hf-internal-testing/flux.1-dev-nf4-pkg"
663663
t5_4bit = T5EncoderModel.from_pretrained(model_id, subfolder="text_encoder_2")
@@ -674,7 +674,7 @@ def tearDown(self):
674674
del self.pipeline_4bit
675675

676676
gc.collect()
677-
torch.cuda.empty_cache()
677+
backend_empty_cache(torch_device)
678678

679679
def test_quality(self):
680680
# keep the resolution and max tokens to a lower number for faster execution.
@@ -722,7 +722,7 @@ def test_lora_loading(self):
722722
class SlowBnb4BitFluxControlWithLoraTests(Base4bitTests):
723723
def setUp(self) -> None:
724724
gc.collect()
725-
torch.cuda.empty_cache()
725+
backend_empty_cache(torch_device)
726726

727727
self.pipeline_4bit = FluxControlPipeline.from_pretrained("eramth/flux-4bit", torch_dtype=torch.float16)
728728
self.pipeline_4bit.enable_model_cpu_offload()
@@ -731,7 +731,7 @@ def tearDown(self):
731731
del self.pipeline_4bit
732732

733733
gc.collect()
734-
torch.cuda.empty_cache()
734+
backend_empty_cache(torch_device)
735735

736736
def test_lora_loading(self):
737737
self.pipeline_4bit.load_lora_weights("black-forest-labs/FLUX.1-Canny-dev-lora")

0 commit comments

Comments
 (0)