diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index 565f8f1ff860..1ea60c3f3342 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -12,9 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os from contextlib import contextmanager, nullcontext from typing import Dict, List, Optional, Set, Tuple, Union +import safetensors.torch import torch from ..utils import get_logger, is_accelerate_available @@ -59,6 +61,7 @@ def __init__( record_stream: Optional[bool] = False, low_cpu_mem_usage: bool = False, onload_self: bool = True, + offload_to_disk_path: Optional[str] = None, ) -> None: self.modules = modules self.offload_device = offload_device @@ -72,7 +75,26 @@ def __init__( self.record_stream = record_stream self.onload_self = onload_self self.low_cpu_mem_usage = low_cpu_mem_usage - self.cpu_param_dict = self._init_cpu_param_dict() + + self.offload_to_disk_path = offload_to_disk_path + self._is_offloaded_to_disk = False + + if self.offload_to_disk_path: + self.safetensors_file_path = os.path.join(self.offload_to_disk_path, f"group_{id(self)}.safetensors") + + all_tensors = [] + for module in self.modules: + all_tensors.extend(list(module.parameters())) + all_tensors.extend(list(module.buffers())) + all_tensors.extend(self.parameters) + all_tensors.extend(self.buffers) + all_tensors = list(dict.fromkeys(all_tensors)) # Remove duplicates + + self.tensor_to_key = {tensor: f"tensor_{i}" for i, tensor in enumerate(all_tensors)} + self.key_to_tensor = {v: k for k, v in self.tensor_to_key.items()} + self.cpu_param_dict = {} + else: + self.cpu_param_dict = self._init_cpu_param_dict() if self.stream is None and self.record_stream: raise ValueError("`record_stream` cannot be True when `stream` is None.") @@ -124,6 +146,30 @@ def onload_(self): context = nullcontext() if self.stream is None else torch_accelerator_module.stream(self.stream) current_stream = torch_accelerator_module.current_stream() if self.record_stream else None + if self.offload_to_disk_path: + if self.stream is not None: + # Wait for previous Host->Device transfer to complete + self.stream.synchronize() + + with context: + if self.stream is not None: + # Load to CPU, pin, and async copy to device for overlapping transfer and compute + loaded_cpu_tensors = safetensors.torch.load_file(self.safetensors_file_path, device="cpu") + for key, tensor_obj in self.key_to_tensor.items(): + pinned_tensor = loaded_cpu_tensors[key].pin_memory() + tensor_obj.data = pinned_tensor.to(self.onload_device, non_blocking=self.non_blocking) + if self.record_stream: + tensor_obj.data.record_stream(current_stream) + else: + # Load directly to the target device (synchronous) + onload_device = ( + self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device + ) + loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=onload_device) + for key, tensor_obj in self.key_to_tensor.items(): + tensor_obj.data = loaded_tensors[key] + return + if self.stream is not None: # Wait for previous Host->Device transfer to complete self.stream.synchronize() @@ -169,6 +215,26 @@ def onload_(self): @torch.compiler.disable() def offload_(self): r"""Offloads the group of modules to the offload_device.""" + if self.offload_to_disk_path: + # TODO: we can potentially optimize this code path by checking if the _all_ the desired + # safetensor files exist on the disk and if so, skip this step entirely, reducing IO + # overhead. Currently, we just check if the given `safetensors_file_path` exists and if not + # we perform a write. + # Check if the file has been saved in this session or if it already exists on disk. + if not self._is_offloaded_to_disk and not os.path.exists(self.safetensors_file_path): + os.makedirs(os.path.dirname(self.safetensors_file_path), exist_ok=True) + tensors_to_save = { + key: tensor.data.to(self.offload_device) for tensor, key in self.tensor_to_key.items() + } + safetensors.torch.save_file(tensors_to_save, self.safetensors_file_path) + + # The group is now considered offloaded to disk for the rest of the session. + self._is_offloaded_to_disk = True + + # We do this to free up the RAM which is still holding the up tensor data. + for tensor_obj in self.tensor_to_key.keys(): + tensor_obj.data = torch.empty_like(tensor_obj.data, device=self.offload_device) + return torch_accelerator_module = ( getattr(torch, torch.accelerator.current_accelerator().type) @@ -205,13 +271,12 @@ class GroupOffloadingHook(ModelHook): _is_stateful = False - def __init__( - self, - group: ModuleGroup, - next_group: Optional[ModuleGroup] = None, - ) -> None: + def __init__(self, group: ModuleGroup, next_group: Optional[ModuleGroup] = None) -> None: self.group = group self.next_group = next_group + # map param/buffer name -> file path + self.param_to_path: Dict[str, str] = {} + self.buffer_to_path: Dict[str, str] = {} def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module: if self.group.offload_leader == module: @@ -358,6 +423,7 @@ def apply_group_offloading( onload_device: torch.device, offload_device: torch.device = torch.device("cpu"), offload_type: str = "block_level", + offload_to_disk_path: Optional[str] = None, num_blocks_per_group: Optional[int] = None, non_blocking: bool = False, use_stream: bool = False, @@ -401,6 +467,8 @@ def apply_group_offloading( offload_type (`str`, defaults to "block_level"): The type of offloading to be applied. Can be one of "block_level" or "leaf_level". Default is "block_level". + offload_to_disk_path (`str`, *optional*): + The path to the directory where offloaded parameters will be stored. num_blocks_per_group (`int`, *optional*): The number of blocks per group when using offload_type="block_level". This is required when using offload_type="block_level". @@ -418,6 +486,8 @@ def apply_group_offloading( option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when the CPU memory is a bottleneck but may counteract the benefits of using streams. + (TODO: include example with `offload_to_disk_path`) + Example: ```python >>> from diffusers import CogVideoXTransformer3DModel @@ -458,6 +528,7 @@ def apply_group_offloading( num_blocks_per_group=num_blocks_per_group, offload_device=offload_device, onload_device=onload_device, + offload_to_disk_path=offload_to_disk_path, non_blocking=non_blocking, stream=stream, record_stream=record_stream, @@ -468,6 +539,7 @@ def apply_group_offloading( module=module, offload_device=offload_device, onload_device=onload_device, + offload_to_disk_path=offload_to_disk_path, non_blocking=non_blocking, stream=stream, record_stream=record_stream, @@ -481,6 +553,7 @@ def _apply_group_offloading_block_level( module: torch.nn.Module, num_blocks_per_group: int, offload_device: torch.device, + offload_to_disk_path: Optional[str], onload_device: torch.device, non_blocking: bool, stream: Union[torch.cuda.Stream, torch.Stream, None] = None, @@ -496,6 +569,7 @@ def _apply_group_offloading_block_level( The module to which group offloading is applied. offload_device (`torch.device`): The device to which the group of modules are offloaded. This should typically be the CPU. + offload_to_disk_path: TODO onload_device (`torch.device`): The device to which the group of modules are onloaded. non_blocking (`bool`): @@ -535,6 +609,7 @@ def _apply_group_offloading_block_level( modules=current_modules, offload_device=offload_device, onload_device=onload_device, + offload_to_disk_path=offload_to_disk_path, offload_leader=current_modules[-1], onload_leader=current_modules[0], non_blocking=non_blocking, @@ -567,6 +642,7 @@ def _apply_group_offloading_block_level( modules=unmatched_modules, offload_device=offload_device, onload_device=onload_device, + offload_to_disk_path=offload_to_disk_path, offload_leader=module, onload_leader=module, parameters=parameters, @@ -586,6 +662,7 @@ def _apply_group_offloading_leaf_level( module: torch.nn.Module, offload_device: torch.device, onload_device: torch.device, + offload_to_disk_path: Optional[str], non_blocking: bool, stream: Union[torch.cuda.Stream, torch.Stream, None] = None, record_stream: Optional[bool] = False, @@ -604,6 +681,7 @@ def _apply_group_offloading_leaf_level( The device to which the group of modules are offloaded. This should typically be the CPU. onload_device (`torch.device`): The device to which the group of modules are onloaded. + offload_to_disk_path: TODO non_blocking (`bool`): If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation and data transfer. @@ -629,6 +707,7 @@ def _apply_group_offloading_leaf_level( modules=[submodule], offload_device=offload_device, onload_device=onload_device, + offload_to_disk_path=offload_to_disk_path, offload_leader=submodule, onload_leader=submodule, non_blocking=non_blocking, @@ -675,6 +754,7 @@ def _apply_group_offloading_leaf_level( onload_device=onload_device, offload_leader=parent_module, onload_leader=parent_module, + offload_to_disk_path=offload_to_disk_path, parameters=parameters, buffers=buffers, non_blocking=non_blocking, @@ -693,6 +773,7 @@ def _apply_group_offloading_leaf_level( modules=[], offload_device=offload_device, onload_device=onload_device, + offload_to_disk_path=offload_to_disk_path, offload_leader=module, onload_leader=module, parameters=None, diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 1e9e28471d89..beaea4805050 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -543,6 +543,7 @@ def enable_group_offload( onload_device: torch.device, offload_device: torch.device = torch.device("cpu"), offload_type: str = "block_level", + offload_to_disk_path: Optional[str] = None, num_blocks_per_group: Optional[int] = None, non_blocking: bool = False, use_stream: bool = False, @@ -588,15 +589,16 @@ def enable_group_offload( f"open an issue at https://github.com/huggingface/diffusers/issues." ) apply_group_offloading( - self, - onload_device, - offload_device, - offload_type, - num_blocks_per_group, - non_blocking, - use_stream, - record_stream, + module=self, + onload_device=onload_device, + offload_device=offload_device, + offload_type=offload_type, + num_blocks_per_group=num_blocks_per_group, + non_blocking=non_blocking, + use_stream=use_stream, + record_stream=record_stream, low_cpu_mem_usage=low_cpu_mem_usage, + offload_to_disk_path=offload_to_disk_path, ) def save_pretrained( diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 511fa4bfa9ea..53da8828341a 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -15,6 +15,7 @@ import copy import gc +import glob import inspect import json import os @@ -1693,6 +1694,35 @@ def test_group_offloading_with_layerwise_casting(self, record_stream, offload_ty model.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype) _ = model(**inputs_dict)[0] + @parameterized.expand([(False, "block_level"), (True, "leaf_level")]) + @require_torch_accelerator + @torch.no_grad() + def test_group_offloading_with_disk(self, record_stream, offload_type): + torch.manual_seed(0) + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict) + + if not getattr(model, "_supports_group_offloading", True): + return + + torch.manual_seed(0) + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict) + model.eval() + additional_kwargs = {} if offload_type == "leaf_level" else {"num_blocks_per_group": 1} + with tempfile.TemporaryDirectory() as tmpdir: + model.enable_group_offload( + torch_device, + offload_type=offload_type, + offload_to_disk_path=tmpdir, + use_stream=True, + record_stream=record_stream, + **additional_kwargs, + ) + has_safetensors = glob.glob(f"{tmpdir}/*.safetensors") + assert has_safetensors + _ = model(**inputs_dict)[0] + def test_auto_model(self, expected_max_diff=5e-5): if self.forward_requires_fresh_args: model = self.model_class(**self.init_dict)