Skip to content

Commit 0ffe3c3

Browse files
authored
[Accelerate] Extend functionality of register_offload_parameter (#356)
* extend register_offload_parameter Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * add link Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * remove dreggs Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> --------- Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent 3fb2844 commit 0ffe3c3

File tree

2 files changed

+64
-7
lines changed

2 files changed

+64
-7
lines changed

src/compressed_tensors/utils/offload.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -206,9 +206,24 @@ def register_offload_parameter(
206206
has_onload = any(p.device != torch.device("meta") for p in module.parameters())
207207
module.register_parameter(name, parameter)
208208

209+
# do everything AlignDevicesHook.init_hook does
210+
# https://github.com/huggingface/accelerate/blob/main/src/accelerate/hooks.py#L281
209211
if has_offloaded_params(module):
210-
weights_map = module._hf_hook.weights_map
211-
offload_to_weights_map(weights_map, name, parameter.data, offload_device)
212+
hook: AlignDevicesHook = module._hf_hook
213+
assert hook.weights_map is not None
214+
215+
# append to original_devices
216+
hook.original_devices[name] = parameter.device
217+
218+
# append to weights map
219+
offload_to_weights_map(hook.weights_map, name, parameter.data, offload_device)
220+
221+
# append to tied_params_map
222+
offloaded = hook.weights_map[name]
223+
if hook.tied_params_map is not None:
224+
hook.tied_params_map[offloaded.data_ptr()] = {} # (1)
225+
226+
# perform offloading
212227
if not has_onload:
213228
set_module_tensor_to_device(module, name, "meta")
214229

@@ -422,7 +437,6 @@ def register_offload_module(base: torch.nn.Module, name: str, module: torch.nn.M
422437
hook: AlignDevicesHook = base._hf_hook
423438
assert hook.offload
424439
assert hook.weights_map is not None
425-
assert hook.tied_params_map is not None
426440

427441
# offloading kwargs for submodule
428442
place_submodules = False
@@ -437,7 +451,8 @@ def register_offload_module(base: torch.nn.Module, name: str, module: torch.nn.M
437451
module, include_buffers=offload_buffers, recurse=place_submodules
438452
):
439453
offloaded = param.to(offload_device)
440-
hook.tied_params_map[offloaded.data_ptr()] = {} # (1)
454+
if hook.tied_params_map is not None:
455+
hook.tied_params_map[offloaded.data_ptr()] = {} # (1)
441456
offload_to_weights_map(hook.weights_map, f"{name}.{param_name}", offloaded)
442457

443458
# if the parent places submodules, offload here
@@ -465,9 +480,6 @@ def register_offload_module(base: torch.nn.Module, name: str, module: torch.nn.M
465480

466481
base.register_module(name, module)
467482

468-
# (1): Since we cannot know which pointers are shared when we add parameters in an
469-
# online way, assume that all pointers are shared. This comes at no runtime cost
470-
471483

472484
def delete_offload_module(base: torch.nn.Module, name: str):
473485
"""
@@ -623,3 +635,7 @@ def align_module_device(
623635

624636
else:
625637
yield
638+
639+
640+
# (1): Since we cannot know which pointers are shared when we add parameters in an
641+
# online way, assume that all pointers are shared. This has virtually no runtime cost

tests/test_utils/test_offload.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,47 @@ def test_register_offload_parameter():
149149
assert module.a.device == module.b.device == module.c.device == torch.device("meta")
150150

151151

152+
@requires_accelerate()
153+
@requires_gpu
154+
def test_register_offload_parameter_hook_replacement():
155+
module = ExampleModule()
156+
parameter_c = torch.nn.Parameter(torch.tensor(1.0, device="cuda"))
157+
parameter_d = torch.nn.Parameter(torch.tensor(1.0, device="cpu"))
158+
159+
offloaded_dispatch(module, "cuda")
160+
register_offload_parameter(module, "c", parameter_c)
161+
register_offload_parameter(module, "d", parameter_d)
162+
163+
with disable_hf_hook(module):
164+
assert module.a.device == torch.device("cpu")
165+
assert module.b.device == torch.device("cpu")
166+
assert module.c.device == torch.device("cuda:0")
167+
assert module.d.device == torch.device("cpu")
168+
169+
assert module.a.device == torch.device("meta")
170+
assert module.b.device == torch.device("meta")
171+
assert module.c.device == torch.device("meta")
172+
assert module.d.device == torch.device("meta")
173+
assert module._hf_hook.weights_map["a"].device == torch.device("cpu")
174+
assert module._hf_hook.weights_map["b"].device == torch.device("cpu")
175+
assert module._hf_hook.weights_map["c"].device == torch.device("cpu")
176+
assert module._hf_hook.weights_map["d"].device == torch.device("cpu")
177+
178+
179+
@requires_accelerate()
180+
@requires_gpu
181+
def test_register_offload_parameter_shared():
182+
module = ExampleModule()
183+
parameter = torch.nn.Parameter(torch.tensor(1.0))
184+
185+
offloaded_dispatch(module, "cuda")
186+
register_offload_parameter(module, "c", parameter)
187+
register_offload_parameter(module, "d", parameter)
188+
189+
with align_module_device(module):
190+
assert module.c is module.d
191+
192+
152193
@requires_accelerate()
153194
def test_update_offload_parameter():
154195
from accelerate.hooks import attach_align_device_hook

0 commit comments

Comments
 (0)