Skip to content

Commit 15ad97f

Browse files
authored
[tests] make cuda only tests device-agnostic (#11058)
* enable bnb on xpu * add 2 more cases * add missing change * add missing change * add one more * enable cuda only tests on xpu * enable big gpu cases
1 parent 9f2d5c9 commit 15ad97f

17 files changed

+93
-44
lines changed

src/diffusers/loaders/textual_inversion.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -449,9 +449,9 @@ def load_textual_inversion(
449449

450450
# 7.5 Offload the model again
451451
if is_model_cpu_offload:
452-
self.enable_model_cpu_offload()
452+
self.enable_model_cpu_offload(device=device)
453453
elif is_sequential_cpu_offload:
454-
self.enable_sequential_cpu_offload()
454+
self.enable_sequential_cpu_offload(device=device)
455455

456456
# / Unsafe Code >
457457

src/diffusers/utils/testing_utils.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,21 @@ def require_torch_multi_gpu(test_case):
320320
return unittest.skipUnless(torch.cuda.device_count() > 1, "test requires multiple GPUs")(test_case)
321321

322322

323+
def require_torch_multi_accelerator(test_case):
324+
"""
325+
Decorator marking a test that requires a multi-accelerator setup (in PyTorch). These tests are skipped on a machine
326+
without multiple hardware accelerators.
327+
"""
328+
if not is_torch_available():
329+
return unittest.skip("test requires PyTorch")(test_case)
330+
331+
import torch
332+
333+
return unittest.skipUnless(
334+
torch.cuda.device_count() > 1 or torch.xpu.device_count() > 1, "test requires multiple hardware accelerators"
335+
)(test_case)
336+
337+
323338
def require_torch_accelerator_with_fp16(test_case):
324339
"""Decorator marking a test that requires an accelerator with support for the FP16 data type."""
325340
return unittest.skipUnless(_is_torch_fp16_available(torch_device), "test requires accelerator with fp16 support")(
@@ -354,6 +369,31 @@ def require_big_gpu_with_torch_cuda(test_case):
354369
)(test_case)
355370

356371

372+
def require_big_accelerator(test_case):
373+
"""
374+
Decorator marking a test that requires a bigger hardware accelerator (24GB) for execution. Some example pipelines:
375+
Flux, SD3, Cog, etc.
376+
"""
377+
if not is_torch_available():
378+
return unittest.skip("test requires PyTorch")(test_case)
379+
380+
import torch
381+
382+
if not (torch.cuda.is_available() or torch.xpu.is_available()):
383+
return unittest.skip("test requires PyTorch CUDA")(test_case)
384+
385+
if torch.xpu.is_available():
386+
device_properties = torch.xpu.get_device_properties(0)
387+
else:
388+
device_properties = torch.cuda.get_device_properties(0)
389+
390+
total_memory = device_properties.total_memory / (1024**3)
391+
return unittest.skipUnless(
392+
total_memory >= BIG_GPU_MEMORY,
393+
f"test requires a hardware accelerator with at least {BIG_GPU_MEMORY} GB memory",
394+
)(test_case)
395+
396+
357397
def require_torch_accelerator_with_training(test_case):
358398
"""Decorator marking a test that requires an accelerator with support for training."""
359399
return unittest.skipUnless(

tests/models/autoencoders/test_models_asymmetric_autoencoder_kl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def get_sd_vae_model(self, model_id="cross-attention/asymmetric-autoencoder-kl-x
124124
return model
125125

126126
def get_generator(self, seed=0):
127-
generator_device = "cpu" if not torch_device.startswith("cuda") else "cuda"
127+
generator_device = "cpu" if not torch_device.startswith(torch_device) else torch_device
128128
if torch_device != "mps":
129129
return torch.Generator(device=generator_device).manual_seed(seed)
130130
return torch.manual_seed(seed)

tests/models/autoencoders/test_models_autoencoder_kl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ def test_output_pretrained(self):
165165
model.eval()
166166

167167
# Keep generator on CPU for non-CUDA devices to compare outputs with CPU result tensors
168-
generator_device = "cpu" if not torch_device.startswith("cuda") else "cuda"
168+
generator_device = "cpu" if not torch_device.startswith(torch_device) else torch_device
169169
if torch_device != "mps":
170170
generator = torch.Generator(device=generator_device).manual_seed(0)
171171
else:
@@ -263,7 +263,7 @@ def get_sd_vae_model(self, model_id="CompVis/stable-diffusion-v1-4", fp16=False)
263263
return model
264264

265265
def get_generator(self, seed=0):
266-
generator_device = "cpu" if not torch_device.startswith("cuda") else "cuda"
266+
generator_device = "cpu" if not torch_device.startswith(torch_device) else torch_device
267267
if torch_device != "mps":
268268
return torch.Generator(device=generator_device).manual_seed(seed)
269269
return torch.manual_seed(seed)

tests/models/autoencoders/test_models_autoencoder_oobleck.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def get_oobleck_vae_model(self, model_id="stabilityai/stable-audio-open-1.0", fp
183183
return model
184184

185185
def get_generator(self, seed=0):
186-
generator_device = "cpu" if not torch_device.startswith("cuda") else "cuda"
186+
generator_device = "cpu" if not torch_device.startswith(torch_device) else torch_device
187187
if torch_device != "mps":
188188
return torch.Generator(device=generator_device).manual_seed(seed)
189189
return torch.manual_seed(seed)

tests/models/test_modeling_common.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@
6363
require_torch_accelerator,
6464
require_torch_accelerator_with_training,
6565
require_torch_gpu,
66-
require_torch_multi_gpu,
66+
require_torch_multi_accelerator,
6767
run_test_in_subprocess,
6868
torch_all_close,
6969
torch_device,
@@ -1227,7 +1227,7 @@ def test_disk_offload_with_safetensors(self):
12271227

12281228
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
12291229

1230-
@require_torch_multi_gpu
1230+
@require_torch_multi_accelerator
12311231
def test_model_parallelism(self):
12321232
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
12331233
model = self.model_class(**config).eval()

tests/pipelines/controlnet_sd3/test_controlnet_sd3.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,10 @@
3131
from diffusers.models import SD3ControlNetModel, SD3MultiControlNetModel
3232
from diffusers.utils import load_image
3333
from diffusers.utils.testing_utils import (
34+
backend_empty_cache,
3435
enable_full_determinism,
3536
numpy_cosine_similarity_distance,
36-
require_big_gpu_with_torch_cuda,
37+
require_big_accelerator,
3738
slow,
3839
torch_device,
3940
)
@@ -219,20 +220,20 @@ def test_xformers_attention_forwardGenerator_pass(self):
219220

220221

221222
@slow
222-
@require_big_gpu_with_torch_cuda
223+
@require_big_accelerator
223224
@pytest.mark.big_gpu_with_torch_cuda
224225
class StableDiffusion3ControlNetPipelineSlowTests(unittest.TestCase):
225226
pipeline_class = StableDiffusion3ControlNetPipeline
226227

227228
def setUp(self):
228229
super().setUp()
229230
gc.collect()
230-
torch.cuda.empty_cache()
231+
backend_empty_cache(torch_device)
231232

232233
def tearDown(self):
233234
super().tearDown()
234235
gc.collect()
235-
torch.cuda.empty_cache()
236+
backend_empty_cache(torch_device)
236237

237238
def test_canny(self):
238239
controlnet = SD3ControlNetModel.from_pretrained("InstantX/SD3-Controlnet-Canny", torch_dtype=torch.float16)
@@ -272,7 +273,7 @@ def test_pose(self):
272273
pipe = StableDiffusion3ControlNetPipeline.from_pretrained(
273274
"stabilityai/stable-diffusion-3-medium-diffusers", controlnet=controlnet, torch_dtype=torch.float16
274275
)
275-
pipe.enable_model_cpu_offload()
276+
pipe.enable_model_cpu_offload(device=torch_device)
276277
pipe.set_progress_bar_config(disable=None)
277278

278279
generator = torch.Generator(device="cpu").manual_seed(0)
@@ -304,7 +305,7 @@ def test_tile(self):
304305
pipe = StableDiffusion3ControlNetPipeline.from_pretrained(
305306
"stabilityai/stable-diffusion-3-medium-diffusers", controlnet=controlnet, torch_dtype=torch.float16
306307
)
307-
pipe.enable_model_cpu_offload()
308+
pipe.enable_model_cpu_offload(device=torch_device)
308309
pipe.set_progress_bar_config(disable=None)
309310

310311
generator = torch.Generator(device="cpu").manual_seed(0)
@@ -338,7 +339,7 @@ def test_multi_controlnet(self):
338339
pipe = StableDiffusion3ControlNetPipeline.from_pretrained(
339340
"stabilityai/stable-diffusion-3-medium-diffusers", controlnet=controlnet, torch_dtype=torch.float16
340341
)
341-
pipe.enable_model_cpu_offload()
342+
pipe.enable_model_cpu_offload(device=torch_device)
342343
pipe.set_progress_bar_config(disable=None)
343344

344345
generator = torch.Generator(device="cpu").manual_seed(0)

tests/pipelines/flux/test_pipeline_flux.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
backend_empty_cache,
1313
nightly,
1414
numpy_cosine_similarity_distance,
15-
require_big_gpu_with_torch_cuda,
15+
require_big_accelerator,
1616
slow,
1717
torch_device,
1818
)
@@ -204,7 +204,7 @@ def test_flux_true_cfg(self):
204204

205205

206206
@nightly
207-
@require_big_gpu_with_torch_cuda
207+
@require_big_accelerator
208208
@pytest.mark.big_gpu_with_torch_cuda
209209
class FluxPipelineSlowTests(unittest.TestCase):
210210
pipeline_class = FluxPipeline
@@ -292,7 +292,7 @@ def test_flux_inference(self):
292292

293293

294294
@slow
295-
@require_big_gpu_with_torch_cuda
295+
@require_big_accelerator
296296
@pytest.mark.big_gpu_with_torch_cuda
297297
class FluxIPAdapterPipelineSlowTests(unittest.TestCase):
298298
pipeline_class = FluxPipeline
@@ -304,12 +304,12 @@ class FluxIPAdapterPipelineSlowTests(unittest.TestCase):
304304
def setUp(self):
305305
super().setUp()
306306
gc.collect()
307-
torch.cuda.empty_cache()
307+
backend_empty_cache(torch_device)
308308

309309
def tearDown(self):
310310
super().tearDown()
311311
gc.collect()
312-
torch.cuda.empty_cache()
312+
backend_empty_cache(torch_device)
313313

314314
def get_inputs(self, device, seed=0):
315315
if str(device).startswith("mps"):

tests/pipelines/flux/test_pipeline_flux_redux.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,16 @@
88
from diffusers import FluxPipeline, FluxPriorReduxPipeline
99
from diffusers.utils import load_image
1010
from diffusers.utils.testing_utils import (
11+
backend_empty_cache,
1112
numpy_cosine_similarity_distance,
12-
require_big_gpu_with_torch_cuda,
13+
require_big_accelerator,
1314
slow,
1415
torch_device,
1516
)
1617

1718

1819
@slow
19-
@require_big_gpu_with_torch_cuda
20+
@require_big_accelerator
2021
@pytest.mark.big_gpu_with_torch_cuda
2122
class FluxReduxSlowTests(unittest.TestCase):
2223
pipeline_class = FluxPriorReduxPipeline
@@ -27,12 +28,12 @@ class FluxReduxSlowTests(unittest.TestCase):
2728
def setUp(self):
2829
super().setUp()
2930
gc.collect()
30-
torch.cuda.empty_cache()
31+
backend_empty_cache(torch_device)
3132

3233
def tearDown(self):
3334
super().tearDown()
3435
gc.collect()
35-
torch.cuda.empty_cache()
36+
backend_empty_cache(torch_device)
3637

3738
def get_inputs(self, device, seed=0):
3839
init_image = load_image(
@@ -59,7 +60,7 @@ def test_flux_redux_inference(self):
5960
self.base_repo_id, torch_dtype=torch.bfloat16, text_encoder=None, text_encoder_2=None
6061
)
6162
pipe_redux.to(torch_device)
62-
pipe_base.enable_model_cpu_offload()
63+
pipe_base.enable_model_cpu_offload(device=torch_device)
6364

6465
inputs = self.get_inputs(torch_device)
6566
base_pipeline_inputs = self.get_base_pipeline_inputs(torch_device)

tests/pipelines/pag/test_pag_sd3_img2img.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ def test_pag_uncond(self):
262262
pipeline = AutoPipelineForImage2Image.from_pretrained(
263263
self.repo_id, enable_pag=True, torch_dtype=torch.float16, pag_applied_layers=["blocks.(4|17)"]
264264
)
265-
pipeline.enable_model_cpu_offload()
265+
pipeline.enable_model_cpu_offload(device=torch_device)
266266
pipeline.set_progress_bar_config(disable=None)
267267

268268
inputs = self.get_inputs(torch_device, guidance_scale=0.0, pag_scale=1.8)

0 commit comments

Comments
 (0)