22
22
from diffusers .pipelines .pipeline_utils import DiffusionPipeline
23
23
from diffusers .utils import get_logger
24
24
from diffusers .utils .import_utils import compare_versions
25
- from diffusers .utils .testing_utils import require_torch_gpu , torch_device
25
+ from diffusers .utils .testing_utils import (
26
+ backend_empty_cache ,
27
+ backend_max_memory_allocated ,
28
+ backend_reset_peak_memory_stats ,
29
+ require_torch_accelerator ,
30
+ torch_device ,
31
+ )
26
32
27
33
28
34
class DummyBlock (torch .nn .Module ):
@@ -107,7 +113,7 @@ def __call__(self, x: torch.Tensor) -> torch.Tensor:
107
113
return x
108
114
109
115
110
- @require_torch_gpu
116
+ @require_torch_accelerator
111
117
class GroupOffloadTests (unittest .TestCase ):
112
118
in_features = 64
113
119
hidden_features = 256
@@ -125,8 +131,8 @@ def tearDown(self):
125
131
del self .model
126
132
del self .input
127
133
gc .collect ()
128
- torch . cuda . empty_cache ( )
129
- torch . cuda . reset_peak_memory_stats ( )
134
+ backend_empty_cache ( torch_device )
135
+ backend_reset_peak_memory_stats ( torch_device )
130
136
131
137
def get_model (self ):
132
138
torch .manual_seed (0 )
@@ -141,8 +147,8 @@ def test_offloading_forward_pass(self):
141
147
@torch .no_grad ()
142
148
def run_forward (model ):
143
149
gc .collect ()
144
- torch . cuda . empty_cache ( )
145
- torch . cuda . reset_peak_memory_stats ( )
150
+ backend_empty_cache ( torch_device )
151
+ backend_reset_peak_memory_stats ( torch_device )
146
152
self .assertTrue (
147
153
all (
148
154
module ._diffusers_hook .get_hook ("group_offloading" ) is not None
@@ -152,7 +158,7 @@ def run_forward(model):
152
158
)
153
159
model .eval ()
154
160
output = model (self .input )[0 ].cpu ()
155
- max_memory_allocated = torch . cuda . max_memory_allocated ( )
161
+ max_memory_allocated = backend_max_memory_allocated ( torch_device )
156
162
return output , max_memory_allocated
157
163
158
164
self .model .to (torch_device )
@@ -187,10 +193,10 @@ def run_forward(model):
187
193
self .assertTrue (torch .allclose (output_without_group_offloading , output_with_group_offloading5 , atol = 1e-5 ))
188
194
189
195
# Memory assertions - offloading should reduce memory usage
190
- self .assertTrue (mem4 <= mem5 < mem2 < mem3 < mem1 < mem_baseline )
196
+ self .assertTrue (mem4 <= mem5 < mem2 <= mem3 < mem1 < mem_baseline )
191
197
192
- def test_warning_logged_if_group_offloaded_module_moved_to_cuda (self ):
193
- if torch .device (torch_device ).type != "cuda" :
198
+ def test_warning_logged_if_group_offloaded_module_moved_to_accelerator (self ):
199
+ if torch .device (torch_device ).type not in [ "cuda" , "xpu" ] :
194
200
return
195
201
self .model .enable_group_offload (torch_device , offload_type = "block_level" , num_blocks_per_group = 3 )
196
202
logger = get_logger ("diffusers.models.modeling_utils" )
@@ -199,8 +205,8 @@ def test_warning_logged_if_group_offloaded_module_moved_to_cuda(self):
199
205
self .model .to (torch_device )
200
206
self .assertIn (f"The module '{ self .model .__class__ .__name__ } ' is group offloaded" , cm .output [0 ])
201
207
202
- def test_warning_logged_if_group_offloaded_pipe_moved_to_cuda (self ):
203
- if torch .device (torch_device ).type != "cuda" :
208
+ def test_warning_logged_if_group_offloaded_pipe_moved_to_accelerator (self ):
209
+ if torch .device (torch_device ).type not in [ "cuda" , "xpu" ] :
204
210
return
205
211
pipe = DummyPipeline (self .model )
206
212
self .model .enable_group_offload (torch_device , offload_type = "block_level" , num_blocks_per_group = 3 )
@@ -210,19 +216,20 @@ def test_warning_logged_if_group_offloaded_pipe_moved_to_cuda(self):
210
216
pipe .to (torch_device )
211
217
self .assertIn (f"The module '{ self .model .__class__ .__name__ } ' is group offloaded" , cm .output [0 ])
212
218
213
- def test_error_raised_if_streams_used_and_no_cuda_device (self ):
214
- original_is_available = torch .cuda .is_available
215
- torch .cuda .is_available = lambda : False
219
+ def test_error_raised_if_streams_used_and_no_accelerator_device (self ):
220
+ torch_accelerator_module = getattr (torch , torch_device , torch .cuda )
221
+ original_is_available = torch_accelerator_module .is_available
222
+ torch_accelerator_module .is_available = lambda : False
216
223
with self .assertRaises (ValueError ):
217
224
self .model .enable_group_offload (
218
- onload_device = torch .device ("cuda" ), offload_type = "leaf_level" , use_stream = True
225
+ onload_device = torch .device (torch_device ), offload_type = "leaf_level" , use_stream = True
219
226
)
220
- torch . cuda .is_available = original_is_available
227
+ torch_accelerator_module .is_available = original_is_available
221
228
222
229
def test_error_raised_if_supports_group_offloading_false (self ):
223
230
self .model ._supports_group_offloading = False
224
231
with self .assertRaisesRegex (ValueError , "does not support group offloading" ):
225
- self .model .enable_group_offload (onload_device = torch .device ("cuda" ))
232
+ self .model .enable_group_offload (onload_device = torch .device (torch_device ))
226
233
227
234
def test_error_raised_if_model_offloading_applied_on_group_offloaded_module (self ):
228
235
pipe = DummyPipeline (self .model )
@@ -249,7 +256,7 @@ def test_error_raised_if_group_offloading_applied_on_sequential_offloaded_module
249
256
pipe .model .enable_group_offload (torch_device , offload_type = "block_level" , num_blocks_per_group = 3 )
250
257
251
258
def test_block_level_stream_with_invocation_order_different_from_initialization_order (self ):
252
- if torch .device (torch_device ).type != "cuda" :
259
+ if torch .device (torch_device ).type not in [ "cuda" , "xpu" ] :
253
260
return
254
261
model = DummyModelWithMultipleBlocks (
255
262
in_features = self .in_features ,
0 commit comments