Skip to content

enable 7 cases on XPU #11503

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
May 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions tests/pipelines/consisid/test_consisid.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@
from diffusers import AutoencoderKLCogVideoX, ConsisIDPipeline, ConsisIDTransformer3DModel, DDIMScheduler
from diffusers.utils import load_image
from diffusers.utils.testing_utils import (
backend_empty_cache,
enable_full_determinism,
numpy_cosine_similarity_distance,
require_torch_gpu,
require_torch_accelerator,
slow,
torch_device,
)
Expand Down Expand Up @@ -316,19 +317,19 @@ def test_vae_tiling(self, expected_diff_max: float = 0.4):


@slow
@require_torch_gpu
@require_torch_accelerator
class ConsisIDPipelineIntegrationTests(unittest.TestCase):
prompt = "A painting of a squirrel eating a burger."

def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)

def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)

def test_consisid(self):
generator = torch.Generator("cpu").manual_seed(0)
Expand All @@ -338,8 +339,8 @@ def test_consisid(self):

prompt = self.prompt
image = load_image("https://github.com/PKU-YuanGroup/ConsisID/blob/main/asserts/example_images/2.png?raw=true")
id_vit_hidden = [torch.ones([1, 2, 2])] * 1
id_cond = torch.ones(1, 2)
id_vit_hidden = [torch.ones([1, 577, 1024])] * 5
id_cond = torch.ones(1, 1280)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the size of the id_vit_hidden and id_cond is wrong, both A100 and XPU will report below error

def forward(self, input: Tensor) -> Tensor:
  return F.linear(input, self.weight, self.bias)
    E       RuntimeError: mat1 and mat2 shapes cannot be multiplied (1x2 and 1280x1024)

I checked the needed size for these 2 tensors and use the correct tensor


videos = pipe(
image=image,
Expand All @@ -357,5 +358,5 @@ def test_consisid(self):
video = videos[0]
expected_video = torch.randn(1, 16, 480, 720, 3).numpy()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

expected_video here is generated by RNG, so suppose this is a in-complete case, so for this PR, I enable the case to run on XPU, need re-check the numerical correctness after expected_video correctly set.


max_diff = numpy_cosine_similarity_distance(video, expected_video)
max_diff = numpy_cosine_similarity_distance(video.cpu(), expected_video)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need to cpu, since expected_video is in cpu, numpy_cosine_similarity_distance only supports 2 tensors in same device.

assert max_diff < 1e-3, f"Max diff is too high. got {video}"
9 changes: 5 additions & 4 deletions tests/pipelines/easyanimate/test_easyanimate.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,10 @@
FlowMatchEulerDiscreteScheduler,
)
from diffusers.utils.testing_utils import (
backend_empty_cache,
enable_full_determinism,
numpy_cosine_similarity_distance,
require_torch_gpu,
require_torch_accelerator,
slow,
torch_device,
)
Expand Down Expand Up @@ -256,19 +257,19 @@ def test_encode_prompt_works_in_isolation(self):


@slow
@require_torch_gpu
@require_torch_accelerator
class EasyAnimatePipelineIntegrationTests(unittest.TestCase):
prompt = "A painting of a squirrel eating a burger."

def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)

def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)

def test_EasyAnimate(self):
generator = torch.Generator("cpu").manual_seed(0)
Expand Down
12 changes: 6 additions & 6 deletions tests/pipelines/mochi/test_mochi.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
enable_full_determinism,
nightly,
numpy_cosine_similarity_distance,
require_big_gpu_with_torch_cuda,
require_torch_gpu,
require_big_accelerator,
require_torch_accelerator,
torch_device,
)

Expand Down Expand Up @@ -266,9 +266,9 @@ def test_vae_tiling(self, expected_diff_max: float = 0.2):


@nightly
@require_torch_gpu
@require_big_gpu_with_torch_cuda
@pytest.mark.big_gpu_with_torch_cuda
@require_torch_accelerator
@require_big_accelerator
@pytest.mark.big_accelerator
class MochiPipelineIntegrationTests(unittest.TestCase):
prompt = "A painting of a squirrel eating a burger."

Expand Down Expand Up @@ -302,5 +302,5 @@ def test_mochi(self):
video = videos[0]
expected_video = torch.randn(1, 19, 480, 848, 3).numpy()

max_diff = numpy_cosine_similarity_distance(video, expected_video)
max_diff = numpy_cosine_similarity_distance(video.cpu(), expected_video)
assert max_diff < 1e-3, f"Max diff is too high. got {video}"
73 changes: 55 additions & 18 deletions tests/pipelines/omnigen/test_pipeline_omnigen.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@

from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, OmniGenPipeline, OmniGenTransformer2DModel
from diffusers.utils.testing_utils import (
Expectations,
backend_empty_cache,
numpy_cosine_similarity_distance,
require_torch_gpu,
require_torch_accelerator,
slow,
torch_device,
)
Expand Down Expand Up @@ -87,20 +89,20 @@ def test_inference(self):


@slow
@require_torch_gpu
@require_torch_accelerator
class OmniGenPipelineSlowTests(unittest.TestCase):
pipeline_class = OmniGenPipeline
repo_id = "shitao/OmniGen-v1-diffusers"

def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)

def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)

def get_inputs(self, device, seed=0):
if str(device).startswith("mps"):
Expand All @@ -125,21 +127,56 @@ def test_omnigen_inference(self):
image = pipe(**inputs).images[0]
image_slice = image[0, :10, :10]

expected_slice = np.array(
[
[0.1783447, 0.16772744, 0.14339337],
[0.17066911, 0.15521264, 0.13757327],
[0.17072496, 0.15531206, 0.13524258],
[0.16746324, 0.1564025, 0.13794944],
[0.16490817, 0.15258026, 0.13697758],
[0.16971767, 0.15826806, 0.13928896],
[0.16782972, 0.15547255, 0.13783783],
[0.16464645, 0.15281534, 0.13522372],
[0.16535294, 0.15301755, 0.13526791],
[0.16365296, 0.15092957, 0.13443318],
],
dtype=np.float32,
expected_slices = Expectations(
{
("xpu", 3): np.array(
[
[0.05859375, 0.05859375, 0.04492188],
[0.04882812, 0.04101562, 0.03320312],
[0.04882812, 0.04296875, 0.03125],
[0.04296875, 0.0390625, 0.03320312],
[0.04296875, 0.03710938, 0.03125],
[0.04492188, 0.0390625, 0.03320312],
[0.04296875, 0.03710938, 0.03125],
[0.04101562, 0.03710938, 0.02734375],
[0.04101562, 0.03515625, 0.02734375],
[0.04101562, 0.03515625, 0.02929688],
],
dtype=np.float32,
),
("cuda", 7): np.array(
[
[0.1783447, 0.16772744, 0.14339337],
[0.17066911, 0.15521264, 0.13757327],
[0.17072496, 0.15531206, 0.13524258],
[0.16746324, 0.1564025, 0.13794944],
[0.16490817, 0.15258026, 0.13697758],
[0.16971767, 0.15826806, 0.13928896],
[0.16782972, 0.15547255, 0.13783783],
[0.16464645, 0.15281534, 0.13522372],
[0.16535294, 0.15301755, 0.13526791],
[0.16365296, 0.15092957, 0.13443318],
],
dtype=np.float32,
),
("cuda", 8): np.array(
[
[0.0546875, 0.05664062, 0.04296875],
[0.046875, 0.04101562, 0.03320312],
[0.05078125, 0.04296875, 0.03125],
[0.04296875, 0.04101562, 0.03320312],
[0.0390625, 0.03710938, 0.02929688],
[0.04296875, 0.03710938, 0.03125],
[0.0390625, 0.03710938, 0.02929688],
[0.0390625, 0.03710938, 0.02734375],
[0.0390625, 0.03320312, 0.02734375],
[0.0390625, 0.03320312, 0.02734375],
],
dtype=np.float32,
),
}
)
expected_slice = expected_slices.get_expectation()

max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten())

Expand Down
9 changes: 5 additions & 4 deletions tests/pipelines/paint_by_example/test_paint_by_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,12 @@
from diffusers import AutoencoderKL, PaintByExamplePipeline, PNDMScheduler, UNet2DConditionModel
from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder
from diffusers.utils.testing_utils import (
backend_empty_cache,
enable_full_determinism,
floats_tensor,
load_image,
nightly,
require_torch_gpu,
require_torch_accelerator,
torch_device,
)

Expand Down Expand Up @@ -174,19 +175,19 @@ def test_inference_batch_single_identical(self):


@nightly
@require_torch_gpu
@require_torch_accelerator
class PaintByExamplePipelineIntegrationTests(unittest.TestCase):
def setUp(self):
# clean up the VRAM before each test
super().setUp()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)

def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)

def test_paint_by_example(self):
# make sure here that pndm scheduler skips prk
Expand Down
27 changes: 20 additions & 7 deletions tests/pipelines/stable_audio/test_stable_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,14 @@
StableAudioProjectionModel,
)
from diffusers.utils import is_xformers_available
from diffusers.utils.testing_utils import enable_full_determinism, nightly, require_torch_gpu, torch_device
from diffusers.utils.testing_utils import (
Expectations,
backend_empty_cache,
enable_full_determinism,
nightly,
require_torch_accelerator,
torch_device,
)

from ..pipeline_params import TEXT_TO_AUDIO_BATCH_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
Expand Down Expand Up @@ -419,17 +426,17 @@ def test_encode_prompt_works_in_isolation(self):


@nightly
@require_torch_gpu
@require_torch_accelerator
class StableAudioPipelineIntegrationTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)

def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)

def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
generator = torch.Generator(device=generator_device).manual_seed(seed)
Expand Down Expand Up @@ -459,9 +466,15 @@ def test_stable_audio(self):
# check the portion of the generated audio with the largest dynamic range (reduces flakiness)
audio_slice = audio[0, 447590:447600]
# fmt: off
expected_slice = np.array(
[-0.0278, 0.1096, 0.1877, 0.3178, 0.5329, 0.6990, 0.6972, 0.6186, 0.5608, 0.5060]
expected_slices = Expectations(
{
("xpu", 3): np.array([-0.0285, 0.1083, 0.1863, 0.3165, 0.5312, 0.6971, 0.6958, 0.6177, 0.5598, 0.5048]),
("cuda", 7): np.array([-0.0278, 0.1096, 0.1877, 0.3178, 0.5329, 0.6990, 0.6972, 0.6186, 0.5608, 0.5060]),
("cuda", 8): np.array([-0.0285, 0.1082, 0.1862, 0.3163, 0.5306, 0.6964, 0.6953, 0.6172, 0.5593, 0.5044]),
}
)
# fmt: one
# fmt: on

expected_slice = expected_slices.get_expectation()
max_diff = np.abs(expected_slice - audio_slice.detach().cpu().numpy()).max()
assert max_diff < 1.5e-3
10 changes: 5 additions & 5 deletions tests/quantization/bnb/test_4bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ def test_bnb_4bit_logs_warning_for_no_quantization(self):
class BnB4BitTrainingTests(Base4bitTests):
def setUp(self):
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)

nf4_config = BitsAndBytesConfig(
load_in_4bit=True,
Expand Down Expand Up @@ -657,7 +657,7 @@ def get_dummy_tensor_inputs(device=None, seed: int = 0):
class SlowBnb4BitFluxTests(Base4bitTests):
def setUp(self) -> None:
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)

model_id = "hf-internal-testing/flux.1-dev-nf4-pkg"
t5_4bit = T5EncoderModel.from_pretrained(model_id, subfolder="text_encoder_2")
Expand All @@ -674,7 +674,7 @@ def tearDown(self):
del self.pipeline_4bit

gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)

def test_quality(self):
# keep the resolution and max tokens to a lower number for faster execution.
Expand Down Expand Up @@ -722,7 +722,7 @@ def test_lora_loading(self):
class SlowBnb4BitFluxControlWithLoraTests(Base4bitTests):
def setUp(self) -> None:
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)

self.pipeline_4bit = FluxControlPipeline.from_pretrained("eramth/flux-4bit", torch_dtype=torch.float16)
self.pipeline_4bit.enable_model_cpu_offload()
Expand All @@ -731,7 +731,7 @@ def tearDown(self):
del self.pipeline_4bit

gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)

def test_lora_loading(self):
self.pipeline_4bit.load_lora_weights("black-forest-labs/FLUX.1-Canny-dev-lora")
Expand Down
Loading