|
14 | 14 | """
|
15 | 15 | Utilities associated with offloading functionality provided by `accelerate`.
|
16 | 16 |
|
17 |
| -| ----------------------------------------------------------------------------------------------------- | # noqa: E501 |
18 |
| -| Operation | Without offloading support | With offloading support | # noqa: E501 |
19 |
| -| --------- | -------------------------------------- | ------------------------------------------------ | # noqa: E501 |
20 |
| -| Add | module.register_parameter(name, param) | register_offload_parameter(module, name, param) | # noqa: E501 |
21 |
| -| Check | N/A | has_offloaded_params(module) | # noqa: E501 |
22 |
| -| Onload | N/A | with align_module_device(module) | # noqa: E501 |
23 |
| -| Update | module.name.data.copy_(new_data) | update_offload_parameter(module, name, new_data) | # noqa: E501 |
24 |
| -| Delete | del module.name | delete_offload_parameter(module, name) | # noqa: E501 |
25 |
| -| ----------------------------------------------------------------------------------------------------- | # noqa: E501 |
| 17 | +| ------------------------------------------------------------------------------------------------------ | # noqa: E501 |
| 18 | +| Operation | Without offloading support | With offloading support | # noqa: E501 |
| 19 | +| ---------- | -------------------------------------- | ------------------------------------------------ | # noqa: E501 |
| 20 | +| Add | module.register_parameter(name, param) | register_offload_parameter(module, name, param) | # noqa: E501 |
| 21 | +| Check | N/A | has_offloaded_params(module) | # noqa: E501 |
| 22 | +| Onload | N/A | with align_module_device(module) | # noqa: E501 |
| 23 | +| Update | module.name.data.copy_(new_data) | update_offload_parameter(module, name, new_data) | # noqa: E501 |
| 24 | +| Delete | del module.name | delete_offload_parameter(module, name) | # noqa: E501 |
| 25 | +| Add Module | module.register_module(name, child) | register_offload_module(name, child) | # noqa: E501 |
| 26 | +| Del Module | del module.name | delete_offload_module(module, name) | # noqa: E501 |
| 27 | +| ------------------------------------------------------------------------------------------------------ | # noqa: E501 |
26 | 28 | """
|
27 | 29 |
|
28 | 30 | import contextlib
|
29 | 31 | import warnings
|
30 | 32 | from functools import wraps
|
31 |
| -from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Union |
| 33 | +from operator import attrgetter |
| 34 | +from typing import Any, Callable, Dict, Iterable, Literal, Optional, Union |
32 | 35 |
|
33 | 36 | import torch
|
34 | 37 |
|
35 | 38 |
|
36 | 39 | try:
|
37 |
| - from accelerate import dispatch_model |
38 | 40 | from accelerate.hooks import (
|
39 | 41 | AlignDevicesHook,
|
40 | 42 | add_hook_to_module,
|
|
45 | 47 | from accelerate.utils import (
|
46 | 48 | OffloadedWeightsLoader,
|
47 | 49 | PrefixedDataset,
|
| 50 | + find_tied_parameters, |
48 | 51 | set_module_tensor_to_device,
|
49 | 52 | )
|
50 | 53 |
|
51 | 54 | _has_accelerate = True
|
| 55 | + |
52 | 56 | except ImportError:
|
53 | 57 | _has_accelerate = False
|
54 | 58 | AlignDevicesHook = None
|
|
58 | 62 | PrefixedDataset = None
|
59 | 63 | set_module_tensor_to_device = None
|
60 | 64 | named_module_tensors = None
|
61 |
| - dispatch_model = None |
62 | 65 | attach_align_device_hook = None
|
| 66 | + find_tied_parameters = None |
63 | 67 |
|
64 | 68 |
|
65 | 69 | __all__ = [
|
|
78 | 82 | "align_module_device",
|
79 | 83 | "register_offload_module",
|
80 | 84 | "delete_offload_module",
|
81 |
| - "force_cpu_offload", |
| 85 | + "offloaded_dispatch", |
82 | 86 | ]
|
83 | 87 |
|
84 | 88 |
|
85 | 89 | def check_accelerate(fallback: Any):
|
86 | 90 | def decorator(func: Callable[[Any], Any]):
|
87 | 91 | if not _has_accelerate:
|
88 |
| - |
89 | 92 | if fallback == "error":
|
90 | 93 |
|
91 | 94 | @wraps(func)
|
@@ -479,46 +482,44 @@ def delete_offload_module(base: torch.nn.Module, name: str):
|
479 | 482 |
|
480 | 483 |
|
481 | 484 | @check_accelerate(fallback="error")
|
482 |
| -def force_cpu_offload( |
483 |
| - module: torch.nn.Module, execution_device: torch.device |
| 485 | +def offloaded_dispatch( |
| 486 | + module: torch.nn.Module, |
| 487 | + execution_device: torch.device, |
| 488 | + offload_device: Union[torch.device, Literal["disk"]] = torch.device("cpu"), |
484 | 489 | ) -> torch.nn.Module:
|
485 | 490 | """
|
486 |
| - Force cpu offloading a module, primarily used for testing |
| 491 | + Unlike `dispatch_model`, this function forces a module (and its submodules) to |
| 492 | + offload all parameters and replace them with meta tensors, utiliizing the |
| 493 | + `AlignDevicesHook` to control onloading and offloading. |
487 | 494 |
|
488 | 495 | :param module: module containing parameters to offload
|
489 |
| - :param execution_device: execution device submodules |
490 |
| - :return: module with hooks to perform cpu offloading |
| 496 | + :param execution_device: device that modules will onload and execute on |
| 497 | + :param offload_device: device that module parameters will offload to |
| 498 | + :return: module with offloading device hooks |
491 | 499 | """
|
492 |
| - # edge case: there is a bug in `dispatch_model` which causes |
493 |
| - # the function to only work if the model contains submodules |
494 |
| - if next(module.children(), None) is None: |
495 |
| - attach_align_device_hook( |
496 |
| - module, |
497 |
| - execution_device=execution_device, |
498 |
| - offload=True, |
499 |
| - weights_map=module.state_dict(), |
500 |
| - tied_params_map={}, |
501 |
| - ) |
502 |
| - return module |
503 |
| - |
504 |
| - device_map = {} |
505 |
| - |
506 |
| - def collect_device_map(name: List[str], module: torch.nn.Module): |
507 |
| - if next(module.parameters(recurse=False), None) is not None: |
508 |
| - device_map[".".join(name)] = "cpu" |
509 |
| - return |
510 |
| - |
511 |
| - else: |
512 |
| - for submodule_name, submodule in module.named_children(): |
513 |
| - name.append(submodule_name) |
514 |
| - collect_device_map(name, submodule) |
515 |
| - name.pop() |
516 |
| - |
517 |
| - collect_device_map([], module) |
518 |
| - |
519 |
| - return dispatch_model( |
520 |
| - module, device_map, main_device=execution_device, force_hooks=True |
| 500 | + if offload_device == "disk": |
| 501 | + raise NotImplementedError("Disk offloading is not currently supported") |
| 502 | + |
| 503 | + # create weights map |
| 504 | + weights_map = OffloadedWeightsLoader(state_dict=module.state_dict(), device="cpu") |
| 505 | + |
| 506 | + # create tied params map |
| 507 | + tied_params = find_tied_parameters(module) |
| 508 | + tied_params_map = {} |
| 509 | + for group in tied_params: |
| 510 | + for param_name in group: |
| 511 | + data_ptr = attrgetter(param_name)(module).data_ptr() |
| 512 | + tied_params_map[data_ptr] = {} |
| 513 | + |
| 514 | + # recursively attaches hooks to all submodules |
| 515 | + attach_align_device_hook( |
| 516 | + module, |
| 517 | + execution_device=execution_device, |
| 518 | + offload=True, |
| 519 | + weights_map=weights_map, |
| 520 | + tied_params_map=tied_params_map, |
521 | 521 | )
|
| 522 | + return module |
522 | 523 |
|
523 | 524 |
|
524 | 525 | """ Upstreamed Functions """
|
|
0 commit comments