19
19
delete_offload_module ,
20
20
delete_offload_parameter ,
21
21
disable_hf_hook ,
22
+ disable_offloading ,
22
23
get_execution_device ,
23
24
has_offloaded_params ,
24
25
offloaded_dispatch ,
@@ -397,29 +398,37 @@ def test_delete_offload_module(exec_device):
397
398
398
399
@requires_gpu
399
400
@requires_accelerate ()
400
- @pytest .mark .parametrize ("exec_device" , [torch .device ("cpu" ), torch .device ("cuda" )])
401
- def test_offloaded_dispatch (exec_device ):
401
+ @pytest .mark .parametrize (
402
+ "exec_device,offload_device" ,
403
+ [
404
+ (torch .device ("cpu" ), torch .device ("cpu" )),
405
+ (torch .device ("cpu" ), torch .device ("cuda:0" )),
406
+ (torch .device ("cuda:0" ), torch .device ("cpu" )),
407
+ (torch .device ("cuda:0" ), torch .device ("cuda:0" )),
408
+ ],
409
+ )
410
+ def test_offloaded_dispatch (exec_device , offload_device ):
402
411
# single module
403
- module = torch .nn .Linear (1 , 2 )
404
- module = offloaded_dispatch (module , exec_device )
412
+ module = torch .nn .Linear (1 , 2 , device = offload_device )
413
+ module = offloaded_dispatch (module , exec_device , offload_device )
405
414
assert has_offloaded_params (module )
406
415
assert module ._hf_hook .offload
407
416
assert module .weight .device == torch .device ("meta" )
408
- assert "weight" in module ._hf_hook .weights_map
417
+ assert module ._hf_hook .weights_map [ "weight" ]. device == offload_device
409
418
assert module ._hf_hook .tied_params_map is not None
410
419
411
420
# can run
412
421
module (torch .empty (1 , device = exec_device ))
413
422
414
423
# model
415
424
model = ExampleModel ()
416
- model = offloaded_dispatch (model , exec_device )
425
+ model = offloaded_dispatch (model , exec_device , offload_device )
417
426
assert not has_offloaded_params (model )
418
427
419
428
assert has_offloaded_params (model .linear )
420
429
assert model .linear ._hf_hook .offload
421
430
assert model .linear .weight .device == torch .device ("meta" )
422
- assert "weight" in model .linear ._hf_hook .weights_map
431
+ assert model .linear ._hf_hook .weights_map [ "weight" ]. device == offload_device
423
432
assert model .linear ._hf_hook .tied_params_map is not None
424
433
425
434
# can run
@@ -429,4 +438,43 @@ def test_offloaded_dispatch(exec_device):
429
438
parameter = torch .nn .Parameter (torch .tensor (1.0 ))
430
439
register_offload_parameter (module , "new_param" , parameter )
431
440
assert module .new_param .device == torch .device ("meta" )
432
- assert module ._hf_hook .weights_map ["new_param" ].device == torch .device ("cpu" )
441
+ assert module ._hf_hook .weights_map ["new_param" ].device == offload_device
442
+
443
+
444
+ @requires_gpu
445
+ @requires_accelerate ()
446
+ @pytest .mark .parametrize (
447
+ "exec_device,offload_device" ,
448
+ [
449
+ (torch .device ("cpu" ), torch .device ("cpu" )),
450
+ (torch .device ("cpu" ), torch .device ("cuda:0" )),
451
+ (torch .device ("cuda:0" ), torch .device ("cpu" )),
452
+ (torch .device ("cuda:0" ), torch .device ("cuda:0" )),
453
+ ],
454
+ )
455
+ def test_disable_offloading (exec_device , offload_device ):
456
+ module = torch .nn .Linear (1 , 2 , device = exec_device )
457
+
458
+ # non-offloaded modules are unaffected
459
+ with disable_offloading ():
460
+ output = module (torch .empty (1 , device = exec_device ))
461
+ assert module .weight .device == exec_device
462
+ assert output .device == exec_device
463
+
464
+ # offloaded modules stay on device until context exit
465
+ offloaded_dispatch (module , exec_device , offload_device )
466
+ assert module .weight .device == torch .device ("meta" )
467
+ assert module ._hf_hook .weights_map ["weight" ].device == offload_device
468
+
469
+ with disable_offloading ():
470
+ assert module .weight .device == torch .device ("meta" )
471
+ output = module (torch .empty (1 , device = exec_device ))
472
+ assert module .weight .device == exec_device
473
+ assert output .device == exec_device
474
+
475
+ output = module (torch .empty (1 , device = exec_device ))
476
+ assert module .weight .device == exec_device
477
+ assert output .device == exec_device
478
+
479
+ assert module .weight .device == torch .device ("meta" )
480
+ assert module ._hf_hook .weights_map ["weight" ].device == offload_device
0 commit comments