31
31
import warnings
32
32
from functools import wraps
33
33
from operator import attrgetter
34
- from typing import Any , Callable , Dict , Iterable , Literal , Optional , Union
34
+ from typing import Any , Callable , Dict , Iterable , Literal , Optional , Tuple , Union
35
35
36
36
import torch
37
+ from compressed_tensors .utils import patch_attr
37
38
38
39
39
40
try :
83
84
"register_offload_module" ,
84
85
"delete_offload_module" ,
85
86
"offloaded_dispatch" ,
87
+ "disable_offloading" ,
88
+ "remove_dispatch" ,
86
89
]
87
90
88
91
@@ -168,22 +171,22 @@ def update_parameter_data(
168
171
169
172
def get_execution_device (module : torch .nn .Module ) -> torch .device :
170
173
"""
171
- Get the device which inputs should be moved to before module execution
174
+ Get the device which inputs should be moved to before module execution.
175
+ Assume that modules execute in the same order as returned by `model.modules()`
172
176
173
177
:param module: module to check, may be offloaded
174
178
:return: onload device of module
175
179
"""
176
- if has_offloaded_params (module ):
177
- return module ._hf_hook .execution_device
180
+ for submodule in module .modules ():
181
+ if has_offloaded_params (submodule ):
182
+ return submodule ._hf_hook .execution_device
178
183
179
- first_param = next (module .parameters (), None )
180
- if first_param is None :
181
- warnings .warn (
182
- f"Unable able to infer execution device of { module } , falling back to CPU"
183
- )
184
- return torch .device ("cpu" )
184
+ param = next (submodule .parameters (recurse = False ), None )
185
+ if param is not None :
186
+ return param .device
185
187
186
- return first_param .device
188
+ warnings .warn (f"Unable to get execution device of { module } , falling back to CPU" )
189
+ return torch .device ("cpu" )
187
190
188
191
189
192
def register_offload_parameter (
@@ -204,17 +207,32 @@ def register_offload_parameter(
204
207
has_onload = any (p .device != torch .device ("meta" ) for p in module .parameters ())
205
208
module .register_parameter (name , parameter )
206
209
210
+ # do everything AlignDevicesHook.init_hook does
211
+ # https://github.com/huggingface/accelerate/blob/main/src/accelerate/hooks.py#L281
207
212
if has_offloaded_params (module ):
208
- weights_map = module ._hf_hook .weights_map
209
- offload_to_weights_map (weights_map , name , parameter .data , offload_device )
213
+ hook : AlignDevicesHook = module ._hf_hook
214
+ assert hook .weights_map is not None
215
+
216
+ # append to original_devices
217
+ hook .original_devices [name ] = parameter .device
218
+
219
+ # append to weights map
220
+ offload_to_weights_map (hook .weights_map , name , parameter .data , offload_device )
221
+
222
+ # append to tied_params_map
223
+ offloaded = hook .weights_map [name ]
224
+ if hook .tied_params_map is not None :
225
+ hook .tied_params_map [offloaded .data_ptr ()] = {} # (1)
226
+
227
+ # perform offloading
210
228
if not has_onload :
211
229
set_module_tensor_to_device (module , name , "meta" )
212
230
213
231
214
232
def update_offload_parameter (
215
233
module : torch .nn .Module ,
216
234
name : str ,
217
- data : Optional [ torch .Tensor ] ,
235
+ data : torch .Tensor ,
218
236
offload_device : Optional [Union [torch .device , Literal ["disk" ]]] = None ,
219
237
):
220
238
"""
@@ -227,15 +245,15 @@ def update_offload_parameter(
227
245
:param offload_device: device on which weight will be offloaded to. If None is
228
246
provided, then infer device from parameters on module
229
247
"""
230
- param = getattr (module , name )
248
+ param : torch . nn . Parameter = getattr (module , name )
231
249
if param .data .shape != data .shape :
232
250
warnings .warn (
233
251
f"Shape of parameter being updated { param .data .shape } does not match shape "
234
252
f"of update data { data .shape } "
235
253
)
236
254
237
255
# copy data into onloaded parameter if applicable
238
- if param .device != torch .device ("meta" ):
256
+ if param .device != torch .device ("meta" ) and data is not param . data :
239
257
param .data .copy_ (data )
240
258
241
259
# update offload dict
@@ -420,7 +438,6 @@ def register_offload_module(base: torch.nn.Module, name: str, module: torch.nn.M
420
438
hook : AlignDevicesHook = base ._hf_hook
421
439
assert hook .offload
422
440
assert hook .weights_map is not None
423
- assert hook .tied_params_map is not None
424
441
425
442
# offloading kwargs for submodule
426
443
place_submodules = False
@@ -435,7 +452,8 @@ def register_offload_module(base: torch.nn.Module, name: str, module: torch.nn.M
435
452
module , include_buffers = offload_buffers , recurse = place_submodules
436
453
):
437
454
offloaded = param .to (offload_device )
438
- hook .tied_params_map [offloaded .data_ptr ()] = {} # (1)
455
+ if hook .tied_params_map is not None :
456
+ hook .tied_params_map [offloaded .data_ptr ()] = {} # (1)
439
457
offload_to_weights_map (hook .weights_map , f"{ name } .{ param_name } " , offloaded )
440
458
441
459
# if the parent places submodules, offload here
@@ -463,9 +481,6 @@ def register_offload_module(base: torch.nn.Module, name: str, module: torch.nn.M
463
481
464
482
base .register_module (name , module )
465
483
466
- # (1): Since we cannot know which pointers are shared when we add parameters in an
467
- # online way, assume that all pointers are shared. This comes at no runtime cost
468
-
469
484
470
485
def delete_offload_module (base : torch .nn .Module , name : str ):
471
486
"""
@@ -500,8 +515,13 @@ def offloaded_dispatch(
500
515
if offload_device == "disk" :
501
516
raise NotImplementedError ("Disk offloading is not currently supported" )
502
517
518
+ # remove any existing hooks
519
+ remove_dispatch (module )
520
+
503
521
# create weights map
504
- weights_map = OffloadedWeightsLoader (state_dict = module .state_dict (), device = "cpu" )
522
+ state_dict = module .state_dict ()
523
+ state_dict = {key : val .to (offload_device ) for key , val in state_dict .items ()}
524
+ weights_map = OffloadedWeightsLoader (state_dict = state_dict , device = offload_device )
505
525
506
526
# create tied params map
507
527
tied_params = find_tied_parameters (module )
@@ -519,9 +539,66 @@ def offloaded_dispatch(
519
539
weights_map = weights_map ,
520
540
tied_params_map = tied_params_map ,
521
541
)
542
+
543
+ # when saving a model, `PretrainedModel.save_pretrained` will only
544
+ # onload weights if the following requirements are met
545
+ # if (
546
+ # hasattr(self, "hf_device_map")
547
+ # and len(set(self.hf_device_map.values())) > 1
548
+ # and ("cpu" in self.hf_device_map.values()
549
+ # or "disk" in self.hf_device_map.values())
550
+ # ):
551
+ # because this function always offloads, disregard actual devices and
552
+ # always use `cpu` and `cuda:0` to guarantee this condition passes
553
+ setattr (module , "hf_device_map" , {"fake_offload" : "cpu" , "fake_exec" : "cuda:0" })
554
+
522
555
return module
523
556
524
557
558
+ def remove_dispatch (module : torch .nn .Module ) -> torch .nn .Module :
559
+ """
560
+ Remove any existing dispatches from module
561
+
562
+ :param module: module which may be dispatched with hf hooks
563
+ :return: module without dispatch
564
+ """
565
+ remove_hook_from_module (module , recurse = True )
566
+ if hasattr (module , "hf_device_map" ):
567
+ delattr (module , "hf_device_map" )
568
+
569
+ return module
570
+
571
+
572
+ @contextlib .contextmanager
573
+ def disable_offloading ():
574
+ """
575
+ Keep modules onloaded and disable offloading until this context exits.
576
+ Affects modules which have been hooked with accelerate's `AlignDevicesHook`
577
+ """
578
+ original_pre_forward = AlignDevicesHook .pre_forward
579
+ onloaded_modules : Dict [torch .nn .Module , Tuple [AlignDevicesHook , bool ]] = dict ()
580
+
581
+ # onload once and disable any future onloading/offloading steps
582
+ def keep_onload_pre_forward (self : AlignDevicesHook , module , * args , ** kwargs ):
583
+ ret = original_pre_forward (self , module , * args , ** kwargs )
584
+ if module not in onloaded_modules :
585
+ onloaded_modules [module ] = (self , self .offload )
586
+ self .offload = False
587
+ return ret
588
+
589
+ # use the patched pre_forward function within the context
590
+ with patch_attr (AlignDevicesHook , "pre_forward" , keep_onload_pre_forward ):
591
+ yield
592
+
593
+ # manually offload all modules that were onloaded
594
+ # update any parameters which may have changed
595
+ for module , (hook , offload ) in onloaded_modules .items ():
596
+ hook .offload = offload
597
+ for name , param in module .named_parameters (recurse = False ):
598
+ update_offload_parameter (module , name , param .data )
599
+ hook .post_forward (module , None )
600
+
601
+
525
602
""" Upstreamed Functions """
526
603
527
604
@@ -589,3 +666,7 @@ def align_module_device(
589
666
590
667
else :
591
668
yield
669
+
670
+
671
+ # (1): Since we cannot know which pointers are shared when we add parameters in an
672
+ # online way, assume that all pointers are shared. This has virtually no runtime cost
0 commit comments