Skip to content

Commit d8a10ec

Browse files
committed
add utilities
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent 3f5705d commit d8a10ec

File tree

1 file changed

+97
-1
lines changed

1 file changed

+97
-1
lines changed

src/compressed_tensors/utils/offload.py

Lines changed: 97 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,17 @@
2828
import contextlib
2929
import warnings
3030
from functools import wraps
31-
from typing import Any, Callable, Dict, Iterable, Literal, Optional, Union
31+
from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Union
3232

3333
import torch
3434

3535

3636
try:
37+
from accelerate import dispatch_model
3738
from accelerate.hooks import (
3839
AlignDevicesHook,
3940
add_hook_to_module,
41+
named_module_tensors,
4042
remove_hook_from_module,
4143
)
4244
from accelerate.utils import (
@@ -54,6 +56,8 @@
5456
OffloadedWeightsLoader = None
5557
PrefixedDataset = None
5658
set_module_tensor_to_device = None
59+
named_module_tensors = None
60+
dispatch_model = None
5761

5862

5963
__all__ = [
@@ -70,13 +74,20 @@
7074
"disable_offload",
7175
"align_modules",
7276
"align_module_device",
77+
"register_offload_module",
78+
"force_cpu_offload",
7379
]
7480

7581

7682
def check_accelerate(fallback: Any):
7783
def decorator(func: Callable[[Any], Any]):
7884
if not _has_accelerate:
7985

86+
if fallback == "error":
87+
raise ValueError(
88+
"Please install `accelerate` in order to use this function"
89+
)
90+
8091
@wraps(func)
8192
def fallback_fn(*args, **kwargs):
8293
return fallback
@@ -346,6 +357,7 @@ def delete_from_weights_map(
346357
)
347358

348359

360+
@check_accelerate(fallback=contextlib.nullcontext())
349361
@contextlib.contextmanager
350362
def disable_offload(module: torch.nn.Module):
351363
"""
@@ -362,6 +374,7 @@ def disable_offload(module: torch.nn.Module):
362374
yield
363375

364376

377+
@check_accelerate(fallback=contextlib.nullcontext())
365378
@contextlib.contextmanager
366379
def align_modules(
367380
modules: Union[torch.nn.Module, Iterable[torch.nn.Module]],
@@ -383,6 +396,89 @@ def align_modules(
383396
yield
384397

385398

399+
@check_accelerate(fallback=None)
400+
def register_offload_module(base: torch.nn.Module, name: str, module: torch.nn.Module):
401+
"""
402+
Register a submodule with offloading if the parent module is offloaded
403+
404+
:param base: module to attach submodule to
405+
:param name: name of submodule
406+
:param module: submodule to attach
407+
"""
408+
409+
if has_offloaded_params(base):
410+
hook: AlignDevicesHook = base._hf_hook
411+
assert hook.offload
412+
assert hook.weights_map is not None
413+
assert hook.tied_params_map is not None
414+
415+
# offloading kwargs for submodule
416+
place_submodules = False
417+
offload_buffers = True
418+
419+
# copy device offloading arguments from parent
420+
current_device = next(base.parameters()).device # assume base has parameters
421+
offload_device = get_offloaded_device(base)
422+
423+
# offload parameters to weights map
424+
for param_name, param in named_module_tensors(
425+
module, include_buffers=offload_buffers, recurse=place_submodules
426+
):
427+
offloaded = param.to(offload_device)
428+
hook.tied_params_map[offloaded.data_ptr()] = {} # (1)
429+
offload_to_weights_map(hook.weights_map, f"{name}.{param_name}", offloaded)
430+
431+
# if the parent places submodules, offload here
432+
if hook.place_submodules:
433+
set_module_tensor_to_device(module, param_name, current_device)
434+
435+
# if the parent does not place submodules, then add a hook
436+
# parameters are offloaded by `add_hook_to_module`
437+
if not hook.place_submodules:
438+
weights_map = PrefixedDataset(
439+
hook.weights_map.dataset, prefix=f"{hook.weights_map.prefix}{name}."
440+
)
441+
442+
submodule_hook = AlignDevicesHook(
443+
execution_device=hook.execution_device,
444+
offload=hook.offload,
445+
io_same_device=False,
446+
weights_map=weights_map,
447+
offload_buffers=offload_buffers,
448+
place_submodules=place_submodules,
449+
skip_keys=None,
450+
tied_params_map=hook.tied_params_map,
451+
)
452+
add_hook_to_module(module, submodule_hook)
453+
454+
base.register_module(name, module)
455+
456+
# (1): Since we cannot know which pointers are shared when we add parameters in an
457+
# online way, assume that all pointers are shared. This comes at no runtime cost
458+
459+
460+
@check_accelerate(fallback="error")
461+
def force_cpu_offload(module: torch.nn.Module, execution_device: torch.device):
462+
device_map = {}
463+
464+
def dfs(name: List[str], module: torch.nn.Module):
465+
if next(module.parameters(recurse=False), None) is not None:
466+
device_map[".".join(name)] = "cpu"
467+
return
468+
469+
else:
470+
for submodule_name, submodule in module.named_children():
471+
name.append(submodule_name)
472+
dfs(name, submodule)
473+
name.pop()
474+
475+
dfs([], module)
476+
477+
return dispatch_model(
478+
module, device_map, main_device=execution_device, force_hooks=True
479+
)
480+
481+
386482
""" Upstreamed Functions """
387483

388484

0 commit comments

Comments
 (0)