28
28
import contextlib
29
29
import warnings
30
30
from functools import wraps
31
- from typing import Any , Callable , Dict , Iterable , Literal , Optional , Union
31
+ from typing import Any , Callable , Dict , Iterable , List , Literal , Optional , Union
32
32
33
33
import torch
34
34
35
35
36
36
try :
37
+ from accelerate import dispatch_model
37
38
from accelerate .hooks import (
38
39
AlignDevicesHook ,
39
40
add_hook_to_module ,
41
+ attach_align_device_hook ,
42
+ named_module_tensors ,
40
43
remove_hook_from_module ,
41
44
)
42
45
from accelerate .utils import (
54
57
OffloadedWeightsLoader = None
55
58
PrefixedDataset = None
56
59
set_module_tensor_to_device = None
60
+ named_module_tensors = None
61
+ dispatch_model = None
62
+ attach_align_device_hook = None
57
63
58
64
59
65
__all__ = [
70
76
"disable_offload" ,
71
77
"align_modules" ,
72
78
"align_module_device" ,
79
+ "register_offload_module" ,
80
+ "delete_offload_module" ,
81
+ "force_cpu_offload" ,
73
82
]
74
83
75
84
76
85
def check_accelerate (fallback : Any ):
77
86
def decorator (func : Callable [[Any ], Any ]):
78
87
if not _has_accelerate :
79
88
89
+ if fallback == "error" :
90
+ raise ValueError (
91
+ "Please install `accelerate` in order to use this function"
92
+ )
93
+
80
94
@wraps (func )
81
95
def fallback_fn (* args , ** kwargs ):
82
96
return fallback
@@ -346,6 +360,7 @@ def delete_from_weights_map(
346
360
)
347
361
348
362
363
+ @check_accelerate (fallback = contextlib .nullcontext ())
349
364
@contextlib .contextmanager
350
365
def disable_offload (module : torch .nn .Module ):
351
366
"""
@@ -362,6 +377,7 @@ def disable_offload(module: torch.nn.Module):
362
377
yield
363
378
364
379
380
+ @check_accelerate (fallback = contextlib .nullcontext ())
365
381
@contextlib .contextmanager
366
382
def align_modules (
367
383
modules : Union [torch .nn .Module , Iterable [torch .nn .Module ]],
@@ -383,6 +399,123 @@ def align_modules(
383
399
yield
384
400
385
401
402
+ def register_offload_module (base : torch .nn .Module , name : str , module : torch .nn .Module ):
403
+ """
404
+ Register a submodule with offloading if the parent module is offloaded
405
+
406
+ :param base: module to attach submodule to
407
+ :param name: name of submodule
408
+ :param module: submodule to attach
409
+ """
410
+
411
+ if has_offloaded_params (base ):
412
+ hook : AlignDevicesHook = base ._hf_hook
413
+ assert hook .offload
414
+ assert hook .weights_map is not None
415
+ assert hook .tied_params_map is not None
416
+
417
+ # offloading kwargs for submodule
418
+ place_submodules = False
419
+ offload_buffers = True
420
+
421
+ # copy device offloading arguments from parent
422
+ current_device = next (base .parameters ()).device # assume base has parameters
423
+ offload_device = get_offloaded_device (base )
424
+
425
+ # offload parameters to weights map
426
+ for param_name , param in named_module_tensors (
427
+ module , include_buffers = offload_buffers , recurse = place_submodules
428
+ ):
429
+ offloaded = param .to (offload_device )
430
+ hook .tied_params_map [offloaded .data_ptr ()] = {} # (1)
431
+ offload_to_weights_map (hook .weights_map , f"{ name } .{ param_name } " , offloaded )
432
+
433
+ # if the parent places submodules, offload here
434
+ if hook .place_submodules :
435
+ set_module_tensor_to_device (module , param_name , current_device )
436
+
437
+ # if the parent does not place submodules, then add a hook
438
+ # parameters are offloaded by `add_hook_to_module`
439
+ if not hook .place_submodules :
440
+ weights_map = PrefixedDataset (
441
+ hook .weights_map .dataset , prefix = f"{ hook .weights_map .prefix } { name } ."
442
+ )
443
+
444
+ submodule_hook = AlignDevicesHook (
445
+ execution_device = hook .execution_device ,
446
+ offload = hook .offload ,
447
+ io_same_device = False ,
448
+ weights_map = weights_map ,
449
+ offload_buffers = offload_buffers ,
450
+ place_submodules = place_submodules ,
451
+ skip_keys = None ,
452
+ tied_params_map = hook .tied_params_map ,
453
+ )
454
+ add_hook_to_module (module , submodule_hook )
455
+
456
+ base .register_module (name , module )
457
+
458
+ # (1): Since we cannot know which pointers are shared when we add parameters in an
459
+ # online way, assume that all pointers are shared. This comes at no runtime cost
460
+
461
+
462
+ def delete_offload_module (base : torch .nn .Module , name : str ):
463
+ """
464
+ Delete a submodule from a model which may contain offloading
465
+ :param base: parent module to delete submodule from
466
+ :param name: name of submodule on parent
467
+ """
468
+ module : torch .nn .Module = getattr (base , name )
469
+
470
+ for param_name , _ in list (module .named_parameters ()):
471
+ delete_offload_parameter (module , param_name )
472
+
473
+ delattr (base , name )
474
+
475
+
476
+ @check_accelerate (fallback = "error" )
477
+ def force_cpu_offload (
478
+ module : torch .nn .Module , execution_device : torch .device
479
+ ) -> torch .nn .Module :
480
+ """
481
+ Force cpu offloading a module, primarily used for testing
482
+
483
+ :param module: module containing parameters to offload
484
+ :param execution_device: execution device submodules
485
+ :return: module with hooks to perform cpu offloading
486
+ """
487
+ # edge case: there is a bug in `dispatch_model` which causes
488
+ # the function to only work if the model contains submodules
489
+ if next (module .children (), None ) is None :
490
+ attach_align_device_hook (
491
+ module ,
492
+ execution_device = execution_device ,
493
+ offload = True ,
494
+ weights_map = module .state_dict (),
495
+ tied_params_map = {},
496
+ )
497
+ return module
498
+
499
+ device_map = {}
500
+
501
+ def collect_device_map (name : List [str ], module : torch .nn .Module ):
502
+ if next (module .parameters (recurse = False ), None ) is not None :
503
+ device_map ["." .join (name )] = "cpu"
504
+ return
505
+
506
+ else :
507
+ for submodule_name , submodule in module .named_children ():
508
+ name .append (submodule_name )
509
+ collect_device_map (name , submodule )
510
+ name .pop ()
511
+
512
+ collect_device_map ([], module )
513
+
514
+ return dispatch_model (
515
+ module , device_map , main_device = execution_device , force_hooks = True
516
+ )
517
+
518
+
386
519
""" Upstreamed Functions """
387
520
388
521
0 commit comments