-
Notifications
You must be signed in to change notification settings - Fork 6k
[wip][poc] make group offloading work with disk/nvme transfers #11682
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
e0d5079
49ac665
278cbc2
d32a2c6
0bf55a9
d8179b1
4e4842f
8029cd7
a018ee1
2d30561
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -12,9 +12,11 @@ | |||
# See the License for the specific language governing permissions and | ||||
# limitations under the License. | ||||
|
||||
import os | ||||
from contextlib import contextmanager, nullcontext | ||||
from typing import Dict, List, Optional, Set, Tuple, Union | ||||
|
||||
import safetensors.torch | ||||
import torch | ||||
|
||||
from ..utils import get_logger, is_accelerate_available | ||||
|
@@ -59,6 +61,7 @@ def __init__( | |||
record_stream: Optional[bool] = False, | ||||
low_cpu_mem_usage: bool = False, | ||||
onload_self: bool = True, | ||||
offload_to_disk_path: Optional[str] = None, | ||||
) -> None: | ||||
self.modules = modules | ||||
self.offload_device = offload_device | ||||
|
@@ -72,7 +75,26 @@ def __init__( | |||
self.record_stream = record_stream | ||||
self.onload_self = onload_self | ||||
self.low_cpu_mem_usage = low_cpu_mem_usage | ||||
self.cpu_param_dict = self._init_cpu_param_dict() | ||||
|
||||
self.offload_to_disk_path = offload_to_disk_path | ||||
self._is_offloaded_to_disk = False | ||||
|
||||
if self.offload_to_disk_path: | ||||
self.safetensors_file_path = os.path.join(self.offload_to_disk_path, f"group_{id(self)}.safetensors") | ||||
|
||||
all_tensors = [] | ||||
for module in self.modules: | ||||
all_tensors.extend(list(module.parameters())) | ||||
all_tensors.extend(list(module.buffers())) | ||||
all_tensors.extend(self.parameters) | ||||
all_tensors.extend(self.buffers) | ||||
all_tensors = list(dict.fromkeys(all_tensors)) # Remove duplicates | ||||
|
||||
self.tensor_to_key = {tensor: f"tensor_{i}" for i, tensor in enumerate(all_tensors)} | ||||
self.key_to_tensor = {v: k for k, v in self.tensor_to_key.items()} | ||||
self.cpu_param_dict = {} | ||||
else: | ||||
self.cpu_param_dict = self._init_cpu_param_dict() | ||||
|
||||
if self.stream is None and self.record_stream: | ||||
raise ValueError("`record_stream` cannot be True when `stream` is None.") | ||||
|
@@ -124,6 +146,30 @@ def onload_(self): | |||
context = nullcontext() if self.stream is None else torch_accelerator_module.stream(self.stream) | ||||
current_stream = torch_accelerator_module.current_stream() if self.record_stream else None | ||||
|
||||
if self.offload_to_disk_path: | ||||
if self.stream is not None: | ||||
# Wait for previous Host->Device transfer to complete | ||||
self.stream.synchronize() | ||||
|
||||
with context: | ||||
if self.stream is not None: | ||||
# Load to CPU, pin, and async copy to device for overlapping transfer and compute | ||||
loaded_cpu_tensors = safetensors.torch.load_file(self.safetensors_file_path, device="cpu") | ||||
for key, tensor_obj in self.key_to_tensor.items(): | ||||
pinned_tensor = loaded_cpu_tensors[key].pin_memory() | ||||
tensor_obj.data = pinned_tensor.to(self.onload_device, non_blocking=self.non_blocking) | ||||
if self.record_stream: | ||||
tensor_obj.data.record_stream(current_stream) | ||||
Comment on lines
+156
to
+162
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think cleaner approach would be to provide a callable to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we know if there would be other alternatives to this code path? If not, I think it's better as is. From skimming through the documentation of |
||||
else: | ||||
# Load directly to the target device (synchronous) | ||||
onload_device = ( | ||||
self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device | ||||
) | ||||
loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=onload_device) | ||||
for key, tensor_obj in self.key_to_tensor.items(): | ||||
tensor_obj.data = loaded_tensors[key] | ||||
return | ||||
|
||||
if self.stream is not None: | ||||
# Wait for previous Host->Device transfer to complete | ||||
self.stream.synchronize() | ||||
|
@@ -169,6 +215,26 @@ def onload_(self): | |||
@torch.compiler.disable() | ||||
def offload_(self): | ||||
r"""Offloads the group of modules to the offload_device.""" | ||||
if self.offload_to_disk_path: | ||||
# TODO: we can potentially optimize this code path by checking if the _all_ the desired | ||||
# safetensor files exist on the disk and if so, skip this step entirely, reducing IO | ||||
# overhead. Currently, we just check if the given `safetensors_file_path` exists and if not | ||||
# we perform a write. | ||||
# Check if the file has been saved in this session or if it already exists on disk. | ||||
if not self._is_offloaded_to_disk and not os.path.exists(self.safetensors_file_path): | ||||
os.makedirs(os.path.dirname(self.safetensors_file_path), exist_ok=True) | ||||
tensors_to_save = { | ||||
key: tensor.data.to(self.offload_device) for tensor, key in self.tensor_to_key.items() | ||||
} | ||||
safetensors.torch.save_file(tensors_to_save, self.safetensors_file_path) | ||||
|
||||
# The group is now considered offloaded to disk for the rest of the session. | ||||
self._is_offloaded_to_disk = True | ||||
|
||||
# We do this to free up the RAM which is still holding the up tensor data. | ||||
for tensor_obj in self.tensor_to_key.keys(): | ||||
tensor_obj.data = torch.empty_like(tensor_obj.data, device=self.offload_device) | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What's the reason for this to be different from the non-disk-offload counterpart? That is, is there a reason we're not doing There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So, we first free up the memory of the accelerator with:
However, since we're also optimizing for RAM usage (can be made clearer through documentation I believe), we need to free up the RAM that is holding the tensor data. After the data has been safely written from RAM to the disk, this step replaces the large data tensor in RAM with a memory-less placeholder. This allows the memory to be released. |
||||
return | ||||
|
||||
torch_accelerator_module = ( | ||||
getattr(torch, torch.accelerator.current_accelerator().type) | ||||
|
@@ -205,13 +271,12 @@ class GroupOffloadingHook(ModelHook): | |||
|
||||
_is_stateful = False | ||||
|
||||
def __init__( | ||||
self, | ||||
group: ModuleGroup, | ||||
next_group: Optional[ModuleGroup] = None, | ||||
) -> None: | ||||
def __init__(self, group: ModuleGroup, next_group: Optional[ModuleGroup] = None) -> None: | ||||
self.group = group | ||||
self.next_group = next_group | ||||
# map param/buffer name -> file path | ||||
self.param_to_path: Dict[str, str] = {} | ||||
self.buffer_to_path: Dict[str, str] = {} | ||||
|
||||
def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module: | ||||
if self.group.offload_leader == module: | ||||
|
@@ -358,6 +423,7 @@ def apply_group_offloading( | |||
onload_device: torch.device, | ||||
offload_device: torch.device = torch.device("cpu"), | ||||
offload_type: str = "block_level", | ||||
offload_to_disk_path: Optional[str] = None, | ||||
num_blocks_per_group: Optional[int] = None, | ||||
non_blocking: bool = False, | ||||
use_stream: bool = False, | ||||
|
@@ -401,6 +467,8 @@ def apply_group_offloading( | |||
offload_type (`str`, defaults to "block_level"): | ||||
The type of offloading to be applied. Can be one of "block_level" or "leaf_level". Default is | ||||
"block_level". | ||||
offload_to_disk_path (`str`, *optional*): | ||||
The path to the directory where offloaded parameters will be stored. | ||||
num_blocks_per_group (`int`, *optional*): | ||||
The number of blocks per group when using offload_type="block_level". This is required when using | ||||
offload_type="block_level". | ||||
|
@@ -418,6 +486,8 @@ def apply_group_offloading( | |||
option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when | ||||
the CPU memory is a bottleneck but may counteract the benefits of using streams. | ||||
|
||||
(TODO: include example with `offload_to_disk_path`) | ||||
|
||||
Example: | ||||
```python | ||||
>>> from diffusers import CogVideoXTransformer3DModel | ||||
|
@@ -458,6 +528,7 @@ def apply_group_offloading( | |||
num_blocks_per_group=num_blocks_per_group, | ||||
offload_device=offload_device, | ||||
onload_device=onload_device, | ||||
offload_to_disk_path=offload_to_disk_path, | ||||
non_blocking=non_blocking, | ||||
stream=stream, | ||||
record_stream=record_stream, | ||||
|
@@ -468,6 +539,7 @@ def apply_group_offloading( | |||
module=module, | ||||
offload_device=offload_device, | ||||
onload_device=onload_device, | ||||
offload_to_disk_path=offload_to_disk_path, | ||||
non_blocking=non_blocking, | ||||
stream=stream, | ||||
record_stream=record_stream, | ||||
|
@@ -481,6 +553,7 @@ def _apply_group_offloading_block_level( | |||
module: torch.nn.Module, | ||||
num_blocks_per_group: int, | ||||
offload_device: torch.device, | ||||
offload_to_disk_path: Optional[str], | ||||
onload_device: torch.device, | ||||
non_blocking: bool, | ||||
stream: Union[torch.cuda.Stream, torch.Stream, None] = None, | ||||
|
@@ -496,6 +569,7 @@ def _apply_group_offloading_block_level( | |||
The module to which group offloading is applied. | ||||
offload_device (`torch.device`): | ||||
The device to which the group of modules are offloaded. This should typically be the CPU. | ||||
offload_to_disk_path: TODO | ||||
onload_device (`torch.device`): | ||||
The device to which the group of modules are onloaded. | ||||
non_blocking (`bool`): | ||||
|
@@ -535,6 +609,7 @@ def _apply_group_offloading_block_level( | |||
modules=current_modules, | ||||
offload_device=offload_device, | ||||
onload_device=onload_device, | ||||
offload_to_disk_path=offload_to_disk_path, | ||||
offload_leader=current_modules[-1], | ||||
onload_leader=current_modules[0], | ||||
non_blocking=non_blocking, | ||||
|
@@ -567,6 +642,7 @@ def _apply_group_offloading_block_level( | |||
modules=unmatched_modules, | ||||
offload_device=offload_device, | ||||
onload_device=onload_device, | ||||
offload_to_disk_path=offload_to_disk_path, | ||||
offload_leader=module, | ||||
onload_leader=module, | ||||
parameters=parameters, | ||||
|
@@ -586,6 +662,7 @@ def _apply_group_offloading_leaf_level( | |||
module: torch.nn.Module, | ||||
offload_device: torch.device, | ||||
onload_device: torch.device, | ||||
offload_to_disk_path: Optional[str], | ||||
non_blocking: bool, | ||||
stream: Union[torch.cuda.Stream, torch.Stream, None] = None, | ||||
record_stream: Optional[bool] = False, | ||||
|
@@ -604,6 +681,7 @@ def _apply_group_offloading_leaf_level( | |||
The device to which the group of modules are offloaded. This should typically be the CPU. | ||||
onload_device (`torch.device`): | ||||
The device to which the group of modules are onloaded. | ||||
offload_to_disk_path: TODO | ||||
non_blocking (`bool`): | ||||
If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation | ||||
and data transfer. | ||||
|
@@ -629,6 +707,7 @@ def _apply_group_offloading_leaf_level( | |||
modules=[submodule], | ||||
offload_device=offload_device, | ||||
onload_device=onload_device, | ||||
offload_to_disk_path=offload_to_disk_path, | ||||
offload_leader=submodule, | ||||
onload_leader=submodule, | ||||
non_blocking=non_blocking, | ||||
|
@@ -675,6 +754,7 @@ def _apply_group_offloading_leaf_level( | |||
onload_device=onload_device, | ||||
offload_leader=parent_module, | ||||
onload_leader=parent_module, | ||||
offload_to_disk_path=offload_to_disk_path, | ||||
parameters=parameters, | ||||
buffers=buffers, | ||||
non_blocking=non_blocking, | ||||
|
@@ -693,6 +773,7 @@ def _apply_group_offloading_leaf_level( | |||
modules=[], | ||||
offload_device=offload_device, | ||||
onload_device=onload_device, | ||||
offload_to_disk_path=offload_to_disk_path, | ||||
offload_leader=module, | ||||
onload_leader=module, | ||||
parameters=None, | ||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(nit): will there be duplicates? I cannot think of a quick example, so maybe we can remove
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There shouldn't really be. But I kept it to prevent edge-cases while reading something similar.