Skip to content

Commit 9a147b8

Browse files
a-r-r-o-wstevhliu
andauthored
Module Group Offloading (#10503)
* update * fix * non_blocking; handle parameters and buffers * update * Group offloading with cuda stream prefetching (#10516) * cuda stream prefetch * remove breakpoints * update * copy model hook implementation from pab * update; ~very workaround based implementation but it seems to work as expected; needs cleanup and rewrite * more workarounds to make it actually work * cleanup * rewrite * update * make sure to sync current stream before overwriting with pinned params not doing so will lead to erroneous computations on the GPU and cause bad results * better check * update * remove hook implementation to not deal with merge conflict * re-add hook changes * why use more memory when less memory do trick * why still use slightly more memory when less memory do trick * optimise * add model tests * add pipeline tests * update docs * add layernorm and groupnorm * address review comments * improve tests; add docs * improve docs * Apply suggestions from code review Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * apply suggestions from code review * update tests * apply suggestions from review * enable_group_offloading -> enable_group_offload for naming consistency * raise errors if multiple offloading strategies used; add relevant tests * handle .to() when group offload applied * refactor some repeated code * remove unintentional change from merge conflict * handle .cuda() --------- Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
1 parent ab42820 commit 9a147b8

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+1239
-4
lines changed

docs/source/en/api/utilities.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,3 +45,7 @@ Utility and helper functions for working with 🤗 Diffusers.
4545
## apply_layerwise_casting
4646

4747
[[autodoc]] hooks.layerwise_casting.apply_layerwise_casting
48+
49+
## apply_group_offloading
50+
51+
[[autodoc]] hooks.group_offloading.apply_group_offloading

docs/source/en/optimization/memory.md

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,46 @@ In order to properly offload models after they're called, it is required to run
158158

159159
</Tip>
160160

161+
## Group offloading
162+
163+
Group offloading is the middle ground between sequential and model offloading. It works by offloading groups of internal layers (either `torch.nn.ModuleList` or `torch.nn.Sequential`), which uses less memory than model-level offloading. It is also faster than sequential-level offloading because the number of device synchronizations is reduced.
164+
165+
To enable group offloading, call the [`~ModelMixin.enable_group_offload`] method on the model if it is a Diffusers model implementation. For any other model implementation, use [`~hooks.group_offloading.apply_group_offloading`]:
166+
167+
```python
168+
import torch
169+
from diffusers import CogVideoXPipeline
170+
from diffusers.hooks import apply_group_offloading
171+
from diffusers.utils import export_to_video
172+
173+
# Load the pipeline
174+
onload_device = torch.device("cuda")
175+
offload_device = torch.device("cpu")
176+
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
177+
178+
# We can utilize the enable_group_offload method for Diffusers model implementations
179+
pipe.transformer.enable_group_offload(onload_device=onload_device, offload_device=offload_device, offload_type="leaf_level", use_stream=True)
180+
181+
# For any other model implementations, the apply_group_offloading function can be used
182+
apply_group_offloading(pipe.text_encoder, onload_device=onload_device, offload_type="block_level", num_blocks_per_group=2)
183+
apply_group_offloading(pipe.vae, onload_device=onload_device, offload_type="leaf_level")
184+
185+
prompt = (
186+
"A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
187+
"The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
188+
"pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
189+
"casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
190+
"The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
191+
"atmosphere of this unique musical performance."
192+
)
193+
video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
194+
# This utilized about 14.79 GB. It can be further reduced by using tiling and using leaf_level offloading throughout the pipeline.
195+
print(f"Max memory reserved: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB")
196+
export_to_video(video, "output.mp4", fps=8)
197+
```
198+
199+
Group offloading (for CUDA devices with support for asynchronous data transfer streams) overlaps data transfer and computation to reduce the overall execution time compared to sequential offloading. This is enabled using layer prefetching with CUDA streams. The next layer to be executed is loaded onto the accelerator device while the current layer is being executed - this increases the memory requirements slightly. Group offloading also supports leaf-level offloading (equivalent to sequential CPU offloading) but can be made much faster when using streams.
200+
161201
## FP8 layerwise weight-casting
162202

163203
PyTorch supports `torch.float8_e4m3fn` and `torch.float8_e5m2` as weight storage dtypes, but they can't be used for computation in many different tensor operations due to unimplemented kernel support. However, you can use these dtypes to store model weights in fp8 precision and upcast them on-the-fly when the layers are used in the forward pass. This is known as layerwise weight-casting.

src/diffusers/hooks/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33

44
if is_torch_available():
5+
from .group_offloading import apply_group_offloading
56
from .hooks import HookRegistry, ModelHook
67
from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook
78
from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast

src/diffusers/hooks/group_offloading.py

Lines changed: 678 additions & 0 deletions
Large diffs are not rendered by default.

src/diffusers/models/autoencoders/autoencoder_oobleck.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,7 @@ class AutoencoderOobleck(ModelMixin, ConfigMixin):
317317
"""
318318

319319
_supports_gradient_checkpointing = False
320+
_supports_group_offloading = False
320321

321322
@register_to_config
322323
def __init__(

src/diffusers/models/autoencoders/consistency_decoder_vae.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
6868
```
6969
"""
7070

71+
_supports_group_offloading = False
72+
7173
@register_to_config
7274
def __init__(
7375
self,

src/diffusers/models/autoencoders/vq_model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ class VQModel(ModelMixin, ConfigMixin):
7272
"""
7373

7474
_skip_layerwise_casting_patterns = ["quantize"]
75+
_supports_group_offloading = False
7576

7677
@register_to_config
7778
def __init__(

src/diffusers/models/modeling_utils.py

Lines changed: 91 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from typing_extensions import Self
3535

3636
from .. import __version__
37-
from ..hooks import apply_layerwise_casting
37+
from ..hooks import apply_group_offloading, apply_layerwise_casting
3838
from ..quantizers import DiffusersAutoQuantizer, DiffusersQuantizer
3939
from ..quantizers.quantization_config import QuantizationMethod
4040
from ..utils import (
@@ -87,7 +87,17 @@
8787

8888

8989
def get_parameter_device(parameter: torch.nn.Module) -> torch.device:
90+
from ..hooks.group_offloading import _get_group_onload_device
91+
92+
try:
93+
# Try to get the onload device from the group offloading hook
94+
return _get_group_onload_device(parameter)
95+
except ValueError:
96+
pass
97+
9098
try:
99+
# If the onload device is not available due to no group offloading hooks, try to get the device
100+
# from the first parameter or buffer
91101
parameters_and_buffers = itertools.chain(parameter.parameters(), parameter.buffers())
92102
return next(parameters_and_buffers).device
93103
except StopIteration:
@@ -166,6 +176,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
166176
_no_split_modules = None
167177
_keep_in_fp32_modules = None
168178
_skip_layerwise_casting_patterns = None
179+
_supports_group_offloading = True
169180

170181
def __init__(self):
171182
super().__init__()
@@ -437,6 +448,55 @@ def enable_layerwise_casting(
437448
self, storage_dtype, compute_dtype, skip_modules_pattern, skip_modules_classes, non_blocking
438449
)
439450

451+
def enable_group_offload(
452+
self,
453+
onload_device: torch.device,
454+
offload_device: torch.device = torch.device("cpu"),
455+
offload_type: str = "block_level",
456+
num_blocks_per_group: Optional[int] = None,
457+
non_blocking: bool = False,
458+
use_stream: bool = False,
459+
) -> None:
460+
r"""
461+
Activates group offloading for the current model.
462+
463+
See [`~hooks.group_offloading.apply_group_offloading`] for more information.
464+
465+
Example:
466+
467+
```python
468+
>>> from diffusers import CogVideoXTransformer3DModel
469+
470+
>>> transformer = CogVideoXTransformer3DModel.from_pretrained(
471+
... "THUDM/CogVideoX-5b", subfolder="transformer", torch_dtype=torch.bfloat16
472+
... )
473+
474+
>>> transformer.enable_group_offload(
475+
... onload_device=torch.device("cuda"),
476+
... offload_device=torch.device("cpu"),
477+
... offload_type="leaf_level",
478+
... use_stream=True,
479+
... )
480+
```
481+
"""
482+
if getattr(self, "enable_tiling", None) is not None and getattr(self, "use_tiling", False) and use_stream:
483+
msg = (
484+
"Applying group offloading on autoencoders, with CUDA streams, may not work as expected if the first "
485+
"forward pass is executed with tiling enabled. Please make sure to either:\n"
486+
"1. Run a forward pass with small input shapes.\n"
487+
"2. Or, run a forward pass with tiling disabled (can still use small dummy inputs)."
488+
)
489+
logger.warning(msg)
490+
if not self._supports_group_offloading:
491+
raise ValueError(
492+
f"{self.__class__.__name__} does not support group offloading. Please make sure to set the boolean attribute "
493+
f"`_supports_group_offloading` to `True` in the class definition. If you believe this is a mistake, please "
494+
f"open an issue at https://github.com/huggingface/diffusers/issues."
495+
)
496+
apply_group_offloading(
497+
self, onload_device, offload_device, offload_type, num_blocks_per_group, non_blocking, use_stream
498+
)
499+
440500
def save_pretrained(
441501
self,
442502
save_directory: Union[str, os.PathLike],
@@ -1170,6 +1230,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
11701230
# Adapted from `transformers`.
11711231
@wraps(torch.nn.Module.cuda)
11721232
def cuda(self, *args, **kwargs):
1233+
from ..hooks.group_offloading import _is_group_offload_enabled
1234+
11731235
# Checks if the model has been loaded in 4-bit or 8-bit with BNB
11741236
if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
11751237
if getattr(self, "is_loaded_in_8bit", False):
@@ -1182,13 +1244,34 @@ def cuda(self, *args, **kwargs):
11821244
"Calling `cuda()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. "
11831245
f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2."
11841246
)
1247+
1248+
# Checks if group offloading is enabled
1249+
if _is_group_offload_enabled(self):
1250+
logger.warning(
1251+
f"The module '{self.__class__.__name__}' is group offloaded and moving it using `.cuda()` is not supported."
1252+
)
1253+
return self
1254+
11851255
return super().cuda(*args, **kwargs)
11861256

11871257
# Adapted from `transformers`.
11881258
@wraps(torch.nn.Module.to)
11891259
def to(self, *args, **kwargs):
1260+
from ..hooks.group_offloading import _is_group_offload_enabled
1261+
1262+
device_arg_or_kwarg_present = any(isinstance(arg, torch.device) for arg in args) or "device" in kwargs
11901263
dtype_present_in_args = "dtype" in kwargs
11911264

1265+
# Try converting arguments to torch.device in case they are passed as strings
1266+
for arg in args:
1267+
if not isinstance(arg, str):
1268+
continue
1269+
try:
1270+
torch.device(arg)
1271+
device_arg_or_kwarg_present = True
1272+
except RuntimeError:
1273+
pass
1274+
11921275
if not dtype_present_in_args:
11931276
for arg in args:
11941277
if isinstance(arg, torch.dtype):
@@ -1213,6 +1296,13 @@ def to(self, *args, **kwargs):
12131296
"Calling `to()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. "
12141297
f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2."
12151298
)
1299+
1300+
if _is_group_offload_enabled(self) and device_arg_or_kwarg_present:
1301+
logger.warning(
1302+
f"The module '{self.__class__.__name__}' is group offloaded and moving it using `.to()` is not supported."
1303+
)
1304+
return self
1305+
12161306
return super().to(*args, **kwargs)
12171307

12181308
# Taken from `transformers`.

src/diffusers/models/transformers/dit_transformer_2d.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ class DiTTransformer2DModel(ModelMixin, ConfigMixin):
6666

6767
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
6868
_supports_gradient_checkpointing = True
69+
_supports_group_offloading = False
6970

7071
@register_to_config
7172
def __init__(

src/diffusers/models/transformers/hunyuan_transformer_2d.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,7 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin):
245245
"""
246246

247247
_skip_layerwise_casting_patterns = ["pos_embed", "norm", "pooler"]
248+
_supports_group_offloading = False
248249

249250
@register_to_config
250251
def __init__(

0 commit comments

Comments
 (0)