Skip to content

Commit d77bcef

Browse files
committed
Merge branch 'kylesayrs/transform-accelerate-utilities' into kylesayrs/transform_factory
2 parents 809e367 + 57d171a commit d77bcef

File tree

2 files changed

+51
-12
lines changed

2 files changed

+51
-12
lines changed

src/compressed_tensors/utils/offload.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@
7777
"align_modules",
7878
"align_module_device",
7979
"register_offload_module",
80+
"delete_offload_module",
8081
"force_cpu_offload",
8182
]
8283

@@ -398,7 +399,6 @@ def align_modules(
398399
yield
399400

400401

401-
@check_accelerate(fallback=None)
402402
def register_offload_module(base: torch.nn.Module, name: str, module: torch.nn.Module):
403403
"""
404404
Register a submodule with offloading if the parent module is offloaded
@@ -459,6 +459,20 @@ def register_offload_module(base: torch.nn.Module, name: str, module: torch.nn.M
459459
# online way, assume that all pointers are shared. This comes at no runtime cost
460460

461461

462+
def delete_offload_module(base: torch.nn.Module, name: str):
463+
"""
464+
Delete a submodule from a model which may contain offloading
465+
:param base: parent module to delete submodule from
466+
:param name: name of submodule on parent
467+
"""
468+
module: torch.nn.Module = getattr(base, name)
469+
470+
for param_name, _ in list(module.named_parameters()):
471+
delete_offload_parameter(module, param_name)
472+
473+
delattr(base, name)
474+
475+
462476
@check_accelerate(fallback="error")
463477
def force_cpu_offload(
464478
module: torch.nn.Module, execution_device: torch.device

tests/test_utils/test_offload.py

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from compressed_tensors.utils import (
1717
align_module_device,
1818
align_modules,
19+
delete_offload_module,
1920
delete_offload_parameter,
2021
disable_hf_hook,
2122
force_cpu_offload,
@@ -344,9 +345,8 @@ def test_offload_to_weights_map():
344345

345346
@requires_gpu
346347
@requires_accelerate()
347-
def test_register_offload_module():
348-
execution_device = torch.device("cuda")
349-
348+
@pytest.mark.parametrize("exec_device", [torch.device("cpu"), torch.device("cuda")])
349+
def test_register_offload_module(exec_device):
350350
# no offloading
351351
model = ExampleModel()
352352
child = torch.nn.Linear(2, 3)
@@ -358,37 +358,62 @@ def test_register_offload_module():
358358
# with offloading
359359
model = ExampleModel()
360360
child = torch.nn.Linear(2, 3)
361-
force_cpu_offload(model, execution_device)
361+
force_cpu_offload(model, exec_device)
362362
register_offload_module(model, "child", child)
363363
register_offload_module(model.linear, "child", child)
364364
assert child in model.children()
365365
assert child in model.linear.children()
366366

367367
# can run modules
368368
model(torch.empty(1))
369-
child(torch.empty(2, device=execution_device))
369+
child(torch.empty(2, device=exec_device))
370370

371371

372372
@requires_gpu
373373
@requires_accelerate()
374-
def test_force_cpu_offload():
375-
execution_device = torch.device("cuda")
374+
@pytest.mark.parametrize("exec_device", [torch.device("cpu"), torch.device("cuda")])
375+
def test_delete_offload_module(exec_device):
376+
# no offloading
377+
model = ExampleModel()
378+
child = torch.nn.Linear(2, 3)
379+
register_offload_module(model, "child", child)
380+
register_offload_module(model.linear, "child", child)
381+
delete_offload_module(model, "child")
382+
delete_offload_module(model.linear, "child")
383+
assert not child in model.children()
384+
assert not child in model.linear.children()
376385

386+
# with offloading
387+
model = ExampleModel()
388+
child = torch.nn.Linear(2, 3)
389+
force_cpu_offload(model, exec_device)
390+
register_offload_module(model, "child", child)
391+
register_offload_module(model.linear, "child", child)
392+
delete_offload_module(model, "child")
393+
delete_offload_module(model.linear, "child")
394+
assert not child in model.children()
395+
assert not child in model.linear.children()
396+
397+
398+
@requires_gpu
399+
@requires_accelerate()
400+
@pytest.mark.parametrize("exec_device", [torch.device("cpu"), torch.device("cuda")])
401+
def test_force_cpu_offload(exec_device):
377402
# single module
378403
module = torch.nn.Linear(1, 2)
379-
module = force_cpu_offload(module, execution_device)
404+
module = force_cpu_offload(module, exec_device)
380405
assert has_offloaded_params(module)
381406
assert module._hf_hook.offload
382407
assert module.weight.device == torch.device("meta")
383408
assert "weight" in module._hf_hook.weights_map
384409
assert module._hf_hook.tied_params_map is not None
385410

386411
# can run
387-
module(torch.empty(1, device=execution_device))
412+
module(torch.empty(1, device=exec_device))
388413

389414
# model
390415
model = ExampleModel()
391-
model = force_cpu_offload(model, execution_device)
416+
model = force_cpu_offload(model, exec_device)
392417
assert not has_offloaded_params(model)
393418

394419
assert has_offloaded_params(model.linear)
@@ -398,4 +423,4 @@ def test_force_cpu_offload():
398423
assert model.linear._hf_hook.tied_params_map is not None
399424

400425
# can run
401-
model(torch.empty(1, device=execution_device))
426+
model(torch.empty(1, device=exec_device))

0 commit comments

Comments
 (0)