@@ -206,9 +206,24 @@ def register_offload_parameter(
206
206
has_onload = any (p .device != torch .device ("meta" ) for p in module .parameters ())
207
207
module .register_parameter (name , parameter )
208
208
209
+ # do everything AlignDevicesHook.init_hook does
210
+ # https://github.com/huggingface/accelerate/blob/main/src/accelerate/hooks.py#L281
209
211
if has_offloaded_params (module ):
210
- weights_map = module ._hf_hook .weights_map
211
- offload_to_weights_map (weights_map , name , parameter .data , offload_device )
212
+ hook : AlignDevicesHook = module ._hf_hook
213
+ assert hook .weights_map is not None
214
+
215
+ # append to original_devices
216
+ hook .original_devices [name ] = parameter .device
217
+
218
+ # append to weights map
219
+ offload_to_weights_map (hook .weights_map , name , parameter .data , offload_device )
220
+
221
+ # append to tied_params_map
222
+ offloaded = hook .weights_map [name ]
223
+ if hook .tied_params_map is not None :
224
+ hook .tied_params_map [offloaded .data_ptr ()] = {} # (1)
225
+
226
+ # perform offloading
212
227
if not has_onload :
213
228
set_module_tensor_to_device (module , name , "meta" )
214
229
@@ -422,7 +437,6 @@ def register_offload_module(base: torch.nn.Module, name: str, module: torch.nn.M
422
437
hook : AlignDevicesHook = base ._hf_hook
423
438
assert hook .offload
424
439
assert hook .weights_map is not None
425
- assert hook .tied_params_map is not None
426
440
427
441
# offloading kwargs for submodule
428
442
place_submodules = False
@@ -437,7 +451,8 @@ def register_offload_module(base: torch.nn.Module, name: str, module: torch.nn.M
437
451
module , include_buffers = offload_buffers , recurse = place_submodules
438
452
):
439
453
offloaded = param .to (offload_device )
440
- hook .tied_params_map [offloaded .data_ptr ()] = {} # (1)
454
+ if hook .tied_params_map is not None :
455
+ hook .tied_params_map [offloaded .data_ptr ()] = {} # (1)
441
456
offload_to_weights_map (hook .weights_map , f"{ name } .{ param_name } " , offloaded )
442
457
443
458
# if the parent places submodules, offload here
@@ -465,9 +480,6 @@ def register_offload_module(base: torch.nn.Module, name: str, module: torch.nn.M
465
480
466
481
base .register_module (name , module )
467
482
468
- # (1): Since we cannot know which pointers are shared when we add parameters in an
469
- # online way, assume that all pointers are shared. This comes at no runtime cost
470
-
471
483
472
484
def delete_offload_module (base : torch .nn .Module , name : str ):
473
485
"""
@@ -623,3 +635,7 @@ def align_module_device(
623
635
624
636
else :
625
637
yield
638
+
639
+
640
+ # (1): Since we cannot know which pointers are shared when we add parameters in an
641
+ # online way, assume that all pointers are shared. This has virtually no runtime cost
0 commit comments