26
26
from diffusers .utils import is_accelerate_version , logging
27
27
from diffusers .utils .testing_utils import (
28
28
CaptureLogger ,
29
+ backend_empty_cache ,
29
30
is_bitsandbytes_available ,
30
31
is_torch_available ,
31
32
is_transformers_available ,
35
36
require_bitsandbytes_version_greater ,
36
37
require_peft_backend ,
37
38
require_torch ,
38
- require_torch_gpu ,
39
+ require_torch_accelerator ,
39
40
require_transformers_version_greater ,
40
41
slow ,
41
42
torch_device ,
@@ -66,7 +67,7 @@ def get_some_linear_layer(model):
66
67
@require_bitsandbytes_version_greater ("0.43.2" )
67
68
@require_accelerate
68
69
@require_torch
69
- @require_torch_gpu
70
+ @require_torch_accelerator
70
71
@slow
71
72
class Base4bitTests (unittest .TestCase ):
72
73
# We need to test on relatively large models (aka >1b parameters otherwise the quantiztion may not work as expected)
@@ -84,13 +85,16 @@ class Base4bitTests(unittest.TestCase):
84
85
85
86
def get_dummy_inputs (self ):
86
87
prompt_embeds = load_pt (
87
- "https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/prompt_embeds.pt"
88
+ "https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/prompt_embeds.pt" ,
89
+ torch_device ,
88
90
)
89
91
pooled_prompt_embeds = load_pt (
90
- "https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/pooled_prompt_embeds.pt"
92
+ "https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/pooled_prompt_embeds.pt" ,
93
+ torch_device ,
91
94
)
92
95
latent_model_input = load_pt (
93
- "https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/latent_model_input.pt"
96
+ "https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/latent_model_input.pt" ,
97
+ torch_device ,
94
98
)
95
99
96
100
input_dict_for_transformer = {
@@ -106,7 +110,7 @@ def get_dummy_inputs(self):
106
110
class BnB4BitBasicTests (Base4bitTests ):
107
111
def setUp (self ):
108
112
gc .collect ()
109
- torch . cuda . empty_cache ( )
113
+ backend_empty_cache ( torch_device )
110
114
111
115
# Models
112
116
self .model_fp16 = SD3Transformer2DModel .from_pretrained (
@@ -128,7 +132,7 @@ def tearDown(self):
128
132
del self .model_4bit
129
133
130
134
gc .collect ()
131
- torch . cuda . empty_cache ( )
135
+ backend_empty_cache ( torch_device )
132
136
133
137
def test_quantization_num_parameters (self ):
134
138
r"""
@@ -224,7 +228,7 @@ def test_keep_modules_in_fp32(self):
224
228
self .assertTrue (module .weight .dtype == torch .uint8 )
225
229
226
230
# test if inference works.
227
- with torch .no_grad () and torch .amp .autocast ("cuda" , dtype = torch .float16 ):
231
+ with torch .no_grad () and torch .amp .autocast (torch_device , dtype = torch .float16 ):
228
232
input_dict_for_transformer = self .get_dummy_inputs ()
229
233
model_inputs = {
230
234
k : v .to (device = torch_device ) for k , v in input_dict_for_transformer .items () if not isinstance (v , bool )
@@ -266,9 +270,9 @@ def test_device_assignment(self):
266
270
self .assertAlmostEqual (self .model_4bit .get_memory_footprint (), mem_before )
267
271
268
272
# Move back to CUDA device
269
- for device in [0 , "cuda " , "cuda :0" , "call()" ]:
273
+ for device in [0 , f" { torch_device } " , f" { torch_device } :0" , "call()" ]:
270
274
if device == "call()" :
271
- self .model_4bit .cuda ( 0 )
275
+ self .model_4bit .to ( f" { torch_device } :0" )
272
276
else :
273
277
self .model_4bit .to (device )
274
278
self .assertEqual (self .model_4bit .device , torch .device (0 ))
@@ -286,7 +290,7 @@ def test_device_and_dtype_assignment(self):
286
290
287
291
with self .assertRaises (ValueError ):
288
292
# Tries with a `device` and `dtype`
289
- self .model_4bit .to (device = "cuda :0" , dtype = torch .float16 )
293
+ self .model_4bit .to (device = f" { torch_device } :0" , dtype = torch .float16 )
290
294
291
295
with self .assertRaises (ValueError ):
292
296
# Tries with a cast
@@ -297,7 +301,7 @@ def test_device_and_dtype_assignment(self):
297
301
self .model_4bit .half ()
298
302
299
303
# This should work
300
- self .model_4bit .to ("cuda" )
304
+ self .model_4bit .to (torch_device )
301
305
302
306
# Test if we did not break anything
303
307
self .model_fp16 = self .model_fp16 .to (dtype = torch .float32 , device = torch_device )
@@ -321,7 +325,7 @@ def test_device_and_dtype_assignment(self):
321
325
_ = self .model_fp16 .float ()
322
326
323
327
# Check that this does not throw an error
324
- _ = self .model_fp16 .cuda ( )
328
+ _ = self .model_fp16 .to ( torch_device )
325
329
326
330
def test_bnb_4bit_wrong_config (self ):
327
331
r"""
@@ -398,7 +402,7 @@ def test_training(self):
398
402
model_inputs .update ({k : v for k , v in input_dict_for_transformer .items () if k not in model_inputs })
399
403
400
404
# Step 4: Check if the gradient is not None
401
- with torch .amp .autocast ("cuda" , dtype = torch .float16 ):
405
+ with torch .amp .autocast (torch_device , dtype = torch .float16 ):
402
406
out = self .model_4bit (** model_inputs )[0 ]
403
407
out .norm ().backward ()
404
408
@@ -412,7 +416,7 @@ def test_training(self):
412
416
class SlowBnb4BitTests (Base4bitTests ):
413
417
def setUp (self ) -> None :
414
418
gc .collect ()
415
- torch . cuda . empty_cache ( )
419
+ backend_empty_cache ( torch_device )
416
420
417
421
nf4_config = BitsAndBytesConfig (
418
422
load_in_4bit = True ,
@@ -431,7 +435,7 @@ def tearDown(self):
431
435
del self .pipeline_4bit
432
436
433
437
gc .collect ()
434
- torch . cuda . empty_cache ( )
438
+ backend_empty_cache ( torch_device )
435
439
436
440
def test_quality (self ):
437
441
output = self .pipeline_4bit (
@@ -501,7 +505,7 @@ def test_moving_to_cpu_throws_warning(self):
501
505
reason = "Test will pass after https://github.com/huggingface/accelerate/pull/3223 is in a release." ,
502
506
strict = True ,
503
507
)
504
- def test_pipeline_cuda_placement_works_with_nf4 (self ):
508
+ def test_pipeline_device_placement_works_with_nf4 (self ):
505
509
transformer_nf4_config = BitsAndBytesConfig (
506
510
load_in_4bit = True ,
507
511
bnb_4bit_quant_type = "nf4" ,
@@ -532,7 +536,7 @@ def test_pipeline_cuda_placement_works_with_nf4(self):
532
536
transformer = transformer_4bit ,
533
537
text_encoder_3 = text_encoder_3_4bit ,
534
538
torch_dtype = torch .float16 ,
535
- ).to ("cuda" )
539
+ ).to (torch_device )
536
540
537
541
# Check if inference works.
538
542
_ = pipeline_4bit ("table" , max_sequence_length = 20 , num_inference_steps = 2 )
@@ -696,7 +700,7 @@ def test_lora_loading(self):
696
700
class BaseBnb4BitSerializationTests (Base4bitTests ):
697
701
def tearDown (self ):
698
702
gc .collect ()
699
- torch . cuda . empty_cache ( )
703
+ backend_empty_cache ( torch_device )
700
704
701
705
def test_serialization (self , quant_type = "nf4" , double_quant = True , safe_serialization = True ):
702
706
r"""
0 commit comments