28
28
import contextlib
29
29
import warnings
30
30
from functools import wraps
31
- from typing import Any , Callable , Dict , Iterable , Literal , Optional , Union
31
+ from typing import Any , Callable , Dict , Iterable , List , Literal , Optional , Union
32
32
33
33
import torch
34
34
35
35
36
36
try :
37
+ from accelerate import dispatch_model
37
38
from accelerate .hooks import (
38
39
AlignDevicesHook ,
39
40
add_hook_to_module ,
41
+ named_module_tensors ,
40
42
remove_hook_from_module ,
41
43
)
42
44
from accelerate .utils import (
54
56
OffloadedWeightsLoader = None
55
57
PrefixedDataset = None
56
58
set_module_tensor_to_device = None
59
+ named_module_tensors = None
60
+ dispatch_model = None
57
61
58
62
59
63
__all__ = [
70
74
"disable_offload" ,
71
75
"align_modules" ,
72
76
"align_module_device" ,
77
+ "register_offload_module" ,
78
+ "force_cpu_offload" ,
73
79
]
74
80
75
81
76
82
def check_accelerate (fallback : Any ):
77
83
def decorator (func : Callable [[Any ], Any ]):
78
84
if not _has_accelerate :
79
85
86
+ if fallback == "error" :
87
+ raise ValueError (
88
+ "Please install `accelerate` in order to use this function"
89
+ )
90
+
80
91
@wraps (func )
81
92
def fallback_fn (* args , ** kwargs ):
82
93
return fallback
@@ -346,6 +357,7 @@ def delete_from_weights_map(
346
357
)
347
358
348
359
360
+ @check_accelerate (fallback = contextlib .nullcontext ())
349
361
@contextlib .contextmanager
350
362
def disable_offload (module : torch .nn .Module ):
351
363
"""
@@ -362,6 +374,7 @@ def disable_offload(module: torch.nn.Module):
362
374
yield
363
375
364
376
377
+ @check_accelerate (fallback = contextlib .nullcontext ())
365
378
@contextlib .contextmanager
366
379
def align_modules (
367
380
modules : Union [torch .nn .Module , Iterable [torch .nn .Module ]],
@@ -383,6 +396,89 @@ def align_modules(
383
396
yield
384
397
385
398
399
+ @check_accelerate (fallback = None )
400
+ def register_offload_module (base : torch .nn .Module , name : str , module : torch .nn .Module ):
401
+ """
402
+ Register a submodule with offloading if the parent module is offloaded
403
+
404
+ :param base: module to attach submodule to
405
+ :param name: name of submodule
406
+ :param module: submodule to attach
407
+ """
408
+
409
+ if has_offloaded_params (base ):
410
+ hook : AlignDevicesHook = base ._hf_hook
411
+ assert hook .offload
412
+ assert hook .weights_map is not None
413
+ assert hook .tied_params_map is not None
414
+
415
+ # offloading kwargs for submodule
416
+ place_submodules = False
417
+ offload_buffers = True
418
+
419
+ # copy device offloading arguments from parent
420
+ current_device = next (base .parameters ()).device # assume base has parameters
421
+ offload_device = get_offloaded_device (base )
422
+
423
+ # offload parameters to weights map
424
+ for param_name , param in named_module_tensors (
425
+ module , include_buffers = offload_buffers , recurse = place_submodules
426
+ ):
427
+ offloaded = param .to (offload_device )
428
+ hook .tied_params_map [offloaded .data_ptr ()] = {} # (1)
429
+ offload_to_weights_map (hook .weights_map , f"{ name } .{ param_name } " , offloaded )
430
+
431
+ # if the parent places submodules, offload here
432
+ if hook .place_submodules :
433
+ set_module_tensor_to_device (module , param_name , current_device )
434
+
435
+ # if the parent does not place submodules, then add a hook
436
+ # parameters are offloaded by `add_hook_to_module`
437
+ if not hook .place_submodules :
438
+ weights_map = PrefixedDataset (
439
+ hook .weights_map .dataset , prefix = f"{ hook .weights_map .prefix } { name } ."
440
+ )
441
+
442
+ submodule_hook = AlignDevicesHook (
443
+ execution_device = hook .execution_device ,
444
+ offload = hook .offload ,
445
+ io_same_device = False ,
446
+ weights_map = weights_map ,
447
+ offload_buffers = offload_buffers ,
448
+ place_submodules = place_submodules ,
449
+ skip_keys = None ,
450
+ tied_params_map = hook .tied_params_map ,
451
+ )
452
+ add_hook_to_module (module , submodule_hook )
453
+
454
+ base .register_module (name , module )
455
+
456
+ # (1): Since we cannot know which pointers are shared when we add parameters in an
457
+ # online way, assume that all pointers are shared. This comes at no runtime cost
458
+
459
+
460
+ @check_accelerate (fallback = "error" )
461
+ def force_cpu_offload (module : torch .nn .Module , execution_device : torch .device ):
462
+ device_map = {}
463
+
464
+ def dfs (name : List [str ], module : torch .nn .Module ):
465
+ if next (module .parameters (recurse = False ), None ) is not None :
466
+ device_map ["." .join (name )] = "cpu"
467
+ return
468
+
469
+ else :
470
+ for submodule_name , submodule in module .named_children ():
471
+ name .append (submodule_name )
472
+ dfs (name , submodule )
473
+ name .pop ()
474
+
475
+ dfs ([], module )
476
+
477
+ return dispatch_model (
478
+ module , device_map , main_device = execution_device , force_hooks = True
479
+ )
480
+
481
+
386
482
""" Upstreamed Functions """
387
483
388
484
0 commit comments