Skip to content

Commit 54f5b4e

Browse files
authored
rename, simplify (#354)
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent 52e7074 commit 54f5b4e

File tree

4 files changed

+67
-60
lines changed

4 files changed

+67
-60
lines changed

src/compressed_tensors/utils/offload.py

Lines changed: 49 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -14,27 +14,29 @@
1414
"""
1515
Utilities associated with offloading functionality provided by `accelerate`.
1616
17-
| ----------------------------------------------------------------------------------------------------- | # noqa: E501
18-
| Operation | Without offloading support | With offloading support | # noqa: E501
19-
| --------- | -------------------------------------- | ------------------------------------------------ | # noqa: E501
20-
| Add | module.register_parameter(name, param) | register_offload_parameter(module, name, param) | # noqa: E501
21-
| Check | N/A | has_offloaded_params(module) | # noqa: E501
22-
| Onload | N/A | with align_module_device(module) | # noqa: E501
23-
| Update | module.name.data.copy_(new_data) | update_offload_parameter(module, name, new_data) | # noqa: E501
24-
| Delete | del module.name | delete_offload_parameter(module, name) | # noqa: E501
25-
| ----------------------------------------------------------------------------------------------------- | # noqa: E501
17+
| ------------------------------------------------------------------------------------------------------ | # noqa: E501
18+
| Operation | Without offloading support | With offloading support | # noqa: E501
19+
| ---------- | -------------------------------------- | ------------------------------------------------ | # noqa: E501
20+
| Add | module.register_parameter(name, param) | register_offload_parameter(module, name, param) | # noqa: E501
21+
| Check | N/A | has_offloaded_params(module) | # noqa: E501
22+
| Onload | N/A | with align_module_device(module) | # noqa: E501
23+
| Update | module.name.data.copy_(new_data) | update_offload_parameter(module, name, new_data) | # noqa: E501
24+
| Delete | del module.name | delete_offload_parameter(module, name) | # noqa: E501
25+
| Add Module | module.register_module(name, child) | register_offload_module(name, child) | # noqa: E501
26+
| Del Module | del module.name | delete_offload_module(module, name) | # noqa: E501
27+
| ------------------------------------------------------------------------------------------------------ | # noqa: E501
2628
"""
2729

2830
import contextlib
2931
import warnings
3032
from functools import wraps
31-
from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Union
33+
from operator import attrgetter
34+
from typing import Any, Callable, Dict, Iterable, Literal, Optional, Union
3235

3336
import torch
3437

3538

3639
try:
37-
from accelerate import dispatch_model
3840
from accelerate.hooks import (
3941
AlignDevicesHook,
4042
add_hook_to_module,
@@ -45,10 +47,12 @@
4547
from accelerate.utils import (
4648
OffloadedWeightsLoader,
4749
PrefixedDataset,
50+
find_tied_parameters,
4851
set_module_tensor_to_device,
4952
)
5053

5154
_has_accelerate = True
55+
5256
except ImportError:
5357
_has_accelerate = False
5458
AlignDevicesHook = None
@@ -58,8 +62,8 @@
5862
PrefixedDataset = None
5963
set_module_tensor_to_device = None
6064
named_module_tensors = None
61-
dispatch_model = None
6265
attach_align_device_hook = None
66+
find_tied_parameters = None
6367

6468

6569
__all__ = [
@@ -78,14 +82,13 @@
7882
"align_module_device",
7983
"register_offload_module",
8084
"delete_offload_module",
81-
"force_cpu_offload",
85+
"offloaded_dispatch",
8286
]
8387

8488

8589
def check_accelerate(fallback: Any):
8690
def decorator(func: Callable[[Any], Any]):
8791
if not _has_accelerate:
88-
8992
if fallback == "error":
9093

9194
@wraps(func)
@@ -479,46 +482,44 @@ def delete_offload_module(base: torch.nn.Module, name: str):
479482

480483

481484
@check_accelerate(fallback="error")
482-
def force_cpu_offload(
483-
module: torch.nn.Module, execution_device: torch.device
485+
def offloaded_dispatch(
486+
module: torch.nn.Module,
487+
execution_device: torch.device,
488+
offload_device: Union[torch.device, Literal["disk"]] = torch.device("cpu"),
484489
) -> torch.nn.Module:
485490
"""
486-
Force cpu offloading a module, primarily used for testing
491+
Unlike `dispatch_model`, this function forces a module (and its submodules) to
492+
offload all parameters and replace them with meta tensors, utiliizing the
493+
`AlignDevicesHook` to control onloading and offloading.
487494
488495
:param module: module containing parameters to offload
489-
:param execution_device: execution device submodules
490-
:return: module with hooks to perform cpu offloading
496+
:param execution_device: device that modules will onload and execute on
497+
:param offload_device: device that module parameters will offload to
498+
:return: module with offloading device hooks
491499
"""
492-
# edge case: there is a bug in `dispatch_model` which causes
493-
# the function to only work if the model contains submodules
494-
if next(module.children(), None) is None:
495-
attach_align_device_hook(
496-
module,
497-
execution_device=execution_device,
498-
offload=True,
499-
weights_map=module.state_dict(),
500-
tied_params_map={},
501-
)
502-
return module
503-
504-
device_map = {}
505-
506-
def collect_device_map(name: List[str], module: torch.nn.Module):
507-
if next(module.parameters(recurse=False), None) is not None:
508-
device_map[".".join(name)] = "cpu"
509-
return
510-
511-
else:
512-
for submodule_name, submodule in module.named_children():
513-
name.append(submodule_name)
514-
collect_device_map(name, submodule)
515-
name.pop()
516-
517-
collect_device_map([], module)
518-
519-
return dispatch_model(
520-
module, device_map, main_device=execution_device, force_hooks=True
500+
if offload_device == "disk":
501+
raise NotImplementedError("Disk offloading is not currently supported")
502+
503+
# create weights map
504+
weights_map = OffloadedWeightsLoader(state_dict=module.state_dict(), device="cpu")
505+
506+
# create tied params map
507+
tied_params = find_tied_parameters(module)
508+
tied_params_map = {}
509+
for group in tied_params:
510+
for param_name in group:
511+
data_ptr = attrgetter(param_name)(module).data_ptr()
512+
tied_params_map[data_ptr] = {}
513+
514+
# recursively attaches hooks to all submodules
515+
attach_align_device_hook(
516+
module,
517+
execution_device=execution_device,
518+
offload=True,
519+
weights_map=weights_map,
520+
tied_params_map=tied_params_map,
521521
)
522+
return module
522523

523524

524525
""" Upstreamed Functions """

tests/test_transform/factory/test_correctness.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
TransformFactory,
2020
TransformScheme,
2121
)
22-
from compressed_tensors.utils import align_modules, force_cpu_offload
22+
from compressed_tensors.utils import offloaded_dispatch
2323
from tests.testing_utils import requires_accelerate, requires_gpu
2424

2525

@@ -75,7 +75,7 @@ def test_correctness_model(scheme, offload=False):
7575
# load model
7676
model = TransformableModel(2, 4, 8, 16, 32, 64)
7777
if offload:
78-
model = force_cpu_offload(model, torch.device("cuda"))
78+
model = offloaded_dispatch(model, torch.device("cuda"))
7979

8080
# create factory
8181
scheme.apply = [

tests/test_transform/factory/test_memory.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
TransformFactory,
2323
TransformScheme,
2424
)
25-
from compressed_tensors.utils import align_modules, force_cpu_offload
25+
from compressed_tensors.utils import align_modules, offloaded_dispatch
2626
from tests.testing_utils import requires_accelerate, requires_gpu
2727

2828

@@ -58,7 +58,7 @@ def test_memory_sharing(scheme, offload=False):
5858
# load model (maybe with offloading)
5959
model = TransformableModel(2, 2, 4, 4, 8, 8)
6060
if offload:
61-
force_cpu_offload(model, torch.device("cuda"))
61+
offloaded_dispatch(model, torch.device("cuda"))
6262

6363
# add transforms to model
6464
factory.apply_to_model(model)

tests/test_utils/test_offload.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@
1919
delete_offload_module,
2020
delete_offload_parameter,
2121
disable_hf_hook,
22-
force_cpu_offload,
2322
get_execution_device,
2423
has_offloaded_params,
24+
offloaded_dispatch,
2525
register_offload_module,
2626
register_offload_parameter,
2727
update_offload_parameter,
@@ -111,15 +111,15 @@ def test_register_offload_parameter():
111111

112112
# register a param prior to offloading
113113
register_offload_parameter(module, "c", parameter)
114-
assert hasattr(module, "c") and module.c == parameter
114+
assert module.c == parameter
115115

116116
# offloading, check that added param was offloaded
117117
attach_align_device_hook(module, offload=True, weights_map=module.state_dict())
118118
assert "c" in module._hf_hook.weights_map
119119

120120
# register a param after offloading, check that added param was offloaded
121121
register_offload_parameter(module, "d", parameter)
122-
assert hasattr(module, "d") and module.d.device == torch.device("meta")
122+
assert module.d.device == torch.device("meta")
123123
assert module._hf_hook.weights_map["d"].device == torch.device("cpu")
124124

125125
# added parameters can be onloaded and offloaded
@@ -358,7 +358,7 @@ def test_register_offload_module(exec_device):
358358
# with offloading
359359
model = ExampleModel()
360360
child = torch.nn.Linear(2, 3)
361-
force_cpu_offload(model, exec_device)
361+
offloaded_dispatch(model, exec_device)
362362
register_offload_module(model, "child", child)
363363
register_offload_module(model.linear, "child", child)
364364
assert child in model.children()
@@ -386,7 +386,7 @@ def test_delete_offload_module(exec_device):
386386
# with offloading
387387
model = ExampleModel()
388388
child = torch.nn.Linear(2, 3)
389-
force_cpu_offload(model, exec_device)
389+
offloaded_dispatch(model, exec_device)
390390
register_offload_module(model, "child", child)
391391
register_offload_module(model.linear, "child", child)
392392
delete_offload_module(model, "child")
@@ -398,10 +398,10 @@ def test_delete_offload_module(exec_device):
398398
@requires_gpu
399399
@requires_accelerate()
400400
@pytest.mark.parametrize("exec_device", [torch.device("cpu"), torch.device("cuda")])
401-
def test_force_cpu_offload(exec_device):
401+
def test_offloaded_dispatch(exec_device):
402402
# single module
403403
module = torch.nn.Linear(1, 2)
404-
module = force_cpu_offload(module, exec_device)
404+
module = offloaded_dispatch(module, exec_device)
405405
assert has_offloaded_params(module)
406406
assert module._hf_hook.offload
407407
assert module.weight.device == torch.device("meta")
@@ -413,7 +413,7 @@ def test_force_cpu_offload(exec_device):
413413

414414
# model
415415
model = ExampleModel()
416-
model = force_cpu_offload(model, exec_device)
416+
model = offloaded_dispatch(model, exec_device)
417417
assert not has_offloaded_params(model)
418418

419419
assert has_offloaded_params(model.linear)
@@ -424,3 +424,9 @@ def test_force_cpu_offload(exec_device):
424424

425425
# can run
426426
model(torch.empty(1, device=exec_device))
427+
428+
# can add new params
429+
parameter = torch.nn.Parameter(torch.tensor(1.0))
430+
register_offload_parameter(module, "new_param", parameter)
431+
assert module.new_param.device == torch.device("meta")
432+
assert module._hf_hook.weights_map["new_param"].device == torch.device("cpu")

0 commit comments

Comments
 (0)