Skip to content

Commit c7f02c2

Browse files
committed
inference
1 parent 425a725 commit c7f02c2

File tree

3 files changed

+193
-76
lines changed

3 files changed

+193
-76
lines changed

docs/source/en/_toctree.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@
178178
title: Quantization Methods
179179
- sections:
180180
- local: optimization/fp16
181-
title: Speed up inference
181+
title: Accelerate inference
182182
- local: optimization/memory
183183
title: Reduce memory usage
184184
- local: optimization/torch2.0

docs/source/en/optimization/fp16.md

Lines changed: 186 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -10,120 +10,231 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
1010
specific language governing permissions and limitations under the License.
1111
-->
1212

13-
# Speed up inference
13+
# Accelerate inference
1414

15-
There are several ways to optimize Diffusers for inference speed, such as reducing the computational burden by lowering the data precision or using a lightweight distilled model. There are also memory-efficient attention implementations, [xFormers](xformers) and [scaled dot product attention](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) in PyTorch 2.0, that reduce memory usage which also indirectly speeds up inference. Different speed optimizations can be stacked together to get the fastest inference times.
15+
Diffusion models are slow at inference because generation is an iterative process where noise is gradually refined into an image or video over a certain number of "steps". To speedup this process, you can try experimenting with different [schedulers](../api/schedulers/overview), reduce the precision of the model weights for faster computations, use more memory-efficient attention mechanisms, and more.
1616

17-
> [!TIP]
18-
> Optimizing for inference speed or reduced memory usage can lead to improved performance in the other category, so you should try to optimize for both whenever you can. This guide focuses on inference speed, but you can learn more about lowering memory usage in the [Reduce memory usage](memory) guide.
17+
Combine and use these techniques together to make inference faster than using any single technique on its own.
18+
19+
This guide will go over how to accelerate inference.
1920

20-
The inference times below are obtained from generating a single 512x512 image from the prompt "a photo of an astronaut riding a horse on mars" with 50 DDIM steps on a NVIDIA A100.
21+
## Model data type
2122

22-
| setup | latency | speed-up |
23-
|----------|---------|----------|
24-
| baseline | 5.27s | x1 |
25-
| tf32 | 4.14s | x1.27 |
26-
| fp16 | 3.51s | x1.50 |
27-
| combined | 3.41s | x1.54 |
23+
The precision and data type of the model weights affect inference speed because a higher precision requires more memory to load and more time to perform the computations. PyTorch loads model weights in float32 or full precision by default, so changing the data type is a simple way to quickly get faster inference.
2824

29-
## TensorFloat-32
25+
<hfoptions id="dtypes">
26+
<hfoption id="bfloat16">
3027

31-
On Ampere and later CUDA devices, matrix multiplications and convolutions can use the [TensorFloat-32 (tf32)](https://blogs.nvidia.com/blog/2020/05/14/tensorfloat-32-precision-format/) mode for faster, but slightly less accurate computations. By default, PyTorch enables tf32 mode for convolutions but not matrix multiplications. Unless your network requires full float32 precision, we recommend enabling tf32 for matrix multiplications. It can significantly speed up computations with typically negligible loss in numerical accuracy.
28+
bfloat16 is similar to float16 but it is more robust to numerical errors. Hardware support for bfloat16 varies, but most modern GPUs are capable of supporting bfloat16.
3229

33-
```python
30+
```py
3431
import torch
32+
from diffusers import StableDiffusionXLPipeline
3533

36-
torch.backends.cuda.matmul.allow_tf32 = True
37-
```
34+
pipeline = StableDiffusionXLPipeline.from_pretrained(
35+
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
36+
).to("cuda")
3837

39-
Learn more about tf32 in the [Mixed precision training](https://huggingface.co/docs/transformers/en/perf_train_gpu_one#tf32) guide.
38+
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
39+
pipeline(prompt, num_inference_steps=30).images[0]
40+
```
4041

41-
## Half-precision weights
42+
</hfoption>
43+
<hfoption id="float16">
4244

43-
To save GPU memory and get more speed, set `torch_dtype=torch.float16` to load and run the model weights directly with half-precision weights.
45+
float16 is similar to bfloat16 but may be more prone to numerical errors.
4446

45-
```Python
47+
```py
4648
import torch
47-
from diffusers import DiffusionPipeline
49+
from diffusers import StableDiffusionXLPipeline
4850

49-
pipe = DiffusionPipeline.from_pretrained(
50-
"stable-diffusion-v1-5/stable-diffusion-v1-5",
51-
torch_dtype=torch.float16,
52-
use_safetensors=True,
53-
)
54-
pipe = pipe.to("cuda")
51+
pipeline = StableDiffusionXLPipeline.from_pretrained(
52+
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
53+
).to("cuda")
54+
55+
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
56+
pipeline(prompt, num_inference_steps=30).images[0]
5557
```
5658

57-
> [!WARNING]
58-
> Don't use [torch.autocast](https://pytorch.org/docs/stable/amp.html#torch.autocast) in any of the pipelines as it can lead to black images and is always slower than pure float16 precision.
59+
</hfoption>
60+
<hfoption id="TensorFloat-32">
5961

60-
## Distilled model
62+
[TensorFloat-32 (tf32)](https://blogs.nvidia.com/blog/2020/05/14/tensorfloat-32-precision-format/) mode is supported on NVIDIA Ampere GPUs and it computes the convolution and matrix multiplication operations in tf32. Storage and other operations are kept in float32. This enables significantly faster computations when combined with bfloat16 or float16.
6163

62-
You could also use a distilled Stable Diffusion model and autoencoder to speed up inference. During distillation, many of the UNet's residual and attention blocks are shed to reduce the model size by 51% and improve latency on CPU/GPU by 43%. The distilled model is faster and uses less memory while generating images of comparable quality to the full Stable Diffusion model.
64+
PyTorch only enables tf32 mode for convolutions by default and you'll need to explicitly enable it for matrix multiplications.
6365

64-
> [!TIP]
65-
> Read the [Open-sourcing Knowledge Distillation Code and Weights of SD-Small and SD-Tiny](https://huggingface.co/blog/sd_distillation) blog post to learn more about how knowledge distillation training works to produce a faster, smaller, and cheaper generative model.
66+
```py
67+
import torch
68+
from diffusers import StableDiffusionXLPipeline
6669

67-
The inference times below are obtained from generating 4 images from the prompt "a photo of an astronaut riding a horse on mars" with 25 PNDM steps on a NVIDIA A100. Each generation is repeated 3 times with the distilled Stable Diffusion v1.4 model by [Nota AI](https://hf.co/nota-ai).
70+
torch.backends.cuda.matmul.allow_tf32 = True
71+
72+
pipeline = StableDiffusionXLPipeline.from_pretrained(
73+
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
74+
).to("cuda")
6875

69-
| setup | latency | speed-up |
70-
|------------------------------|---------|----------|
71-
| baseline | 6.37s | x1 |
72-
| distilled | 4.18s | x1.52 |
73-
| distilled + tiny autoencoder | 3.83s | x1.66 |
76+
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
77+
pipeline(prompt, num_inference_steps=30).images[0]
78+
```
7479

75-
Let's load the distilled Stable Diffusion model and compare it against the original Stable Diffusion model.
80+
Refer to the [mixed precision training](https://huggingface.co/docs/transformers/en/perf_train_gpu_one#mixed-precision) docs for more details.
81+
82+
</hfoption>
83+
</hfoptions>
84+
85+
## Scaled dot product attention
86+
87+
[Scaled dot product attention (SDPA)](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) implements several attention backends, [FlashAttention](https://github.com/Dao-AILab/flash-attention), [xFormers](https://github.com/facebookresearch/xformers), and a native C++ implementation. It automatically selects the most optimal backend for your hardware.
88+
89+
SDPA is enabled by default if you're using PyTorch >= 2.0 and no additional changes are required to your code. You could try experimenting with other attention backends though if you'd like to choose your own. The example below uses the [torch.nn.attention.sdpa_kernel](https://pytorch.org/docs/stable/generated/torch.nn.attention.sdpa_kernel.html) context manager to enable efficient attention.
7690

7791
```py
78-
from diffusers import StableDiffusionPipeline
92+
from torch.nn.attention import SDPBackend, sdpa_kernel
7993
import torch
94+
from diffusers import StableDiffusionXLPipeline
8095

81-
distilled = StableDiffusionPipeline.from_pretrained(
82-
"nota-ai/bk-sdm-small", torch_dtype=torch.float16, use_safetensors=True,
96+
pipeline = StableDiffusionXLPipeline.from_pretrained(
97+
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
8398
).to("cuda")
84-
prompt = "a golden vase with different flowers"
85-
generator = torch.manual_seed(2023)
86-
image = distilled("a golden vase with different flowers", num_inference_steps=25, generator=generator).images[0]
87-
image
99+
100+
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
101+
102+
with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION):
103+
image = pipeline(prompt, num_inference_steps=30).images[0]
88104
```
89105

90-
<div class="flex gap-4">
91-
<div>
92-
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/original_sd.png"/>
93-
<figcaption class="mt-2 text-center text-sm text-gray-500">original Stable Diffusion</figcaption>
94-
</div>
95-
<div>
96-
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/distilled_sd.png"/>
97-
<figcaption class="mt-2 text-center text-sm text-gray-500">distilled Stable Diffusion</figcaption>
98-
</div>
99-
</div>
106+
## torch.compile
100107

101-
### Tiny AutoEncoder
108+
[torch.compile](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html) accelerates inference by compiling PyTorch code and operations into optimized kernels. Diffusers typically compiles the more compute-intensive models like the UNet, transformer, or VAE.
102109

103-
To speed inference up even more, replace the autoencoder with a [distilled version](https://huggingface.co/sayakpaul/taesdxl-diffusers) of it.
110+
Enable the following compiler settings for maximum speed (refer to the [full list](https://github.com/pytorch/pytorch/blob/main/torch/_inductor/config.py) for more options).
104111

105112
```py
106113
import torch
107-
from diffusers import AutoencoderTiny, StableDiffusionPipeline
114+
from diffusers import StableDiffusionXLPipeline
115+
116+
torch._inductor.config.conv_1x1_as_mm = True
117+
torch._inductor.config.coordinate_descent_tuning = True
118+
torch._inductor.config.epilogue_fusion = False
119+
torch._inductor.config.coordinate_descent_check_all_directions = True
120+
```
121+
122+
Load and compile the UNet and VAE. There are several different modes you can choose from, but `"max-autotune"` optimizes for the fastest speed by compiling to a CUDA graph. CUDA graphs effectively reduces the overhead by launching multiple GPU operations through a single CPU operation.
108123

109-
distilled = StableDiffusionPipeline.from_pretrained(
110-
"nota-ai/bk-sdm-small", torch_dtype=torch.float16, use_safetensors=True,
124+
> [!TIP]
125+
> With PyTorch 2.3.1, you can control the caching behavior of torch.compile. This is particularly beneficial for compilation modes like `"max-autotune"` which performs a grid-search over several compilation flags to find the optimal configuration. Learn more in the [Compile Time Caching in torch.compile](https://pytorch.org/tutorials/recipes/torch_compile_caching_tutorial.html) tutorial.
126+
127+
Changing the memory layout to [channels_last](./memory#torchchannels_last) also optimizes memory and inference speed.
128+
129+
```py
130+
pipeline = StableDiffusionXLPipeline.from_pretrained(
131+
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
111132
).to("cuda")
112-
distilled.vae = AutoencoderTiny.from_pretrained(
113-
"sayakpaul/taesd-diffusers", torch_dtype=torch.float16, use_safetensors=True,
133+
pipeline.unet.to(memory_format=torch.channels_last)
134+
pipeline.vae.to(memory_format=torch.channels_last)
135+
pipeline.unet = torch.compile(pipeline.unet,
136+
mode="max-autotune",
137+
fullgraph=True
138+
)
139+
pipeline.vae.decode = torch.compile(
140+
pipeline.vae.decode,
141+
mode="max-autotune",
142+
fullgraph=True
143+
)
144+
145+
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
146+
pipeline(prompt, num_inference_steps=30).images[0]
147+
```
148+
149+
Compilation is slow the first time, but once compiled, it is significantly faster. Try to only use the compiled pipeline on the same type of inference operations. Calling the compiled pipeline on a different image size retriggers compilation which is slow and inefficient.
150+
151+
### Graph breaks
152+
153+
It is important to specify `fullgraph=True` in torch.compile to ensure there are no graph breaks in the underlying model. This allows you to take advantage of torch.compile without any performance degradation. For the UNet and VAE, this changes how you access the return variables.
154+
155+
```diff
156+
- latents = unet(
157+
- latents, timestep=timestep, encoder_hidden_states=prompt_embeds
158+
-).sample
159+
160+
+ latents = unet(
161+
+ latents, timestep=timestep, encoder_hidden_states=prompt_embeds, return_dict=False
162+
+)[0]
163+
```
164+
165+
### GPU sync
166+
167+
The `step()` function is [called](https://github.com/huggingface/diffusers/blob/1d686bac8146037e97f3fd8c56e4063230f71751/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L1228) on the scheduler each time after the denoiser makes a prediction, and the `sigmas` variable is [indexed](https://github.com/huggingface/diffusers/blob/1d686bac8146037e97f3fd8c56e4063230f71751/src/diffusers/schedulers/scheduling_euler_discrete.py#L476). When placed on the GPU, it introduces latency because of the communication sync between the CPU and GPU. It becomes more evident when the denoiser has already been compiled.
168+
169+
In general, the `sigmas` should [stay on the CPU](https://github.com/huggingface/diffusers/blob/35a969d297cba69110d175ee79c59312b9f49e1e/src/diffusers/schedulers/scheduling_euler_discrete.py#L240) to avoid the communication sync and latency.
170+
171+
## Dynamic quantization
172+
173+
[Dynamic quantization](https://pytorch.org/tutorials/recipes/recipes/dynamic_quantization.html) improves inference speed by reducing precision to enable faster math operations. This particular type of quantization determines how to scale the activations based on the data at runtime rather than using a fixed scaling factor. As a result, the scaling factor is more accurately aligned with the data.
174+
175+
The example below applies [dynamic int8 quantization](https://pytorch.org/tutorials/recipes/recipes/dynamic_quantization.html) to the UNet and VAE with the [torchao](../quantization/torchao) library.
176+
177+
Configure the compiler tags for maximum speed.
178+
179+
```py
180+
import torch
181+
from torchao import apply_dynamic_quant
182+
from diffusers import StableDiffusionXLPipeline
183+
184+
torch._inductor.config.conv_1x1_as_mm = True
185+
torch._inductor.config.coordinate_descent_tuning = True
186+
torch._inductor.config.epilogue_fusion = False
187+
torch._inductor.config.coordinate_descent_check_all_directions = True
188+
torch._inductor.config.force_fuse_int_mm_with_mul = True
189+
torch._inductor.config.use_mixed_mm = True
190+
```
191+
192+
Filter out some linear layers in the UNet and VAE which don't benefit from dynamic quantization with the [dynamic_quant_filter_fn](https://github.com/huggingface/diffusion-fast/blob/0f169640b1db106fe6a479f78c1ed3bfaeba3386/utils/pipeline_utils.py#L16).
193+
194+
```py
195+
pipeline = StableDiffusionXLPipeline.from_pretrained(
196+
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
114197
).to("cuda")
115198

116-
prompt = "a golden vase with different flowers"
117-
generator = torch.manual_seed(2023)
118-
image = distilled("a golden vase with different flowers", num_inference_steps=25, generator=generator).images[0]
119-
image
199+
apply_dynamic_quant(pipeline.unet, dynamic_quant_filter_fn)
200+
apply_dynamic_quant(pipeline.vae, dynamic_quant_filter_fn)
201+
202+
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
203+
pipeline(prompt, num_inference_steps=30).images[0]
120204
```
121205

122-
<div class="flex justify-center">
123-
<div>
124-
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/distilled_sd_vae.png" />
125-
<figcaption class="mt-2 text-center text-sm text-gray-500">distilled Stable Diffusion + Tiny AutoEncoder</figcaption>
126-
</div>
127-
</div>
206+
## Fused projection matrices
207+
208+
> [!WARNING]
209+
> The [fuse_qkv_projections](https://github.com/huggingface/diffusers/blob/58431f102cf39c3c8a569f32d71b2ea8caa461e1/src/diffusers/pipelines/pipeline_utils.py#L2034) method is experimental and support is limited to mostly Stable Diffusion pipelines. Take a look at this [PR](https://github.com/huggingface/diffusers/pull/6179) to learn more about how to enable it for other pipelines
210+
211+
An input is projected into three subspaces, represented by the projection matrices Q, K, and V, in an attention block. These projections are typically calculated separately, but you can horizontally combine these into a single matrix and perform the projection in a single step. It increases the size of the matrix multiplications of the input projections and also improves the impact of quantization.
212+
213+
```py
214+
pipeline.fuse_qkv_projections()
215+
```
216+
217+
## Distilled models
218+
219+
Another option for accelerating inference is to use a smaller distilled model if it's available. During distillation, many of the UNet's residual and attention blocks are discarded to reduce model size and improve latency. A distilled model is faster and uses less memory without compromising quality compared to a full-sized model.
220+
221+
> [!TIP]
222+
> Read [Open-sourcing Knowledge Distillation Code and Weights of SD-Small and SD-Tiny](https://huggingface.co/blog/sd_distillation) to learn more about how knowledge distillation training works to produce a faster, smaller, and cheaper generative model.
223+
224+
The example below uses a distilled Stable Diffusion XL model and VAE.
225+
226+
```py
227+
import torch
228+
from diffusers import DiffusionPipeline, AutoencoderTiny
229+
230+
pipeline = DiffusionPipeline.from_pretrained(
231+
"segmind/SSD-1B", torch_dtype=torch.float16
232+
)
233+
pipeline.vae = AutoencoderTiny.from_pretrained(
234+
"madebyollin/taesdxl", torch_dtype=torch.float16
235+
)
236+
pipeline = pipeline.to("cuda")
128237

129-
More tiny autoencoder models for other Stable Diffusion models, like Stable Diffusion 3, are available from [madebyollin](https://huggingface.co/madebyollin).
238+
prompt = "slice of delicious New York-style cheesecake topped with berries, mint, chocolate crumble"
239+
pipeline(prompt, num_inference_steps=50).images[0]
240+
```

docs/source/en/optimization/memory.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,9 @@ pipeline(["An astronaut riding a horse on Mars"]*32).images[0]
125125
print(f"Max memory reserved: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB")
126126
```
127127

128+
> [!WARNING]
129+
> [`AutoencoderKLWan`] and [`AsymmetricAutoencoderKL`] don't support slicing.
130+
128131
## VAE tiling
129132

130133
VAE tiling saves memory by dividing an image into smaller overlapping tiles instead of processing the entire image at once. This also reduces peak memory usage because the GPU is only processing a tile at a time. Unlike sliced VAE, tiled VAE maintains some context between tiles because they overlap which can generate more coherent images.
@@ -147,6 +150,9 @@ pipeline(prompt, image=init_image, strength=0.5).images[0]
147150
print(f"Max memory reserved: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB")
148151
```
149152

153+
> [!WARNING]
154+
> [`AutoencoderKLWan`] and [`AsymmetricAutoencoderKL`] don't support tiling.
155+
150156
## CPU offloading
151157

152158
CPU offloading selectively moves weights from the GPU to the CPU. When a component is required, it is transferred to the GPU and when it isn't required, it is moved to the CPU. This method works on submodules rather than whole models. It saves memory by avoiding storing the entire model on the GPU.

0 commit comments

Comments
 (0)