Skip to content

Commit 4949912

Browse files
authored
[Utils] add align_modules (#282)
* add align_modules Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * better implementation Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * add align_modules Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * add docstrings Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * remove comment Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * docstring and typo Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> --------- Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent 60d78a5 commit 4949912

File tree

2 files changed

+70
-1
lines changed

2 files changed

+70
-1
lines changed

src/compressed_tensors/utils/offload.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
import contextlib
2929
import warnings
3030
from functools import wraps
31-
from typing import Any, Callable, Dict, Literal, Optional, Union
31+
from typing import Any, Callable, Dict, Iterable, Literal, Optional, Union
3232

3333
import torch
3434

@@ -67,6 +67,8 @@
6767
"delete_offload_parameter",
6868
"has_offloaded_params",
6969
"disable_hf_hook",
70+
"disable_offload",
71+
"align_modules",
7072
"align_module_device",
7173
]
7274

@@ -344,6 +346,43 @@ def delete_from_weights_map(
344346
)
345347

346348

349+
@contextlib.contextmanager
350+
def disable_offload(module: torch.nn.Module):
351+
"""
352+
Context manager to disable module onloading and offloading. Parameters will stay on
353+
their current device
354+
355+
:param module: module to disable offloading for
356+
"""
357+
if has_offloaded_params(module):
358+
module._hf_hook.offload = False
359+
yield
360+
module._hf_hook.offload = True
361+
else:
362+
yield
363+
364+
365+
@contextlib.contextmanager
366+
def align_modules(
367+
modules: Union[torch.nn.Module, Iterable[torch.nn.Module]],
368+
execution_device: Optional[torch.device] = None,
369+
):
370+
"""
371+
Context manager for onloading modules to a device, and disabling onload and offload
372+
attempts triggered by forward calls. Used for sequential onloading of layers
373+
374+
:param modules: `torch.nn.Module` or iterable of `torch.nn.Module`s to onload
375+
:param execution_device: device to onload to
376+
"""
377+
modules = (modules,) if isinstance(modules, torch.nn.Module) else modules
378+
379+
with contextlib.ExitStack() as stack:
380+
for module in modules:
381+
stack.enter_context(align_module_device(module, execution_device))
382+
stack.enter_context(disable_offload(module)) # disable redundant onloading
383+
yield
384+
385+
347386
""" Upstreamed Functions """
348387

349388

tests/test_utils/test_offload.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import torch
1616
from compressed_tensors.utils import (
1717
align_module_device,
18+
align_modules,
1819
delete_offload_parameter,
1920
disable_hf_hook,
2021
get_execution_device,
@@ -248,6 +249,35 @@ def test_disable_hf_hook_model_recurse():
248249
assert hasattr(module2, "_hf_hook")
249250

250251

252+
@requires_accelerate()
253+
def test_align_modules():
254+
from accelerate.hooks import attach_align_device_hook
255+
256+
module0 = ExampleModule()
257+
module1 = ExampleModule()
258+
module2 = ExampleModule()
259+
model = torch.nn.Sequential(module0, torch.nn.Sequential(module1, module2))
260+
attach_align_device_hook(
261+
model,
262+
execution_device=torch.device("cpu"),
263+
offload=True,
264+
weights_map=model.state_dict(),
265+
)
266+
267+
assert module0.a.device == torch.device("meta")
268+
assert module1.a.device == torch.device("meta")
269+
assert module2.a.device == torch.device("meta")
270+
271+
with align_modules((module0, module1)):
272+
assert module0.a.device != torch.device("meta")
273+
assert module1.a.device != torch.device("meta")
274+
assert module2.a.device == torch.device("meta")
275+
276+
assert module0.a.device == torch.device("meta")
277+
assert module1.a.device == torch.device("meta")
278+
assert module2.a.device == torch.device("meta")
279+
280+
251281
@requires_accelerate()
252282
def test_offload_to_weights_map():
253283
from accelerate.utils import OffloadedWeightsLoader, PrefixedDataset

0 commit comments

Comments
 (0)