Skip to content

Commit 955f2f5

Browse files
committed
Merge
1 parent e7f08e1 commit 955f2f5

File tree

5 files changed

+278
-82
lines changed

5 files changed

+278
-82
lines changed

.github/workflows/report.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ jobs:
120120
shell: bash
121121

122122
- name: report to reportportal
123-
uses: neuralmagic/nm-actions/actions/reportportal_submit_execution_results@v1.15.0
123+
uses: neuralmagic/nm-actions/actions/reportportal_submit_execution_results@v1.22.0
124124
with:
125125
droute_username: ${{ secrets.DROUTE_USERNAME }}
126126
droute_password: ${{ secrets.DROUTE_PASSWORD }}

src/compressed_tensors/utils/offload.py

Lines changed: 147 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -14,27 +14,30 @@
1414
"""
1515
Utilities associated with offloading functionality provided by `accelerate`.
1616
17-
| ----------------------------------------------------------------------------------------------------- | # noqa: E501
18-
| Operation | Without offloading support | With offloading support | # noqa: E501
19-
| --------- | -------------------------------------- | ------------------------------------------------ | # noqa: E501
20-
| Add | module.register_parameter(name, param) | register_offload_parameter(module, name, param) | # noqa: E501
21-
| Check | N/A | has_offloaded_params(module) | # noqa: E501
22-
| Onload | N/A | with align_module_device(module) | # noqa: E501
23-
| Update | module.name.data.copy_(new_data) | update_offload_parameter(module, name, new_data) | # noqa: E501
24-
| Delete | del module.name | delete_offload_parameter(module, name) | # noqa: E501
25-
| ----------------------------------------------------------------------------------------------------- | # noqa: E501
17+
| ------------------------------------------------------------------------------------------------------ | # noqa: E501
18+
| Operation | Without offloading support | With offloading support | # noqa: E501
19+
| ---------- | -------------------------------------- | ------------------------------------------------ | # noqa: E501
20+
| Add | module.register_parameter(name, param) | register_offload_parameter(module, name, param) | # noqa: E501
21+
| Check | N/A | has_offloaded_params(module) | # noqa: E501
22+
| Onload | N/A | with align_module_device(module) | # noqa: E501
23+
| Update | module.name.data.copy_(new_data) | update_offload_parameter(module, name, new_data) | # noqa: E501
24+
| Delete | del module.name | delete_offload_parameter(module, name) | # noqa: E501
25+
| Add Module | module.register_module(name, child) | register_offload_module(name, child) | # noqa: E501
26+
| Del Module | del module.name | delete_offload_module(module, name) | # noqa: E501
27+
| ------------------------------------------------------------------------------------------------------ | # noqa: E501
2628
"""
2729

2830
import contextlib
2931
import warnings
3032
from functools import wraps
31-
from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Union
33+
from operator import attrgetter
34+
from typing import Any, Callable, Dict, Iterable, Literal, Optional, Tuple, Union
3235

3336
import torch
37+
from compressed_tensors.utils import patch_attr
3438

3539

3640
try:
37-
from accelerate import dispatch_model
3841
from accelerate.hooks import (
3942
AlignDevicesHook,
4043
add_hook_to_module,
@@ -45,10 +48,12 @@
4548
from accelerate.utils import (
4649
OffloadedWeightsLoader,
4750
PrefixedDataset,
51+
find_tied_parameters,
4852
set_module_tensor_to_device,
4953
)
5054

5155
_has_accelerate = True
56+
5257
except ImportError:
5358
_has_accelerate = False
5459
AlignDevicesHook = None
@@ -58,8 +63,8 @@
5863
PrefixedDataset = None
5964
set_module_tensor_to_device = None
6065
named_module_tensors = None
61-
dispatch_model = None
6266
attach_align_device_hook = None
67+
find_tied_parameters = None
6368

6469

6570
__all__ = [
@@ -78,14 +83,15 @@
7883
"align_module_device",
7984
"register_offload_module",
8085
"delete_offload_module",
81-
"force_cpu_offload",
86+
"offloaded_dispatch",
87+
"disable_offloading",
88+
"remove_dispatch",
8289
]
8390

8491

8592
def check_accelerate(fallback: Any):
8693
def decorator(func: Callable[[Any], Any]):
8794
if not _has_accelerate:
88-
8995
if fallback == "error":
9096

9197
@wraps(func)
@@ -165,22 +171,22 @@ def update_parameter_data(
165171

166172
def get_execution_device(module: torch.nn.Module) -> torch.device:
167173
"""
168-
Get the device which inputs should be moved to before module execution
174+
Get the device which inputs should be moved to before module execution.
175+
Assume that modules execute in the same order as returned by `model.modules()`
169176
170177
:param module: module to check, may be offloaded
171178
:return: onload device of module
172179
"""
173-
if has_offloaded_params(module):
174-
return module._hf_hook.execution_device
180+
for submodule in module.modules():
181+
if has_offloaded_params(submodule):
182+
return submodule._hf_hook.execution_device
175183

176-
first_param = next(module.parameters(), None)
177-
if first_param is None:
178-
warnings.warn(
179-
f"Unable able to infer execution device of {module}, falling back to CPU"
180-
)
181-
return torch.device("cpu")
184+
param = next(submodule.parameters(recurse=False), None)
185+
if param is not None:
186+
return param.device
182187

183-
return first_param.device
188+
warnings.warn(f"Unable to get execution device of {module}, falling back to CPU")
189+
return torch.device("cpu")
184190

185191

186192
def register_offload_parameter(
@@ -201,17 +207,32 @@ def register_offload_parameter(
201207
has_onload = any(p.device != torch.device("meta") for p in module.parameters())
202208
module.register_parameter(name, parameter)
203209

210+
# do everything AlignDevicesHook.init_hook does
211+
# https://github.com/huggingface/accelerate/blob/main/src/accelerate/hooks.py#L281
204212
if has_offloaded_params(module):
205-
weights_map = module._hf_hook.weights_map
206-
offload_to_weights_map(weights_map, name, parameter.data, offload_device)
213+
hook: AlignDevicesHook = module._hf_hook
214+
assert hook.weights_map is not None
215+
216+
# append to original_devices
217+
hook.original_devices[name] = parameter.device
218+
219+
# append to weights map
220+
offload_to_weights_map(hook.weights_map, name, parameter.data, offload_device)
221+
222+
# append to tied_params_map
223+
offloaded = hook.weights_map[name]
224+
if hook.tied_params_map is not None:
225+
hook.tied_params_map[offloaded.data_ptr()] = {} # (1)
226+
227+
# perform offloading
207228
if not has_onload:
208229
set_module_tensor_to_device(module, name, "meta")
209230

210231

211232
def update_offload_parameter(
212233
module: torch.nn.Module,
213234
name: str,
214-
data: Optional[torch.Tensor],
235+
data: torch.Tensor,
215236
offload_device: Optional[Union[torch.device, Literal["disk"]]] = None,
216237
):
217238
"""
@@ -224,15 +245,15 @@ def update_offload_parameter(
224245
:param offload_device: device on which weight will be offloaded to. If None is
225246
provided, then infer device from parameters on module
226247
"""
227-
param = getattr(module, name)
248+
param: torch.nn.Parameter = getattr(module, name)
228249
if param.data.shape != data.shape:
229250
warnings.warn(
230251
f"Shape of parameter being updated {param.data.shape} does not match shape "
231252
f"of update data {data.shape}"
232253
)
233254

234255
# copy data into onloaded parameter if applicable
235-
if param.device != torch.device("meta"):
256+
if param.device != torch.device("meta") and data is not param.data:
236257
param.data.copy_(data)
237258

238259
# update offload dict
@@ -417,7 +438,6 @@ def register_offload_module(base: torch.nn.Module, name: str, module: torch.nn.M
417438
hook: AlignDevicesHook = base._hf_hook
418439
assert hook.offload
419440
assert hook.weights_map is not None
420-
assert hook.tied_params_map is not None
421441

422442
# offloading kwargs for submodule
423443
place_submodules = False
@@ -432,7 +452,8 @@ def register_offload_module(base: torch.nn.Module, name: str, module: torch.nn.M
432452
module, include_buffers=offload_buffers, recurse=place_submodules
433453
):
434454
offloaded = param.to(offload_device)
435-
hook.tied_params_map[offloaded.data_ptr()] = {} # (1)
455+
if hook.tied_params_map is not None:
456+
hook.tied_params_map[offloaded.data_ptr()] = {} # (1)
436457
offload_to_weights_map(hook.weights_map, f"{name}.{param_name}", offloaded)
437458

438459
# if the parent places submodules, offload here
@@ -460,9 +481,6 @@ def register_offload_module(base: torch.nn.Module, name: str, module: torch.nn.M
460481

461482
base.register_module(name, module)
462483

463-
# (1): Since we cannot know which pointers are shared when we add parameters in an
464-
# online way, assume that all pointers are shared. This comes at no runtime cost
465-
466484

467485
def delete_offload_module(base: torch.nn.Module, name: str):
468486
"""
@@ -479,46 +497,106 @@ def delete_offload_module(base: torch.nn.Module, name: str):
479497

480498

481499
@check_accelerate(fallback="error")
482-
def force_cpu_offload(
483-
module: torch.nn.Module, execution_device: torch.device
500+
def offloaded_dispatch(
501+
module: torch.nn.Module,
502+
execution_device: torch.device,
503+
offload_device: Union[torch.device, Literal["disk"]] = torch.device("cpu"),
484504
) -> torch.nn.Module:
485505
"""
486-
Force cpu offloading a module, primarily used for testing
506+
Unlike `dispatch_model`, this function forces a module (and its submodules) to
507+
offload all parameters and replace them with meta tensors, utiliizing the
508+
`AlignDevicesHook` to control onloading and offloading.
487509
488510
:param module: module containing parameters to offload
489-
:param execution_device: execution device submodules
490-
:return: module with hooks to perform cpu offloading
491-
"""
492-
# edge case: there is a bug in `dispatch_model` which causes
493-
# the function to only work if the model contains submodules
494-
if next(module.children(), None) is None:
495-
attach_align_device_hook(
496-
module,
497-
execution_device=execution_device,
498-
offload=True,
499-
weights_map=module.state_dict(),
500-
tied_params_map={},
501-
)
502-
return module
511+
:param execution_device: device that modules will onload and execute on
512+
:param offload_device: device that module parameters will offload to
513+
:return: module with offloading device hooks
514+
"""
515+
if offload_device == "disk":
516+
raise NotImplementedError("Disk offloading is not currently supported")
517+
518+
# remove any existing hooks
519+
remove_dispatch(module)
520+
521+
# create weights map
522+
state_dict = module.state_dict()
523+
state_dict = {key: val.to(offload_device) for key, val in state_dict.items()}
524+
weights_map = OffloadedWeightsLoader(state_dict=state_dict, device=offload_device)
525+
526+
# create tied params map
527+
tied_params = find_tied_parameters(module)
528+
tied_params_map = {}
529+
for group in tied_params:
530+
for param_name in group:
531+
data_ptr = attrgetter(param_name)(module).data_ptr()
532+
tied_params_map[data_ptr] = {}
533+
534+
# recursively attaches hooks to all submodules
535+
attach_align_device_hook(
536+
module,
537+
execution_device=execution_device,
538+
offload=True,
539+
weights_map=weights_map,
540+
tied_params_map=tied_params_map,
541+
)
503542

504-
device_map = {}
543+
# when saving a model, `PretrainedModel.save_pretrained` will only
544+
# onload weights if the following requirements are met
545+
# if (
546+
# hasattr(self, "hf_device_map")
547+
# and len(set(self.hf_device_map.values())) > 1
548+
# and ("cpu" in self.hf_device_map.values()
549+
# or "disk" in self.hf_device_map.values())
550+
# ):
551+
# because this function always offloads, disregard actual devices and
552+
# always use `cpu` and `cuda:0` to guarantee this condition passes
553+
setattr(module, "hf_device_map", {"fake_offload": "cpu", "fake_exec": "cuda:0"})
505554

506-
def collect_device_map(name: List[str], module: torch.nn.Module):
507-
if next(module.parameters(recurse=False), None) is not None:
508-
device_map[".".join(name)] = "cpu"
509-
return
555+
return module
510556

511-
else:
512-
for submodule_name, submodule in module.named_children():
513-
name.append(submodule_name)
514-
collect_device_map(name, submodule)
515-
name.pop()
516557

517-
collect_device_map([], module)
558+
def remove_dispatch(module: torch.nn.Module) -> torch.nn.Module:
559+
"""
560+
Remove any existing dispatches from module
561+
562+
:param module: module which may be dispatched with hf hooks
563+
:return: module without dispatch
564+
"""
565+
remove_hook_from_module(module, recurse=True)
566+
if hasattr(module, "hf_device_map"):
567+
delattr(module, "hf_device_map")
568+
569+
return module
518570

519-
return dispatch_model(
520-
module, device_map, main_device=execution_device, force_hooks=True
521-
)
571+
572+
@contextlib.contextmanager
573+
def disable_offloading():
574+
"""
575+
Keep modules onloaded and disable offloading until this context exits.
576+
Affects modules which have been hooked with accelerate's `AlignDevicesHook`
577+
"""
578+
original_pre_forward = AlignDevicesHook.pre_forward
579+
onloaded_modules: Dict[torch.nn.Module, Tuple[AlignDevicesHook, bool]] = dict()
580+
581+
# onload once and disable any future onloading/offloading steps
582+
def keep_onload_pre_forward(self: AlignDevicesHook, module, *args, **kwargs):
583+
ret = original_pre_forward(self, module, *args, **kwargs)
584+
if module not in onloaded_modules:
585+
onloaded_modules[module] = (self, self.offload)
586+
self.offload = False
587+
return ret
588+
589+
# use the patched pre_forward function within the context
590+
with patch_attr(AlignDevicesHook, "pre_forward", keep_onload_pre_forward):
591+
yield
592+
593+
# manually offload all modules that were onloaded
594+
# update any parameters which may have changed
595+
for module, (hook, offload) in onloaded_modules.items():
596+
hook.offload = offload
597+
for name, param in module.named_parameters(recurse=False):
598+
update_offload_parameter(module, name, param.data)
599+
hook.post_forward(module, None)
522600

523601

524602
""" Upstreamed Functions """
@@ -588,3 +666,7 @@ def align_module_device(
588666

589667
else:
590668
yield
669+
670+
671+
# (1): Since we cannot know which pointers are shared when we add parameters in an
672+
# online way, assume that all pointers are shared. This has virtually no runtime cost

tests/test_transform/factory/test_correctness.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
TransformScheme,
2222
apply_transform_config,
2323
)
24-
from compressed_tensors.utils import force_cpu_offload
24+
from compressed_tensors.utils import offloaded_dispatch
2525
from tests.testing_utils import requires_accelerate, requires_gpu
2626

2727

@@ -65,7 +65,7 @@ def test_correctness_model(scheme_kwargs, model_apply, offload=False):
6565
# load model
6666
model = model_apply[0]
6767
if offload:
68-
model = force_cpu_offload(model, torch.device("cuda"))
68+
model = offloaded_dispatch(model, torch.device("cuda"))
6969

7070
# get output
7171
input = torch.rand((17, model.fcs[0].in_features))

tests/test_transform/factory/test_memory.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
TransformScheme,
2525
apply_transform_config,
2626
)
27-
from compressed_tensors.utils import align_modules, force_cpu_offload
27+
from compressed_tensors.utils import align_modules, offloaded_dispatch
2828
from tests.test_transform.conftest import TransformableModel
2929
from tests.testing_utils import requires_accelerate, requires_gpu
3030

@@ -41,7 +41,7 @@ def test_memory_sharing(scheme_kwargs, offload=False):
4141
# load model (maybe with offloading)
4242
model = TransformableModel(2, 2, 4, 4, 8, 8)
4343
if offload:
44-
force_cpu_offload(model, torch.device("cuda"))
44+
offloaded_dispatch(model, torch.device("cuda"))
4545

4646
# add transforms to model
4747
config = TransformConfig(

0 commit comments

Comments
 (0)