Skip to content

[docs] Quantization + torch.compile + offloading #11703

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,8 @@
title: Caching
- local: optimization/memory
title: Reduce memory usage
- local: optimization/speed-memory-optims
title: Compile and offloading
- local: optimization/xformers
title: xFormers
- local: optimization/tome
Expand Down
8 changes: 4 additions & 4 deletions docs/source/en/optimization/memory.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ Modern diffusion models like [Flux](../api/pipelines/flux) and [Wan](../api/pipe
This guide will show you how to reduce your memory usage.

> [!TIP]
> Keep in mind these techniques may need to be adjusted depending on the model! For example, a transformer-based diffusion model may not benefit equally from these inference speed optimizations as a UNet-based model.
> Keep in mind these techniques may need to be adjusted depending on the model. For example, a transformer-based diffusion model may not benefit equally from these memory optimizations as a UNet-based model.

## Multiple GPUs

Expand Down Expand Up @@ -145,7 +145,7 @@ print(f"Max memory reserved: {torch.cuda.max_memory_allocated() / 1024**3:.2f} G
```

> [!WARNING]
> [`AutoencoderKLWan`] and [`AsymmetricAutoencoderKL`] don't support slicing.
> The [`AutoencoderKLWan`] and [`AsymmetricAutoencoderKL`] classes don't support slicing.

## VAE tiling

Expand Down Expand Up @@ -219,7 +219,7 @@ from diffusers import DiffusionPipeline
pipeline = DiffusionPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16
)
pipline.enable_model_cpu_offload()
pipeline.enable_model_cpu_offload()

pipeline(
prompt="An astronaut riding a horse on Mars",
Expand Down Expand Up @@ -493,7 +493,7 @@ with torch.inference_mode():
## Memory-efficient attention

> [!TIP]
> Memory-efficient attention optimizes for memory usage *and* [inference speed](./fp16#scaled-dot-product-attention!
> Memory-efficient attention optimizes for memory usage *and* [inference speed](./fp16#scaled-dot-product-attention)!

The Transformers attention mechanism is memory-intensive, especially for long sequences, so you can try using different and more memory-efficient attention types.

Expand Down
143 changes: 143 additions & 0 deletions docs/source/en/optimization/speed-memory-optims.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->

# Compile and offloading

When optimizing models, you often face trade-offs between [inference speed](./fp16) and [memory-usage](./memory). For instance, while [caching](./cache) can boost inference speed, it comes at the cost of increased memory consumption since it needs to store intermediate attention layer outputs.

A more balanced optimization strategy combines [torch.compile](./fp16#torchcompile) with various offloading methods. This approach not only accelerates inference but also helps lower memory-usage.

The table below provides a comparison of optimization strategy combinations and their impact on latency and memory-usage.

| combination | latency | memory-usage |
|---|---|---|
| quantization, torch.compile | | |
| quantization, torch.compile, model CPU offloading | | |
| quantization, torch.compile, group offloading | | |

This guide will show you how to compile and offload a model.

## Quantization and torch.compile

> [!TIP]
> The quantization backend, such as [bitsandbytes](../quantization/bitsandbytes#torchcompile), must be compatible with torch.compile. Refer to the quantization [overview](https://huggingface.co/docs/transformers/quantization/overview#overview) table to see which backends support torch.compile.

Start by [quantizing](../quantization/overview) a model to reduce the memory required for storage and [compiling](./fp16#torchcompile) it to accelerate inference.

```py
import torch
from diffusers import DiffusionPipeline
from diffusers.quantizers import PipelineQuantizationConfig

# quantize
pipeline_quant_config = PipelineQuantizationConfig(
quant_backend="bitsandbytes_4bit",
quant_kwargs={"load_in_4bit": True, "bnb_4bit_quant_type": "nf4", "bnb_4bit_compute_dtype": torch.bfloat16},
components_to_quantize=["transformer", "text_encoder_2"],
)
pipeline = DiffusionPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
quantization_config=pipeline_quant_config,
torch_dtype=torch.bfloat16,
).to("cuda")

# compile
pipeline.transformer.to(memory_format=torch.channels_last)
pipeline.transformer.compile( mode="max-autotune", fullgraph=True)
pipeline("""
cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
"""
).images[0]
```

## Quantization, torch.compile, and offloading

In addition to quantization and torch.compile, try offloading if you need to reduce memory-usage further. Offloading moves various layers or model components from the CPU to the GPU as needed for computations.

<hfoptions id="offloading">
<hfoption id="model CPU offloading">

[Model CPU offloading](./memory#model-offloading) moves an individual pipeline component, like the transformer model, to the GPU when it is needed for computation. Otherwise, it is offloaded to the CPU.

```py
import torch
from diffusers import DiffusionPipeline
from diffusers.quantizers import PipelineQuantizationConfig

# quantize
pipeline_quant_config = PipelineQuantizationConfig(
quant_backend="bitsandbytes_4bit",
quant_kwargs={"load_in_4bit": True, "bnb_4bit_quant_type": "nf4", "bnb_4bit_compute_dtype": torch.bfloat16},
components_to_quantize=["transformer", "text_encoder_2"],
)
pipeline = DiffusionPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
quantization_config=pipeline_quant_config,
torch_dtype=torch.bfloat16,
).to("cuda")

# model CPU offloading
pipeline.enable_model_cpu_offload()

# compile
pipeline.transformer.to(memory_format=torch.channels_last)
pipeline.transformer.compile( mode="max-autotune", fullgraph=True)
pipeline(
"cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California, highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain"
).images[0]
```

</hfoption>
<hfoption id="group offloading">

[Group offloading](./memory#group-offloading) moves the internal layers of an individual pipeline component, like the transformer model, to the GPU for computation and offloads it when it's not required. At the same time, it uses the [CUDA stream](./memory#cuda-stream) feature to prefetch the next layer for execution.

By overlapping computation and data transfer, it is faster than model CPU offloading while also saving memory.

```py
import torch
from diffusers import DiffusionPipeline
from diffusers.hooks import apply_group_offloading
from diffusers.quantizers import PipelineQuantizationConfig

# quantize
pipeline_quant_config = PipelineQuantizationConfig(
quant_backend="bitsandbytes_4bit",
quant_kwargs={"load_in_4bit": True, "bnb_4bit_quant_type": "nf4", "bnb_4bit_compute_dtype": torch.bfloat16},
components_to_quantize=["transformer", "text_encoder_2"],
)
pipeline = DiffusionPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
quantization_config=pipeline_quant_config,
torch_dtype=torch.bfloat16,
).to("cuda")

# group offloading
onload_device = torch.device("cuda")
offload_device = torch.device("cpu")

pipeline.transformer.enable_group_offload(onload_device=onload_device, offload_device=offload_device, offload_type="leaf_level", use_stream=True)
pipeline.vae.enable_group_offload(onload_device=onload_device, offload_type="leaf_level", use_stream=True)
apply_group_offloading(pipeline.text_encoder, onload_device=onload_device, offload_type="leaf_level", use_stream=True)
apply_group_offloading(pipeline.text_encoder_2, onload_device=onload_device, offload_type="leaf_level", use_stream=True)

# compile
pipeline.transformer.to(memory_format=torch.channels_last)
pipeline.transformer.compile( mode="max-autotune", fullgraph=True)
pipeline(
"cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California, highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain"
).images[0]
```

</hfoption>
</hfoptions>