@@ -344,9 +344,8 @@ def test_offload_to_weights_map():
344
344
345
345
@requires_gpu
346
346
@requires_accelerate ()
347
- def test_register_offload_module ():
348
- execution_device = torch .device ("cuda" )
349
-
347
+ @pytest .mark .parametrize ("exec_device" , [torch .device ("cpu" ), torch .device ("cuda" )])
348
+ def test_register_offload_module (exec_device ):
350
349
# no offloading
351
350
model = ExampleModel ()
352
351
child = torch .nn .Linear (2 , 3 )
@@ -358,37 +357,36 @@ def test_register_offload_module():
358
357
# with offloading
359
358
model = ExampleModel ()
360
359
child = torch .nn .Linear (2 , 3 )
361
- force_cpu_offload (model , execution_device )
360
+ force_cpu_offload (model , exec_device )
362
361
register_offload_module (model , "child" , child )
363
362
register_offload_module (model .linear , "child" , child )
364
363
assert child in model .children ()
365
364
assert child in model .linear .children ()
366
365
367
366
# can run modules
368
367
model (torch .empty (1 ))
369
- child (torch .empty (2 , device = execution_device ))
368
+ child (torch .empty (2 , device = exec_device ))
370
369
371
370
372
371
@requires_gpu
373
372
@requires_accelerate ()
374
- def test_force_cpu_offload ():
375
- execution_device = torch .device ("cuda" )
376
-
373
+ @pytest .mark .parametrize ("exec_device" , [torch .device ("cpu" ), torch .device ("cuda" )])
374
+ def test_force_cpu_offload (exec_device ):
377
375
# single module
378
376
module = torch .nn .Linear (1 , 2 )
379
- module = force_cpu_offload (module , execution_device )
377
+ module = force_cpu_offload (module , exec_device )
380
378
assert has_offloaded_params (module )
381
379
assert module ._hf_hook .offload
382
380
assert module .weight .device == torch .device ("meta" )
383
381
assert "weight" in module ._hf_hook .weights_map
384
382
assert module ._hf_hook .tied_params_map is not None
385
383
386
384
# can run
387
- module (torch .empty (1 , device = execution_device ))
385
+ module (torch .empty (1 , device = exec_device ))
388
386
389
387
# model
390
388
model = ExampleModel ()
391
- model = force_cpu_offload (model , execution_device )
389
+ model = force_cpu_offload (model , exec_device )
392
390
assert not has_offloaded_params (model )
393
391
394
392
assert has_offloaded_params (model .linear )
@@ -398,4 +396,4 @@ def test_force_cpu_offload():
398
396
assert model .linear ._hf_hook .tied_params_map is not None
399
397
400
398
# can run
401
- model (torch .empty (1 , device = execution_device ))
399
+ model (torch .empty (1 , device = exec_device ))
0 commit comments