16
16
from compressed_tensors .utils import (
17
17
align_module_device ,
18
18
align_modules ,
19
+ delete_offload_module ,
19
20
delete_offload_parameter ,
20
21
disable_hf_hook ,
21
22
force_cpu_offload ,
@@ -344,9 +345,8 @@ def test_offload_to_weights_map():
344
345
345
346
@requires_gpu
346
347
@requires_accelerate ()
347
- def test_register_offload_module ():
348
- execution_device = torch .device ("cuda" )
349
-
348
+ @pytest .mark .parametrize ("exec_device" , [torch .device ("cpu" ), torch .device ("cuda" )])
349
+ def test_register_offload_module (exec_device ):
350
350
# no offloading
351
351
model = ExampleModel ()
352
352
child = torch .nn .Linear (2 , 3 )
@@ -358,37 +358,62 @@ def test_register_offload_module():
358
358
# with offloading
359
359
model = ExampleModel ()
360
360
child = torch .nn .Linear (2 , 3 )
361
- force_cpu_offload (model , execution_device )
361
+ force_cpu_offload (model , exec_device )
362
362
register_offload_module (model , "child" , child )
363
363
register_offload_module (model .linear , "child" , child )
364
364
assert child in model .children ()
365
365
assert child in model .linear .children ()
366
366
367
367
# can run modules
368
368
model (torch .empty (1 ))
369
- child (torch .empty (2 , device = execution_device ))
369
+ child (torch .empty (2 , device = exec_device ))
370
370
371
371
372
372
@requires_gpu
373
373
@requires_accelerate ()
374
- def test_force_cpu_offload ():
375
- execution_device = torch .device ("cuda" )
374
+ @pytest .mark .parametrize ("exec_device" , [torch .device ("cpu" ), torch .device ("cuda" )])
375
+ def test_delete_offload_module (exec_device ):
376
+ # no offloading
377
+ model = ExampleModel ()
378
+ child = torch .nn .Linear (2 , 3 )
379
+ register_offload_module (model , "child" , child )
380
+ register_offload_module (model .linear , "child" , child )
381
+ delete_offload_module (model , "child" )
382
+ delete_offload_module (model .linear , "child" )
383
+ assert not child in model .children ()
384
+ assert not child in model .linear .children ()
376
385
386
+ # with offloading
387
+ model = ExampleModel ()
388
+ child = torch .nn .Linear (2 , 3 )
389
+ force_cpu_offload (model , exec_device )
390
+ register_offload_module (model , "child" , child )
391
+ register_offload_module (model .linear , "child" , child )
392
+ delete_offload_module (model , "child" )
393
+ delete_offload_module (model .linear , "child" )
394
+ assert not child in model .children ()
395
+ assert not child in model .linear .children ()
396
+
397
+
398
+ @requires_gpu
399
+ @requires_accelerate ()
400
+ @pytest .mark .parametrize ("exec_device" , [torch .device ("cpu" ), torch .device ("cuda" )])
401
+ def test_force_cpu_offload (exec_device ):
377
402
# single module
378
403
module = torch .nn .Linear (1 , 2 )
379
- module = force_cpu_offload (module , execution_device )
404
+ module = force_cpu_offload (module , exec_device )
380
405
assert has_offloaded_params (module )
381
406
assert module ._hf_hook .offload
382
407
assert module .weight .device == torch .device ("meta" )
383
408
assert "weight" in module ._hf_hook .weights_map
384
409
assert module ._hf_hook .tied_params_map is not None
385
410
386
411
# can run
387
- module (torch .empty (1 , device = execution_device ))
412
+ module (torch .empty (1 , device = exec_device ))
388
413
389
414
# model
390
415
model = ExampleModel ()
391
- model = force_cpu_offload (model , execution_device )
416
+ model = force_cpu_offload (model , exec_device )
392
417
assert not has_offloaded_params (model )
393
418
394
419
assert has_offloaded_params (model .linear )
@@ -398,4 +423,4 @@ def test_force_cpu_offload():
398
423
assert model .linear ._hf_hook .tied_params_map is not None
399
424
400
425
# can run
401
- model (torch .empty (1 , device = execution_device ))
426
+ model (torch .empty (1 , device = exec_device ))
0 commit comments