14
14
"""
15
15
Utilities associated with offloading functionality provided by `accelerate`.
16
16
17
- | ----------------------------------------------------------------------------------------------------- | # noqa: E501
18
- | Operation | Without offloading support | With offloading support | # noqa: E501
19
- | --------- | -------------------------------------- | ------------------------------------------------ | # noqa: E501
20
- | Add | module.register_parameter(name, param) | register_offload_parameter(module, name, param) | # noqa: E501
21
- | Check | N/A | has_offloaded_params(module) | # noqa: E501
22
- | Onload | N/A | with align_module_device(module) | # noqa: E501
23
- | Update | module.name.data.copy_(new_data) | update_offload_parameter(module, name, new_data) | # noqa: E501
24
- | Delete | del module.name | delete_offload_parameter(module, name) | # noqa: E501
25
- | ----------------------------------------------------------------------------------------------------- | # noqa: E501
17
+ | ------------------------------------------------------------------------------------------------------ | # noqa: E501
18
+ | Operation | Without offloading support | With offloading support | # noqa: E501
19
+ | ---------- | -------------------------------------- | ------------------------------------------------ | # noqa: E501
20
+ | Add | module.register_parameter(name, param) | register_offload_parameter(module, name, param) | # noqa: E501
21
+ | Check | N/A | has_offloaded_params(module) | # noqa: E501
22
+ | Onload | N/A | with align_module_device(module) | # noqa: E501
23
+ | Update | module.name.data.copy_(new_data) | update_offload_parameter(module, name, new_data) | # noqa: E501
24
+ | Delete | del module.name | delete_offload_parameter(module, name) | # noqa: E501
25
+ | Add Module | module.register_module(name, child) | register_offload_module(name, child) | # noqa: E501
26
+ | Del Module | del module.name | delete_offload_module(module, name) | # noqa: E501
27
+ | ------------------------------------------------------------------------------------------------------ | # noqa: E501
26
28
"""
27
29
28
30
import contextlib
29
31
import warnings
30
32
from functools import wraps
31
- from typing import Any , Callable , Dict , Iterable , List , Literal , Optional , Union
33
+ from operator import attrgetter
34
+ from typing import Any , Callable , Dict , Iterable , Literal , Optional , Tuple , Union
32
35
33
36
import torch
37
+ from compressed_tensors .utils import patch_attr
34
38
35
39
36
40
try :
37
- from accelerate import dispatch_model
38
41
from accelerate .hooks import (
39
42
AlignDevicesHook ,
40
43
add_hook_to_module ,
45
48
from accelerate .utils import (
46
49
OffloadedWeightsLoader ,
47
50
PrefixedDataset ,
51
+ find_tied_parameters ,
48
52
set_module_tensor_to_device ,
49
53
)
50
54
51
55
_has_accelerate = True
56
+
52
57
except ImportError :
53
58
_has_accelerate = False
54
59
AlignDevicesHook = None
58
63
PrefixedDataset = None
59
64
set_module_tensor_to_device = None
60
65
named_module_tensors = None
61
- dispatch_model = None
62
66
attach_align_device_hook = None
67
+ find_tied_parameters = None
63
68
64
69
65
70
__all__ = [
78
83
"align_module_device" ,
79
84
"register_offload_module" ,
80
85
"delete_offload_module" ,
81
- "force_cpu_offload" ,
86
+ "offloaded_dispatch" ,
87
+ "disable_offloading" ,
88
+ "remove_dispatch" ,
82
89
]
83
90
84
91
85
92
def check_accelerate (fallback : Any ):
86
93
def decorator (func : Callable [[Any ], Any ]):
87
94
if not _has_accelerate :
88
-
89
95
if fallback == "error" :
90
96
91
97
@wraps (func )
@@ -165,22 +171,22 @@ def update_parameter_data(
165
171
166
172
def get_execution_device (module : torch .nn .Module ) -> torch .device :
167
173
"""
168
- Get the device which inputs should be moved to before module execution
174
+ Get the device which inputs should be moved to before module execution.
175
+ Assume that modules execute in the same order as returned by `model.modules()`
169
176
170
177
:param module: module to check, may be offloaded
171
178
:return: onload device of module
172
179
"""
173
- if has_offloaded_params (module ):
174
- return module ._hf_hook .execution_device
180
+ for submodule in module .modules ():
181
+ if has_offloaded_params (submodule ):
182
+ return submodule ._hf_hook .execution_device
175
183
176
- first_param = next (module .parameters (), None )
177
- if first_param is None :
178
- warnings .warn (
179
- f"Unable able to infer execution device of { module } , falling back to CPU"
180
- )
181
- return torch .device ("cpu" )
184
+ param = next (submodule .parameters (recurse = False ), None )
185
+ if param is not None :
186
+ return param .device
182
187
183
- return first_param .device
188
+ warnings .warn (f"Unable to get execution device of { module } , falling back to CPU" )
189
+ return torch .device ("cpu" )
184
190
185
191
186
192
def register_offload_parameter (
@@ -201,17 +207,32 @@ def register_offload_parameter(
201
207
has_onload = any (p .device != torch .device ("meta" ) for p in module .parameters ())
202
208
module .register_parameter (name , parameter )
203
209
210
+ # do everything AlignDevicesHook.init_hook does
211
+ # https://github.com/huggingface/accelerate/blob/main/src/accelerate/hooks.py#L281
204
212
if has_offloaded_params (module ):
205
- weights_map = module ._hf_hook .weights_map
206
- offload_to_weights_map (weights_map , name , parameter .data , offload_device )
213
+ hook : AlignDevicesHook = module ._hf_hook
214
+ assert hook .weights_map is not None
215
+
216
+ # append to original_devices
217
+ hook .original_devices [name ] = parameter .device
218
+
219
+ # append to weights map
220
+ offload_to_weights_map (hook .weights_map , name , parameter .data , offload_device )
221
+
222
+ # append to tied_params_map
223
+ offloaded = hook .weights_map [name ]
224
+ if hook .tied_params_map is not None :
225
+ hook .tied_params_map [offloaded .data_ptr ()] = {} # (1)
226
+
227
+ # perform offloading
207
228
if not has_onload :
208
229
set_module_tensor_to_device (module , name , "meta" )
209
230
210
231
211
232
def update_offload_parameter (
212
233
module : torch .nn .Module ,
213
234
name : str ,
214
- data : Optional [ torch .Tensor ] ,
235
+ data : torch .Tensor ,
215
236
offload_device : Optional [Union [torch .device , Literal ["disk" ]]] = None ,
216
237
):
217
238
"""
@@ -224,15 +245,15 @@ def update_offload_parameter(
224
245
:param offload_device: device on which weight will be offloaded to. If None is
225
246
provided, then infer device from parameters on module
226
247
"""
227
- param = getattr (module , name )
248
+ param : torch . nn . Parameter = getattr (module , name )
228
249
if param .data .shape != data .shape :
229
250
warnings .warn (
230
251
f"Shape of parameter being updated { param .data .shape } does not match shape "
231
252
f"of update data { data .shape } "
232
253
)
233
254
234
255
# copy data into onloaded parameter if applicable
235
- if param .device != torch .device ("meta" ):
256
+ if param .device != torch .device ("meta" ) and data is not param . data :
236
257
param .data .copy_ (data )
237
258
238
259
# update offload dict
@@ -417,7 +438,6 @@ def register_offload_module(base: torch.nn.Module, name: str, module: torch.nn.M
417
438
hook : AlignDevicesHook = base ._hf_hook
418
439
assert hook .offload
419
440
assert hook .weights_map is not None
420
- assert hook .tied_params_map is not None
421
441
422
442
# offloading kwargs for submodule
423
443
place_submodules = False
@@ -432,7 +452,8 @@ def register_offload_module(base: torch.nn.Module, name: str, module: torch.nn.M
432
452
module , include_buffers = offload_buffers , recurse = place_submodules
433
453
):
434
454
offloaded = param .to (offload_device )
435
- hook .tied_params_map [offloaded .data_ptr ()] = {} # (1)
455
+ if hook .tied_params_map is not None :
456
+ hook .tied_params_map [offloaded .data_ptr ()] = {} # (1)
436
457
offload_to_weights_map (hook .weights_map , f"{ name } .{ param_name } " , offloaded )
437
458
438
459
# if the parent places submodules, offload here
@@ -460,9 +481,6 @@ def register_offload_module(base: torch.nn.Module, name: str, module: torch.nn.M
460
481
461
482
base .register_module (name , module )
462
483
463
- # (1): Since we cannot know which pointers are shared when we add parameters in an
464
- # online way, assume that all pointers are shared. This comes at no runtime cost
465
-
466
484
467
485
def delete_offload_module (base : torch .nn .Module , name : str ):
468
486
"""
@@ -479,46 +497,106 @@ def delete_offload_module(base: torch.nn.Module, name: str):
479
497
480
498
481
499
@check_accelerate (fallback = "error" )
482
- def force_cpu_offload (
483
- module : torch .nn .Module , execution_device : torch .device
500
+ def offloaded_dispatch (
501
+ module : torch .nn .Module ,
502
+ execution_device : torch .device ,
503
+ offload_device : Union [torch .device , Literal ["disk" ]] = torch .device ("cpu" ),
484
504
) -> torch .nn .Module :
485
505
"""
486
- Force cpu offloading a module, primarily used for testing
506
+ Unlike `dispatch_model`, this function forces a module (and its submodules) to
507
+ offload all parameters and replace them with meta tensors, utiliizing the
508
+ `AlignDevicesHook` to control onloading and offloading.
487
509
488
510
:param module: module containing parameters to offload
489
- :param execution_device: execution device submodules
490
- :return: module with hooks to perform cpu offloading
491
- """
492
- # edge case: there is a bug in `dispatch_model` which causes
493
- # the function to only work if the model contains submodules
494
- if next (module .children (), None ) is None :
495
- attach_align_device_hook (
496
- module ,
497
- execution_device = execution_device ,
498
- offload = True ,
499
- weights_map = module .state_dict (),
500
- tied_params_map = {},
501
- )
502
- return module
511
+ :param execution_device: device that modules will onload and execute on
512
+ :param offload_device: device that module parameters will offload to
513
+ :return: module with offloading device hooks
514
+ """
515
+ if offload_device == "disk" :
516
+ raise NotImplementedError ("Disk offloading is not currently supported" )
517
+
518
+ # remove any existing hooks
519
+ remove_dispatch (module )
520
+
521
+ # create weights map
522
+ state_dict = module .state_dict ()
523
+ state_dict = {key : val .to (offload_device ) for key , val in state_dict .items ()}
524
+ weights_map = OffloadedWeightsLoader (state_dict = state_dict , device = offload_device )
525
+
526
+ # create tied params map
527
+ tied_params = find_tied_parameters (module )
528
+ tied_params_map = {}
529
+ for group in tied_params :
530
+ for param_name in group :
531
+ data_ptr = attrgetter (param_name )(module ).data_ptr ()
532
+ tied_params_map [data_ptr ] = {}
533
+
534
+ # recursively attaches hooks to all submodules
535
+ attach_align_device_hook (
536
+ module ,
537
+ execution_device = execution_device ,
538
+ offload = True ,
539
+ weights_map = weights_map ,
540
+ tied_params_map = tied_params_map ,
541
+ )
503
542
504
- device_map = {}
543
+ # when saving a model, `PretrainedModel.save_pretrained` will only
544
+ # onload weights if the following requirements are met
545
+ # if (
546
+ # hasattr(self, "hf_device_map")
547
+ # and len(set(self.hf_device_map.values())) > 1
548
+ # and ("cpu" in self.hf_device_map.values()
549
+ # or "disk" in self.hf_device_map.values())
550
+ # ):
551
+ # because this function always offloads, disregard actual devices and
552
+ # always use `cpu` and `cuda:0` to guarantee this condition passes
553
+ setattr (module , "hf_device_map" , {"fake_offload" : "cpu" , "fake_exec" : "cuda:0" })
505
554
506
- def collect_device_map (name : List [str ], module : torch .nn .Module ):
507
- if next (module .parameters (recurse = False ), None ) is not None :
508
- device_map ["." .join (name )] = "cpu"
509
- return
555
+ return module
510
556
511
- else :
512
- for submodule_name , submodule in module .named_children ():
513
- name .append (submodule_name )
514
- collect_device_map (name , submodule )
515
- name .pop ()
516
557
517
- collect_device_map ([], module )
558
+ def remove_dispatch (module : torch .nn .Module ) -> torch .nn .Module :
559
+ """
560
+ Remove any existing dispatches from module
561
+
562
+ :param module: module which may be dispatched with hf hooks
563
+ :return: module without dispatch
564
+ """
565
+ remove_hook_from_module (module , recurse = True )
566
+ if hasattr (module , "hf_device_map" ):
567
+ delattr (module , "hf_device_map" )
568
+
569
+ return module
518
570
519
- return dispatch_model (
520
- module , device_map , main_device = execution_device , force_hooks = True
521
- )
571
+
572
+ @contextlib .contextmanager
573
+ def disable_offloading ():
574
+ """
575
+ Keep modules onloaded and disable offloading until this context exits.
576
+ Affects modules which have been hooked with accelerate's `AlignDevicesHook`
577
+ """
578
+ original_pre_forward = AlignDevicesHook .pre_forward
579
+ onloaded_modules : Dict [torch .nn .Module , Tuple [AlignDevicesHook , bool ]] = dict ()
580
+
581
+ # onload once and disable any future onloading/offloading steps
582
+ def keep_onload_pre_forward (self : AlignDevicesHook , module , * args , ** kwargs ):
583
+ ret = original_pre_forward (self , module , * args , ** kwargs )
584
+ if module not in onloaded_modules :
585
+ onloaded_modules [module ] = (self , self .offload )
586
+ self .offload = False
587
+ return ret
588
+
589
+ # use the patched pre_forward function within the context
590
+ with patch_attr (AlignDevicesHook , "pre_forward" , keep_onload_pre_forward ):
591
+ yield
592
+
593
+ # manually offload all modules that were onloaded
594
+ # update any parameters which may have changed
595
+ for module , (hook , offload ) in onloaded_modules .items ():
596
+ hook .offload = offload
597
+ for name , param in module .named_parameters (recurse = False ):
598
+ update_offload_parameter (module , name , param .data )
599
+ hook .post_forward (module , None )
522
600
523
601
524
602
""" Upstreamed Functions """
@@ -588,3 +666,7 @@ def align_module_device(
588
666
589
667
else :
590
668
yield
669
+
670
+
671
+ # (1): Since we cannot know which pointers are shared when we add parameters in an
672
+ # online way, assume that all pointers are shared. This has virtually no runtime cost
0 commit comments