Skip to content

Commit 7b5a7a4

Browse files
authored
[Transform] Accelerate Utilities (#328)
* add utilities Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * add tests Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * add additional tests Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * add delete_offload_module Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> --------- Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent 33a3f9d commit 7b5a7a4

File tree

2 files changed

+229
-6
lines changed

2 files changed

+229
-6
lines changed

src/compressed_tensors/utils/offload.py

Lines changed: 134 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,18 @@
2828
import contextlib
2929
import warnings
3030
from functools import wraps
31-
from typing import Any, Callable, Dict, Iterable, Literal, Optional, Union
31+
from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Union
3232

3333
import torch
3434

3535

3636
try:
37+
from accelerate import dispatch_model
3738
from accelerate.hooks import (
3839
AlignDevicesHook,
3940
add_hook_to_module,
41+
attach_align_device_hook,
42+
named_module_tensors,
4043
remove_hook_from_module,
4144
)
4245
from accelerate.utils import (
@@ -54,6 +57,9 @@
5457
OffloadedWeightsLoader = None
5558
PrefixedDataset = None
5659
set_module_tensor_to_device = None
60+
named_module_tensors = None
61+
dispatch_model = None
62+
attach_align_device_hook = None
5763

5864

5965
__all__ = [
@@ -70,13 +76,21 @@
7076
"disable_offload",
7177
"align_modules",
7278
"align_module_device",
79+
"register_offload_module",
80+
"delete_offload_module",
81+
"force_cpu_offload",
7382
]
7483

7584

7685
def check_accelerate(fallback: Any):
7786
def decorator(func: Callable[[Any], Any]):
7887
if not _has_accelerate:
7988

89+
if fallback == "error":
90+
raise ValueError(
91+
"Please install `accelerate` in order to use this function"
92+
)
93+
8094
@wraps(func)
8195
def fallback_fn(*args, **kwargs):
8296
return fallback
@@ -346,6 +360,7 @@ def delete_from_weights_map(
346360
)
347361

348362

363+
@check_accelerate(fallback=contextlib.nullcontext())
349364
@contextlib.contextmanager
350365
def disable_offload(module: torch.nn.Module):
351366
"""
@@ -362,6 +377,7 @@ def disable_offload(module: torch.nn.Module):
362377
yield
363378

364379

380+
@check_accelerate(fallback=contextlib.nullcontext())
365381
@contextlib.contextmanager
366382
def align_modules(
367383
modules: Union[torch.nn.Module, Iterable[torch.nn.Module]],
@@ -383,6 +399,123 @@ def align_modules(
383399
yield
384400

385401

402+
def register_offload_module(base: torch.nn.Module, name: str, module: torch.nn.Module):
403+
"""
404+
Register a submodule with offloading if the parent module is offloaded
405+
406+
:param base: module to attach submodule to
407+
:param name: name of submodule
408+
:param module: submodule to attach
409+
"""
410+
411+
if has_offloaded_params(base):
412+
hook: AlignDevicesHook = base._hf_hook
413+
assert hook.offload
414+
assert hook.weights_map is not None
415+
assert hook.tied_params_map is not None
416+
417+
# offloading kwargs for submodule
418+
place_submodules = False
419+
offload_buffers = True
420+
421+
# copy device offloading arguments from parent
422+
current_device = next(base.parameters()).device # assume base has parameters
423+
offload_device = get_offloaded_device(base)
424+
425+
# offload parameters to weights map
426+
for param_name, param in named_module_tensors(
427+
module, include_buffers=offload_buffers, recurse=place_submodules
428+
):
429+
offloaded = param.to(offload_device)
430+
hook.tied_params_map[offloaded.data_ptr()] = {} # (1)
431+
offload_to_weights_map(hook.weights_map, f"{name}.{param_name}", offloaded)
432+
433+
# if the parent places submodules, offload here
434+
if hook.place_submodules:
435+
set_module_tensor_to_device(module, param_name, current_device)
436+
437+
# if the parent does not place submodules, then add a hook
438+
# parameters are offloaded by `add_hook_to_module`
439+
if not hook.place_submodules:
440+
weights_map = PrefixedDataset(
441+
hook.weights_map.dataset, prefix=f"{hook.weights_map.prefix}{name}."
442+
)
443+
444+
submodule_hook = AlignDevicesHook(
445+
execution_device=hook.execution_device,
446+
offload=hook.offload,
447+
io_same_device=False,
448+
weights_map=weights_map,
449+
offload_buffers=offload_buffers,
450+
place_submodules=place_submodules,
451+
skip_keys=None,
452+
tied_params_map=hook.tied_params_map,
453+
)
454+
add_hook_to_module(module, submodule_hook)
455+
456+
base.register_module(name, module)
457+
458+
# (1): Since we cannot know which pointers are shared when we add parameters in an
459+
# online way, assume that all pointers are shared. This comes at no runtime cost
460+
461+
462+
def delete_offload_module(base: torch.nn.Module, name: str):
463+
"""
464+
Delete a submodule from a model which may contain offloading
465+
:param base: parent module to delete submodule from
466+
:param name: name of submodule on parent
467+
"""
468+
module: torch.nn.Module = getattr(base, name)
469+
470+
for param_name, _ in list(module.named_parameters()):
471+
delete_offload_parameter(module, param_name)
472+
473+
delattr(base, name)
474+
475+
476+
@check_accelerate(fallback="error")
477+
def force_cpu_offload(
478+
module: torch.nn.Module, execution_device: torch.device
479+
) -> torch.nn.Module:
480+
"""
481+
Force cpu offloading a module, primarily used for testing
482+
483+
:param module: module containing parameters to offload
484+
:param execution_device: execution device submodules
485+
:return: module with hooks to perform cpu offloading
486+
"""
487+
# edge case: there is a bug in `dispatch_model` which causes
488+
# the function to only work if the model contains submodules
489+
if next(module.children(), None) is None:
490+
attach_align_device_hook(
491+
module,
492+
execution_device=execution_device,
493+
offload=True,
494+
weights_map=module.state_dict(),
495+
tied_params_map={},
496+
)
497+
return module
498+
499+
device_map = {}
500+
501+
def collect_device_map(name: List[str], module: torch.nn.Module):
502+
if next(module.parameters(recurse=False), None) is not None:
503+
device_map[".".join(name)] = "cpu"
504+
return
505+
506+
else:
507+
for submodule_name, submodule in module.named_children():
508+
name.append(submodule_name)
509+
collect_device_map(name, submodule)
510+
name.pop()
511+
512+
collect_device_map([], module)
513+
514+
return dispatch_model(
515+
module, device_map, main_device=execution_device, force_hooks=True
516+
)
517+
518+
386519
""" Upstreamed Functions """
387520

388521

tests/test_utils/test_offload.py

Lines changed: 95 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,13 @@
1616
from compressed_tensors.utils import (
1717
align_module_device,
1818
align_modules,
19+
delete_offload_module,
1920
delete_offload_parameter,
2021
disable_hf_hook,
22+
force_cpu_offload,
2123
get_execution_device,
2224
has_offloaded_params,
25+
register_offload_module,
2326
register_offload_parameter,
2427
update_offload_parameter,
2528
)
@@ -37,9 +40,17 @@ def forward(self, x):
3740
return x * self.a + self.b
3841

3942

43+
class ExampleModel(torch.nn.Module):
44+
def __init__(self):
45+
super().__init__()
46+
self.linear = torch.nn.Linear(1, 2)
47+
48+
def forward(self, x):
49+
return self.linear(x)
50+
51+
4052
@requires_accelerate()
4153
def test_has_offloaded_params():
42-
from accelerate.big_modeling import cpu_offload_with_hook
4354
from accelerate.hooks import attach_align_device_hook, remove_hook_from_module
4455

4556
module = ExampleModule()
@@ -48,10 +59,6 @@ def test_has_offloaded_params():
4859
attach_align_device_hook(module, offload=False)
4960
assert not has_offloaded_params(module)
5061

51-
remove_hook_from_module(module)
52-
module, _ = cpu_offload_with_hook(module)
53-
assert not has_offloaded_params(module)
54-
5562
remove_hook_from_module(module)
5663
attach_align_device_hook(module, offload=True, weights_map=module.state_dict())
5764
assert has_offloaded_params(module)
@@ -334,3 +341,86 @@ def test_offload_to_weights_map():
334341
weights_map = PrefixedDataset(OffloadedWeightsLoader({name: old_value}), prefix)
335342
offload_to_weights_map(weights_map, name, new_value)
336343
assert weights_map[name] == new_value
344+
345+
346+
@requires_gpu
347+
@requires_accelerate()
348+
@pytest.mark.parametrize("exec_device", [torch.device("cpu"), torch.device("cuda")])
349+
def test_register_offload_module(exec_device):
350+
# no offloading
351+
model = ExampleModel()
352+
child = torch.nn.Linear(2, 3)
353+
register_offload_module(model, "child", child)
354+
register_offload_module(model.linear, "child", child)
355+
assert child in model.children()
356+
assert child in model.linear.children()
357+
358+
# with offloading
359+
model = ExampleModel()
360+
child = torch.nn.Linear(2, 3)
361+
force_cpu_offload(model, exec_device)
362+
register_offload_module(model, "child", child)
363+
register_offload_module(model.linear, "child", child)
364+
assert child in model.children()
365+
assert child in model.linear.children()
366+
367+
# can run modules
368+
model(torch.empty(1))
369+
child(torch.empty(2, device=exec_device))
370+
371+
372+
@requires_gpu
373+
@requires_accelerate()
374+
@pytest.mark.parametrize("exec_device", [torch.device("cpu"), torch.device("cuda")])
375+
def test_delete_offload_module(exec_device):
376+
# no offloading
377+
model = ExampleModel()
378+
child = torch.nn.Linear(2, 3)
379+
register_offload_module(model, "child", child)
380+
register_offload_module(model.linear, "child", child)
381+
delete_offload_module(model, "child")
382+
delete_offload_module(model.linear, "child")
383+
assert not child in model.children()
384+
assert not child in model.linear.children()
385+
386+
# with offloading
387+
model = ExampleModel()
388+
child = torch.nn.Linear(2, 3)
389+
force_cpu_offload(model, exec_device)
390+
register_offload_module(model, "child", child)
391+
register_offload_module(model.linear, "child", child)
392+
delete_offload_module(model, "child")
393+
delete_offload_module(model.linear, "child")
394+
assert not child in model.children()
395+
assert not child in model.linear.children()
396+
397+
398+
@requires_gpu
399+
@requires_accelerate()
400+
@pytest.mark.parametrize("exec_device", [torch.device("cpu"), torch.device("cuda")])
401+
def test_force_cpu_offload(exec_device):
402+
# single module
403+
module = torch.nn.Linear(1, 2)
404+
module = force_cpu_offload(module, exec_device)
405+
assert has_offloaded_params(module)
406+
assert module._hf_hook.offload
407+
assert module.weight.device == torch.device("meta")
408+
assert "weight" in module._hf_hook.weights_map
409+
assert module._hf_hook.tied_params_map is not None
410+
411+
# can run
412+
module(torch.empty(1, device=exec_device))
413+
414+
# model
415+
model = ExampleModel()
416+
model = force_cpu_offload(model, exec_device)
417+
assert not has_offloaded_params(model)
418+
419+
assert has_offloaded_params(model.linear)
420+
assert model.linear._hf_hook.offload
421+
assert model.linear.weight.device == torch.device("meta")
422+
assert "weight" in model.linear._hf_hook.weights_map
423+
assert model.linear._hf_hook.tied_params_map is not None
424+
425+
# can run
426+
model(torch.empty(1, device=exec_device))

0 commit comments

Comments
 (0)