Skip to content

Commit 57d171a

Browse files
committed
add delete_offload_module
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent e32d5b5 commit 57d171a

File tree

2 files changed

+42
-1
lines changed

2 files changed

+42
-1
lines changed

src/compressed_tensors/utils/offload.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@
7777
"align_modules",
7878
"align_module_device",
7979
"register_offload_module",
80+
"delete_offload_module",
8081
"force_cpu_offload",
8182
]
8283

@@ -398,7 +399,6 @@ def align_modules(
398399
yield
399400

400401

401-
@check_accelerate(fallback=None)
402402
def register_offload_module(base: torch.nn.Module, name: str, module: torch.nn.Module):
403403
"""
404404
Register a submodule with offloading if the parent module is offloaded
@@ -459,6 +459,20 @@ def register_offload_module(base: torch.nn.Module, name: str, module: torch.nn.M
459459
# online way, assume that all pointers are shared. This comes at no runtime cost
460460

461461

462+
def delete_offload_module(base: torch.nn.Module, name: str):
463+
"""
464+
Delete a submodule from a model which may contain offloading
465+
:param base: parent module to delete submodule from
466+
:param name: name of submodule on parent
467+
"""
468+
module: torch.nn.Module = getattr(base, name)
469+
470+
for param_name, _ in list(module.named_parameters()):
471+
delete_offload_parameter(module, param_name)
472+
473+
delattr(base, name)
474+
475+
462476
@check_accelerate(fallback="error")
463477
def force_cpu_offload(
464478
module: torch.nn.Module, execution_device: torch.device

tests/test_utils/test_offload.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from compressed_tensors.utils import (
1717
align_module_device,
1818
align_modules,
19+
delete_offload_module,
1920
delete_offload_parameter,
2021
disable_hf_hook,
2122
force_cpu_offload,
@@ -368,6 +369,32 @@ def test_register_offload_module(exec_device):
368369
child(torch.empty(2, device=exec_device))
369370

370371

372+
@requires_gpu
373+
@requires_accelerate()
374+
@pytest.mark.parametrize("exec_device", [torch.device("cpu"), torch.device("cuda")])
375+
def test_delete_offload_module(exec_device):
376+
# no offloading
377+
model = ExampleModel()
378+
child = torch.nn.Linear(2, 3)
379+
register_offload_module(model, "child", child)
380+
register_offload_module(model.linear, "child", child)
381+
delete_offload_module(model, "child")
382+
delete_offload_module(model.linear, "child")
383+
assert not child in model.children()
384+
assert not child in model.linear.children()
385+
386+
# with offloading
387+
model = ExampleModel()
388+
child = torch.nn.Linear(2, 3)
389+
force_cpu_offload(model, exec_device)
390+
register_offload_module(model, "child", child)
391+
register_offload_module(model.linear, "child", child)
392+
delete_offload_module(model, "child")
393+
delete_offload_module(model.linear, "child")
394+
assert not child in model.children()
395+
assert not child in model.linear.children()
396+
397+
371398
@requires_gpu
372399
@requires_accelerate()
373400
@pytest.mark.parametrize("exec_device", [torch.device("cpu"), torch.device("cuda")])

0 commit comments

Comments
 (0)