Skip to content

[wip][poc] make group offloading work with disk/nvme transfers #11682

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
93 changes: 87 additions & 6 deletions src/diffusers/hooks/group_offloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from contextlib import contextmanager, nullcontext
from typing import Dict, List, Optional, Set, Tuple, Union

import safetensors.torch
import torch

from ..utils import get_logger, is_accelerate_available
Expand Down Expand Up @@ -59,6 +61,7 @@ def __init__(
record_stream: Optional[bool] = False,
low_cpu_mem_usage: bool = False,
onload_self: bool = True,
offload_to_disk_path: Optional[str] = None,
) -> None:
self.modules = modules
self.offload_device = offload_device
Expand All @@ -72,7 +75,26 @@ def __init__(
self.record_stream = record_stream
self.onload_self = onload_self
self.low_cpu_mem_usage = low_cpu_mem_usage
self.cpu_param_dict = self._init_cpu_param_dict()

self.offload_to_disk_path = offload_to_disk_path
self._is_offloaded_to_disk = False

if self.offload_to_disk_path:
self.safetensors_file_path = os.path.join(self.offload_to_disk_path, f"group_{id(self)}.safetensors")

all_tensors = []
for module in self.modules:
all_tensors.extend(list(module.parameters()))
all_tensors.extend(list(module.buffers()))
all_tensors.extend(self.parameters)
all_tensors.extend(self.buffers)
all_tensors = list(dict.fromkeys(all_tensors)) # Remove duplicates
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit): will there be duplicates? I cannot think of a quick example, so maybe we can remove

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There shouldn't really be. But I kept it to prevent edge-cases while reading something similar.


self.tensor_to_key = {tensor: f"tensor_{i}" for i, tensor in enumerate(all_tensors)}
self.key_to_tensor = {v: k for k, v in self.tensor_to_key.items()}
self.cpu_param_dict = {}
else:
self.cpu_param_dict = self._init_cpu_param_dict()

if self.stream is None and self.record_stream:
raise ValueError("`record_stream` cannot be True when `stream` is None.")
Expand Down Expand Up @@ -124,6 +146,30 @@ def onload_(self):
context = nullcontext() if self.stream is None else torch_accelerator_module.stream(self.stream)
current_stream = torch_accelerator_module.current_stream() if self.record_stream else None

if self.offload_to_disk_path:
if self.stream is not None:
# Wait for previous Host->Device transfer to complete
self.stream.synchronize()

with context:
if self.stream is not None:
# Load to CPU, pin, and async copy to device for overlapping transfer and compute
loaded_cpu_tensors = safetensors.torch.load_file(self.safetensors_file_path, device="cpu")
for key, tensor_obj in self.key_to_tensor.items():
pinned_tensor = loaded_cpu_tensors[key].pin_memory()
tensor_obj.data = pinned_tensor.to(self.onload_device, non_blocking=self.non_blocking)
if self.record_stream:
tensor_obj.data.record_stream(current_stream)
Comment on lines +156 to +162
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think cleaner approach would be to provide a callable to map_location (assuming we were using torch.load instead of safetensors), which for each tensor can pin and move to device. Do we know if there is a equivalent to passing a callable with safetensors? If not, this is okay too

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we know if there would be other alternatives to this code path? If not, I think it's better as is. From skimming through the documentation of safetensors, I couldn't find any equivalent of map_location.

else:
# Load directly to the target device (synchronous)
onload_device = (
self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device
)
loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=onload_device)
for key, tensor_obj in self.key_to_tensor.items():
tensor_obj.data = loaded_tensors[key]
return

if self.stream is not None:
# Wait for previous Host->Device transfer to complete
self.stream.synchronize()
Expand Down Expand Up @@ -169,6 +215,26 @@ def onload_(self):
@torch.compiler.disable()
def offload_(self):
r"""Offloads the group of modules to the offload_device."""
if self.offload_to_disk_path:
# TODO: we can potentially optimize this code path by checking if the _all_ the desired
# safetensor files exist on the disk and if so, skip this step entirely, reducing IO
# overhead. Currently, we just check if the given `safetensors_file_path` exists and if not
# we perform a write.
# Check if the file has been saved in this session or if it already exists on disk.
if not self._is_offloaded_to_disk and not os.path.exists(self.safetensors_file_path):
os.makedirs(os.path.dirname(self.safetensors_file_path), exist_ok=True)
tensors_to_save = {
key: tensor.data.to(self.offload_device) for tensor, key in self.tensor_to_key.items()
}
safetensors.torch.save_file(tensors_to_save, self.safetensors_file_path)

# The group is now considered offloaded to disk for the rest of the session.
self._is_offloaded_to_disk = True

# We do this to free up the RAM which is still holding the up tensor data.
for tensor_obj in self.tensor_to_key.keys():
tensor_obj.data = torch.empty_like(tensor_obj.data, device=self.offload_device)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the reason for this to be different from the non-disk-offload counterpart? That is, is there a reason we're not doing buffer.data.to(self.offload_device, non_blocking=self.non_blocking)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, we first free up the memory of the accelerator with:

key: tensor.data.to(self.offload_device) for tensor, key in self.tensor_to_key.items()

However, since we're also optimizing for RAM usage (can be made clearer through documentation I believe), we need to free up the RAM that is holding the tensor data. After the data has been safely written from RAM to the disk, this step replaces the large data tensor in RAM with a memory-less placeholder. This allows the memory to be released.

return

torch_accelerator_module = (
getattr(torch, torch.accelerator.current_accelerator().type)
Expand Down Expand Up @@ -205,13 +271,12 @@ class GroupOffloadingHook(ModelHook):

_is_stateful = False

def __init__(
self,
group: ModuleGroup,
next_group: Optional[ModuleGroup] = None,
) -> None:
def __init__(self, group: ModuleGroup, next_group: Optional[ModuleGroup] = None) -> None:
self.group = group
self.next_group = next_group
# map param/buffer name -> file path
self.param_to_path: Dict[str, str] = {}
self.buffer_to_path: Dict[str, str] = {}

def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
if self.group.offload_leader == module:
Expand Down Expand Up @@ -358,6 +423,7 @@ def apply_group_offloading(
onload_device: torch.device,
offload_device: torch.device = torch.device("cpu"),
offload_type: str = "block_level",
offload_to_disk_path: Optional[str] = None,
num_blocks_per_group: Optional[int] = None,
non_blocking: bool = False,
use_stream: bool = False,
Expand Down Expand Up @@ -401,6 +467,8 @@ def apply_group_offloading(
offload_type (`str`, defaults to "block_level"):
The type of offloading to be applied. Can be one of "block_level" or "leaf_level". Default is
"block_level".
offload_to_disk_path (`str`, *optional*):
The path to the directory where offloaded parameters will be stored.
num_blocks_per_group (`int`, *optional*):
The number of blocks per group when using offload_type="block_level". This is required when using
offload_type="block_level".
Expand All @@ -418,6 +486,8 @@ def apply_group_offloading(
option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when
the CPU memory is a bottleneck but may counteract the benefits of using streams.

(TODO: include example with `offload_to_disk_path`)

Example:
```python
>>> from diffusers import CogVideoXTransformer3DModel
Expand Down Expand Up @@ -458,6 +528,7 @@ def apply_group_offloading(
num_blocks_per_group=num_blocks_per_group,
offload_device=offload_device,
onload_device=onload_device,
offload_to_disk_path=offload_to_disk_path,
non_blocking=non_blocking,
stream=stream,
record_stream=record_stream,
Expand All @@ -468,6 +539,7 @@ def apply_group_offloading(
module=module,
offload_device=offload_device,
onload_device=onload_device,
offload_to_disk_path=offload_to_disk_path,
non_blocking=non_blocking,
stream=stream,
record_stream=record_stream,
Expand All @@ -481,6 +553,7 @@ def _apply_group_offloading_block_level(
module: torch.nn.Module,
num_blocks_per_group: int,
offload_device: torch.device,
offload_to_disk_path: Optional[str],
onload_device: torch.device,
non_blocking: bool,
stream: Union[torch.cuda.Stream, torch.Stream, None] = None,
Expand All @@ -496,6 +569,7 @@ def _apply_group_offloading_block_level(
The module to which group offloading is applied.
offload_device (`torch.device`):
The device to which the group of modules are offloaded. This should typically be the CPU.
offload_to_disk_path: TODO
onload_device (`torch.device`):
The device to which the group of modules are onloaded.
non_blocking (`bool`):
Expand Down Expand Up @@ -535,6 +609,7 @@ def _apply_group_offloading_block_level(
modules=current_modules,
offload_device=offload_device,
onload_device=onload_device,
offload_to_disk_path=offload_to_disk_path,
offload_leader=current_modules[-1],
onload_leader=current_modules[0],
non_blocking=non_blocking,
Expand Down Expand Up @@ -567,6 +642,7 @@ def _apply_group_offloading_block_level(
modules=unmatched_modules,
offload_device=offload_device,
onload_device=onload_device,
offload_to_disk_path=offload_to_disk_path,
offload_leader=module,
onload_leader=module,
parameters=parameters,
Expand All @@ -586,6 +662,7 @@ def _apply_group_offloading_leaf_level(
module: torch.nn.Module,
offload_device: torch.device,
onload_device: torch.device,
offload_to_disk_path: Optional[str],
non_blocking: bool,
stream: Union[torch.cuda.Stream, torch.Stream, None] = None,
record_stream: Optional[bool] = False,
Expand All @@ -604,6 +681,7 @@ def _apply_group_offloading_leaf_level(
The device to which the group of modules are offloaded. This should typically be the CPU.
onload_device (`torch.device`):
The device to which the group of modules are onloaded.
offload_to_disk_path: TODO
non_blocking (`bool`):
If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation
and data transfer.
Expand All @@ -629,6 +707,7 @@ def _apply_group_offloading_leaf_level(
modules=[submodule],
offload_device=offload_device,
onload_device=onload_device,
offload_to_disk_path=offload_to_disk_path,
offload_leader=submodule,
onload_leader=submodule,
non_blocking=non_blocking,
Expand Down Expand Up @@ -675,6 +754,7 @@ def _apply_group_offloading_leaf_level(
onload_device=onload_device,
offload_leader=parent_module,
onload_leader=parent_module,
offload_to_disk_path=offload_to_disk_path,
parameters=parameters,
buffers=buffers,
non_blocking=non_blocking,
Expand All @@ -693,6 +773,7 @@ def _apply_group_offloading_leaf_level(
modules=[],
offload_device=offload_device,
onload_device=onload_device,
offload_to_disk_path=offload_to_disk_path,
offload_leader=module,
onload_leader=module,
parameters=None,
Expand Down
18 changes: 10 additions & 8 deletions src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,7 @@ def enable_group_offload(
onload_device: torch.device,
offload_device: torch.device = torch.device("cpu"),
offload_type: str = "block_level",
offload_to_disk_path: Optional[str] = None,
num_blocks_per_group: Optional[int] = None,
non_blocking: bool = False,
use_stream: bool = False,
Expand Down Expand Up @@ -588,15 +589,16 @@ def enable_group_offload(
f"open an issue at https://github.com/huggingface/diffusers/issues."
)
apply_group_offloading(
self,
onload_device,
offload_device,
offload_type,
num_blocks_per_group,
non_blocking,
use_stream,
record_stream,
module=self,
onload_device=onload_device,
offload_device=offload_device,
offload_type=offload_type,
num_blocks_per_group=num_blocks_per_group,
non_blocking=non_blocking,
use_stream=use_stream,
record_stream=record_stream,
low_cpu_mem_usage=low_cpu_mem_usage,
offload_to_disk_path=offload_to_disk_path,
)

def save_pretrained(
Expand Down
30 changes: 30 additions & 0 deletions tests/models/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import copy
import gc
import glob
import inspect
import json
import os
Expand Down Expand Up @@ -1693,6 +1694,35 @@ def test_group_offloading_with_layerwise_casting(self, record_stream, offload_ty
model.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype)
_ = model(**inputs_dict)[0]

@parameterized.expand([(False, "block_level"), (True, "leaf_level")])
@require_torch_accelerator
@torch.no_grad()
def test_group_offloading_with_disk(self, record_stream, offload_type):
torch.manual_seed(0)
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)

if not getattr(model, "_supports_group_offloading", True):
return

torch.manual_seed(0)
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.eval()
additional_kwargs = {} if offload_type == "leaf_level" else {"num_blocks_per_group": 1}
with tempfile.TemporaryDirectory() as tmpdir:
model.enable_group_offload(
torch_device,
offload_type=offload_type,
offload_to_disk_path=tmpdir,
use_stream=True,
record_stream=record_stream,
**additional_kwargs,
)
has_safetensors = glob.glob(f"{tmpdir}/*.safetensors")
assert has_safetensors
_ = model(**inputs_dict)[0]

def test_auto_model(self, expected_max_diff=5e-5):
if self.forward_requires_fresh_args:
model = self.model_class(**self.init_dict)
Expand Down
Loading