Skip to content

Commit a7aa8bf

Browse files
enable group_offloading and PipelineDeviceAndDtypeStabilityTests on XPU, all passed (huggingface#11620)
* enable group_offloading and PipelineDeviceAndDtypeStabilityTests on XPU, all passed Signed-off-by: Matrix YAO <matrix.yao@intel.com> * fix style Signed-off-by: Matrix YAO <matrix.yao@intel.com> * fix Signed-off-by: Matrix YAO <matrix.yao@intel.com> --------- Signed-off-by: Matrix YAO <matrix.yao@intel.com> Co-authored-by: Aryan <aryan@huggingface.co>
1 parent 3651bdb commit a7aa8bf

File tree

2 files changed

+31
-24
lines changed

2 files changed

+31
-24
lines changed

tests/hooks/test_group_offloading.py

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,13 @@
2222
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
2323
from diffusers.utils import get_logger
2424
from diffusers.utils.import_utils import compare_versions
25-
from diffusers.utils.testing_utils import require_torch_gpu, torch_device
25+
from diffusers.utils.testing_utils import (
26+
backend_empty_cache,
27+
backend_max_memory_allocated,
28+
backend_reset_peak_memory_stats,
29+
require_torch_accelerator,
30+
torch_device,
31+
)
2632

2733

2834
class DummyBlock(torch.nn.Module):
@@ -107,7 +113,7 @@ def __call__(self, x: torch.Tensor) -> torch.Tensor:
107113
return x
108114

109115

110-
@require_torch_gpu
116+
@require_torch_accelerator
111117
class GroupOffloadTests(unittest.TestCase):
112118
in_features = 64
113119
hidden_features = 256
@@ -125,8 +131,8 @@ def tearDown(self):
125131
del self.model
126132
del self.input
127133
gc.collect()
128-
torch.cuda.empty_cache()
129-
torch.cuda.reset_peak_memory_stats()
134+
backend_empty_cache(torch_device)
135+
backend_reset_peak_memory_stats(torch_device)
130136

131137
def get_model(self):
132138
torch.manual_seed(0)
@@ -141,8 +147,8 @@ def test_offloading_forward_pass(self):
141147
@torch.no_grad()
142148
def run_forward(model):
143149
gc.collect()
144-
torch.cuda.empty_cache()
145-
torch.cuda.reset_peak_memory_stats()
150+
backend_empty_cache(torch_device)
151+
backend_reset_peak_memory_stats(torch_device)
146152
self.assertTrue(
147153
all(
148154
module._diffusers_hook.get_hook("group_offloading") is not None
@@ -152,7 +158,7 @@ def run_forward(model):
152158
)
153159
model.eval()
154160
output = model(self.input)[0].cpu()
155-
max_memory_allocated = torch.cuda.max_memory_allocated()
161+
max_memory_allocated = backend_max_memory_allocated(torch_device)
156162
return output, max_memory_allocated
157163

158164
self.model.to(torch_device)
@@ -187,10 +193,10 @@ def run_forward(model):
187193
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading5, atol=1e-5))
188194

189195
# Memory assertions - offloading should reduce memory usage
190-
self.assertTrue(mem4 <= mem5 < mem2 < mem3 < mem1 < mem_baseline)
196+
self.assertTrue(mem4 <= mem5 < mem2 <= mem3 < mem1 < mem_baseline)
191197

192-
def test_warning_logged_if_group_offloaded_module_moved_to_cuda(self):
193-
if torch.device(torch_device).type != "cuda":
198+
def test_warning_logged_if_group_offloaded_module_moved_to_accelerator(self):
199+
if torch.device(torch_device).type not in ["cuda", "xpu"]:
194200
return
195201
self.model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=3)
196202
logger = get_logger("diffusers.models.modeling_utils")
@@ -199,8 +205,8 @@ def test_warning_logged_if_group_offloaded_module_moved_to_cuda(self):
199205
self.model.to(torch_device)
200206
self.assertIn(f"The module '{self.model.__class__.__name__}' is group offloaded", cm.output[0])
201207

202-
def test_warning_logged_if_group_offloaded_pipe_moved_to_cuda(self):
203-
if torch.device(torch_device).type != "cuda":
208+
def test_warning_logged_if_group_offloaded_pipe_moved_to_accelerator(self):
209+
if torch.device(torch_device).type not in ["cuda", "xpu"]:
204210
return
205211
pipe = DummyPipeline(self.model)
206212
self.model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=3)
@@ -210,19 +216,20 @@ def test_warning_logged_if_group_offloaded_pipe_moved_to_cuda(self):
210216
pipe.to(torch_device)
211217
self.assertIn(f"The module '{self.model.__class__.__name__}' is group offloaded", cm.output[0])
212218

213-
def test_error_raised_if_streams_used_and_no_cuda_device(self):
214-
original_is_available = torch.cuda.is_available
215-
torch.cuda.is_available = lambda: False
219+
def test_error_raised_if_streams_used_and_no_accelerator_device(self):
220+
torch_accelerator_module = getattr(torch, torch_device, torch.cuda)
221+
original_is_available = torch_accelerator_module.is_available
222+
torch_accelerator_module.is_available = lambda: False
216223
with self.assertRaises(ValueError):
217224
self.model.enable_group_offload(
218-
onload_device=torch.device("cuda"), offload_type="leaf_level", use_stream=True
225+
onload_device=torch.device(torch_device), offload_type="leaf_level", use_stream=True
219226
)
220-
torch.cuda.is_available = original_is_available
227+
torch_accelerator_module.is_available = original_is_available
221228

222229
def test_error_raised_if_supports_group_offloading_false(self):
223230
self.model._supports_group_offloading = False
224231
with self.assertRaisesRegex(ValueError, "does not support group offloading"):
225-
self.model.enable_group_offload(onload_device=torch.device("cuda"))
232+
self.model.enable_group_offload(onload_device=torch.device(torch_device))
226233

227234
def test_error_raised_if_model_offloading_applied_on_group_offloaded_module(self):
228235
pipe = DummyPipeline(self.model)
@@ -249,7 +256,7 @@ def test_error_raised_if_group_offloading_applied_on_sequential_offloaded_module
249256
pipe.model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=3)
250257

251258
def test_block_level_stream_with_invocation_order_different_from_initialization_order(self):
252-
if torch.device(torch_device).type != "cuda":
259+
if torch.device(torch_device).type not in ["cuda", "xpu"]:
253260
return
254261
model = DummyModelWithMultipleBlocks(
255262
in_features=self.in_features,

tests/pipelines/test_pipeline_utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
UNet2DConditionModel,
2020
)
2121
from diffusers.pipelines.pipeline_loading_utils import is_safetensors_compatible, variant_compatible_siblings
22-
from diffusers.utils.testing_utils import require_torch_gpu, torch_device
22+
from diffusers.utils.testing_utils import require_torch_accelerator, torch_device
2323

2424

2525
class IsSafetensorsCompatibleTests(unittest.TestCase):
@@ -850,9 +850,9 @@ def test_video_to_video(self):
850850
self.assertTrue(stderr.getvalue() == "", "Progress bar should be disabled")
851851

852852

853-
@require_torch_gpu
853+
@require_torch_accelerator
854854
class PipelineDeviceAndDtypeStabilityTests(unittest.TestCase):
855-
expected_pipe_device = torch.device("cuda:0")
855+
expected_pipe_device = torch.device(f"{torch_device}:0")
856856
expected_pipe_dtype = torch.float64
857857

858858
def get_dummy_components_image_generation(self):
@@ -921,8 +921,8 @@ def test_deterministic_device(self):
921921
pipe.to(device=torch_device, dtype=torch.float32)
922922

923923
pipe.unet.to(device="cpu")
924-
pipe.vae.to(device="cuda")
925-
pipe.text_encoder.to(device="cuda:0")
924+
pipe.vae.to(device=torch_device)
925+
pipe.text_encoder.to(device=f"{torch_device}:0")
926926

927927
pipe_device = pipe.device
928928

0 commit comments

Comments
 (0)