|
28 | 28 | import contextlib |
29 | 29 | import warnings |
30 | 30 | from functools import wraps |
31 | | -from typing import Any, Callable, Dict, Literal, Optional, Union |
| 31 | +from typing import Any, Callable, Dict, Iterable, Literal, Optional, Union |
32 | 32 |
|
33 | 33 | import torch |
34 | 34 |
|
|
67 | 67 | "delete_offload_parameter", |
68 | 68 | "has_offloaded_params", |
69 | 69 | "disable_hf_hook", |
| 70 | + "disable_offload", |
| 71 | + "align_modules", |
70 | 72 | "align_module_device", |
71 | 73 | ] |
72 | 74 |
|
@@ -344,6 +346,43 @@ def delete_from_weights_map( |
344 | 346 | ) |
345 | 347 |
|
346 | 348 |
|
| 349 | +@contextlib.contextmanager |
| 350 | +def disable_offload(module: torch.nn.Module): |
| 351 | + """ |
| 352 | + Context manager to disable module onloading and offloading. Parameters will stay on |
| 353 | + their current device |
| 354 | +
|
| 355 | + :param module: module to disable offloading for |
| 356 | + """ |
| 357 | + if has_offloaded_params(module): |
| 358 | + module._hf_hook.offload = False |
| 359 | + yield |
| 360 | + module._hf_hook.offload = True |
| 361 | + else: |
| 362 | + yield |
| 363 | + |
| 364 | + |
| 365 | +@contextlib.contextmanager |
| 366 | +def align_modules( |
| 367 | + modules: Union[torch.nn.Module, Iterable[torch.nn.Module]], |
| 368 | + execution_device: Optional[torch.device] = None, |
| 369 | +): |
| 370 | + """ |
| 371 | + Context manager for onloading modules to a device, and disabling onload and offload |
| 372 | + attempts triggered by forward calls. Used for sequential onloading of layers |
| 373 | +
|
| 374 | + :param modules: `torch.nn.Module` or iterable of `torch.nn.Module`s to onload |
| 375 | + :param execution_device: device to onload to |
| 376 | + """ |
| 377 | + modules = (modules,) if isinstance(modules, torch.nn.Module) else modules |
| 378 | + |
| 379 | + with contextlib.ExitStack() as stack: |
| 380 | + for module in modules: |
| 381 | + stack.enter_context(align_module_device(module, execution_device)) |
| 382 | + stack.enter_context(disable_offload(module)) # disable redundant onloading |
| 383 | + yield |
| 384 | + |
| 385 | + |
347 | 386 | """ Upstreamed Functions """ |
348 | 387 |
|
349 | 388 |
|
|
0 commit comments