You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
@@ -10,120 +10,231 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
10
10
specific language governing permissions and limitations under the License.
11
11
-->
12
12
13
-
# Speed up inference
13
+
# Accelerate inference
14
14
15
-
There are several ways to optimize Diffusers for inference speed, such as reducing the computational burden by lowering the data precision or using a lightweight distilled model. There are also memory-efficient attention implementations, [xFormers](xformers) and [scaled dot product attention](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) in PyTorch 2.0, that reduce memory usage which also indirectly speeds up inference. Different speed optimizations can be stacked together to get the fastest inference times.
15
+
Diffusion models are slow at inference because generation is an iterative process where noise is gradually refined into an image or video over a certain number of "steps". To speedup this process, you can try experimenting with different [schedulers](../api/schedulers/overview), reduce the precision of the model weights for faster computations, use more memory-efficient attention mechanisms, and more.
16
16
17
-
> [!TIP]
18
-
> Optimizing for inference speed or reduced memory usage can lead to improved performance in the other category, so you should try to optimize for both whenever you can. This guide focuses on inference speed, but you can learn more about lowering memory usage in the [Reduce memory usage](memory) guide.
17
+
Combine and use these techniques together to make inference faster than using any single technique on its own.
18
+
19
+
This guide will go over how to accelerate inference.
19
20
20
-
The inference times below are obtained from generating a single 512x512 image from the prompt "a photo of an astronaut riding a horse on mars" with 50 DDIM steps on a NVIDIA A100.
21
+
## Model data type
21
22
22
-
| setup | latency | speed-up |
23
-
|----------|---------|----------|
24
-
| baseline | 5.27s | x1 |
25
-
| tf32 | 4.14s | x1.27 |
26
-
| fp16 | 3.51s | x1.50 |
27
-
| combined | 3.41s | x1.54 |
23
+
The precision and data type of the model weights affect inference speed because a higher precision requires more memory to load and more time to perform the computations. PyTorch loads model weights in float32 or full precision by default, so changing the data type is a simple way to quickly get faster inference.
28
24
29
-
## TensorFloat-32
25
+
<hfoptionsid="dtypes">
26
+
<hfoptionid="bfloat16">
30
27
31
-
On Ampere and later CUDA devices, matrix multiplications and convolutions can use the [TensorFloat-32 (tf32)](https://blogs.nvidia.com/blog/2020/05/14/tensorfloat-32-precision-format/) mode for faster, but slightly less accurate computations. By default, PyTorch enables tf32 mode for convolutions but not matrix multiplications. Unless your network requires full float32 precision, we recommend enabling tf32 for matrix multiplications. It can significantly speed up computations with typically negligible loss in numerical accuracy.
28
+
bfloat16 is similar to float16 but it is more robust to numerical errors. Hardware support for bfloat16 varies, but most modern GPUs are capable of supporting bfloat16.
> Don't use [torch.autocast](https://pytorch.org/docs/stable/amp.html#torch.autocast) in any of the pipelines as it can lead to black images and is always slower than pure float16 precision.
59
+
</hfoption>
60
+
<hfoptionid="TensorFloat-32">
59
61
60
-
## Distilled model
62
+
[TensorFloat-32 (tf32)](https://blogs.nvidia.com/blog/2020/05/14/tensorfloat-32-precision-format/) mode is supported on NVIDIA Ampere GPUs and it computes the convolution and matrix multiplication operations in tf32. Storage and other operations are kept in float32. This enables significantly faster computations when combined with bfloat16 or float16.
61
63
62
-
You could also use a distilled Stable Diffusion model and autoencoder to speed up inference. During distillation, many of the UNet's residual and attention blocks are shed to reduce the model size by 51% and improve latency on CPU/GPU by 43%. The distilled model is faster and uses less memory while generating images of comparable quality to the full Stable Diffusion model.
64
+
PyTorch only enables tf32 mode for convolutions by default and you'll need to explicitly enable it for matrix multiplications.
63
65
64
-
> [!TIP]
65
-
> Read the [Open-sourcing Knowledge Distillation Code and Weights of SD-Small and SD-Tiny](https://huggingface.co/blog/sd_distillation) blog post to learn more about how knowledge distillation training works to produce a faster, smaller, and cheaper generative model.
66
+
```py
67
+
import torch
68
+
from diffusers import StableDiffusionXLPipeline
66
69
67
-
The inference times below are obtained from generating 4 images from the prompt "a photo of an astronaut riding a horse on mars" with 25 PNDM steps on a NVIDIA A100. Each generation is repeated 3 times with the distilled Stable Diffusion v1.4 model by [Nota AI](https://hf.co/nota-ai).
Let's load the distilled Stable Diffusion model and compare it against the original Stable Diffusion model.
80
+
Refer to the [mixed precision training](https://huggingface.co/docs/transformers/en/perf_train_gpu_one#mixed-precision) docs for more details.
81
+
82
+
</hfoption>
83
+
</hfoptions>
84
+
85
+
## Scaled dot product attention
86
+
87
+
[Scaled dot product attention (SDPA)](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) implements several attention backends, [FlashAttention](https://github.com/Dao-AILab/flash-attention), [xFormers](https://github.com/facebookresearch/xformers), and a native C++ implementation. It automatically selects the most optimal backend for your hardware.
88
+
89
+
SDPA is enabled by default if you're using PyTorch >= 2.0 and no additional changes are required to your code. You could try experimenting with other attention backends though if you'd like to choose your own. The example below uses the [torch.nn.attention.sdpa_kernel](https://pytorch.org/docs/stable/generated/torch.nn.attention.sdpa_kernel.html) context manager to enable efficient attention.
[torch.compile](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html) accelerates inference by compiling PyTorch code and operations into optimized kernels. Diffusers typically compiles the more compute-intensive models like the UNet, transformer, or VAE.
102
109
103
-
To speed inference up even more, replace the autoencoder with a [distilled version](https://huggingface.co/sayakpaul/taesdxl-diffusers) of it.
110
+
Enable the following compiler settings for maximum speed (refer to the [full list](https://github.com/pytorch/pytorch/blob/main/torch/_inductor/config.py) for more options).
104
111
105
112
```py
106
113
import torch
107
-
from diffusers import AutoencoderTiny, StableDiffusionPipeline
Load and compile the UNet and VAE. There are several different modes you can choose from, but `"max-autotune"` optimizes for the fastest speed by compiling to a CUDA graph. CUDA graphs effectively reduces the overhead by launching multiple GPU operations through a single CPU operation.
> With PyTorch 2.3.1, you can control the caching behavior of torch.compile. This is particularly beneficial for compilation modes like `"max-autotune"` which performs a grid-search over several compilation flags to find the optimal configuration. Learn more in the [Compile Time Caching in torch.compile](https://pytorch.org/tutorials/recipes/torch_compile_caching_tutorial.html) tutorial.
126
+
127
+
Changing the memory layout to [channels_last](./memory#torchchannels_last) also optimizes memory and inference speed.
Compilation is slow the first time, but once compiled, it is significantly faster. Try to only use the compiled pipeline on the same type of inference operations. Calling the compiled pipeline on a different image size retriggers compilation which is slow and inefficient.
150
+
151
+
### Graph breaks
152
+
153
+
It is important to specify `fullgraph=True` in torch.compile to ensure there are no graph breaks in the underlying model. This allows you to take advantage of torch.compile without any performance degradation. For the UNet and VAE, this changes how you access the return variables.
The `step()` function is [called](https://github.com/huggingface/diffusers/blob/1d686bac8146037e97f3fd8c56e4063230f71751/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L1228) on the scheduler each time after the denoiser makes a prediction, and the `sigmas` variable is [indexed](https://github.com/huggingface/diffusers/blob/1d686bac8146037e97f3fd8c56e4063230f71751/src/diffusers/schedulers/scheduling_euler_discrete.py#L476). When placed on the GPU, it introduces latency because of the communication sync between the CPU and GPU. It becomes more evident when the denoiser has already been compiled.
168
+
169
+
In general, the `sigmas` should [stay on the CPU](https://github.com/huggingface/diffusers/blob/35a969d297cba69110d175ee79c59312b9f49e1e/src/diffusers/schedulers/scheduling_euler_discrete.py#L240) to avoid the communication sync and latency.
170
+
171
+
## Dynamic quantization
172
+
173
+
[Dynamic quantization](https://pytorch.org/tutorials/recipes/recipes/dynamic_quantization.html) improves inference speed by reducing precision to enable faster math operations. This particular type of quantization determines how to scale the activations based on the data at runtime rather than using a fixed scaling factor. As a result, the scaling factor is more accurately aligned with the data.
174
+
175
+
The example below applies [dynamic int8 quantization](https://pytorch.org/tutorials/recipes/recipes/dynamic_quantization.html) to the UNet and VAE with the [torchao](../quantization/torchao) library.
Filter out some linear layers in the UNet and VAE which don't benefit from dynamic quantization with the [dynamic_quant_filter_fn](https://github.com/huggingface/diffusion-fast/blob/0f169640b1db106fe6a479f78c1ed3bfaeba3386/utils/pipeline_utils.py#L16).
> The [fuse_qkv_projections](https://github.com/huggingface/diffusers/blob/58431f102cf39c3c8a569f32d71b2ea8caa461e1/src/diffusers/pipelines/pipeline_utils.py#L2034) method is experimental and support is limited to mostly Stable Diffusion pipelines. Take a look at this [PR](https://github.com/huggingface/diffusers/pull/6179) to learn more about how to enable it for other pipelines
210
+
211
+
An input is projected into three subspaces, represented by the projection matrices Q, K, and V, in an attention block. These projections are typically calculated separately, but you can horizontally combine these into a single matrix and perform the projection in a single step. It increases the size of the matrix multiplications of the input projections and also improves the impact of quantization.
212
+
213
+
```py
214
+
pipeline.fuse_qkv_projections()
215
+
```
216
+
217
+
## Distilled models
218
+
219
+
Another option for accelerating inference is to use a smaller distilled model if it's available. During distillation, many of the UNet's residual and attention blocks are discarded to reduce model size and improve latency. A distilled model is faster and uses less memory without compromising quality compared to a full-sized model.
220
+
221
+
> [!TIP]
222
+
> Read [Open-sourcing Knowledge Distillation Code and Weights of SD-Small and SD-Tiny](https://huggingface.co/blog/sd_distillation) to learn more about how knowledge distillation training works to produce a faster, smaller, and cheaper generative model.
223
+
224
+
The example below uses a distilled Stable Diffusion XL model and VAE.
225
+
226
+
```py
227
+
import torch
228
+
from diffusers import DiffusionPipeline, AutoencoderTiny
229
+
230
+
pipeline = DiffusionPipeline.from_pretrained(
231
+
"segmind/SSD-1B", torch_dtype=torch.float16
232
+
)
233
+
pipeline.vae = AutoencoderTiny.from_pretrained(
234
+
"madebyollin/taesdxl", torch_dtype=torch.float16
235
+
)
236
+
pipeline = pipeline.to("cuda")
128
237
129
-
More tiny autoencoder models for other Stable Diffusion models, like Stable Diffusion 3, are available from [madebyollin](https://huggingface.co/madebyollin).
238
+
prompt ="slice of delicious New York-style cheesecake topped with berries, mint, chocolate crumble"
> [`AutoencoderKLWan`] and [`AsymmetricAutoencoderKL`] don't support slicing.
130
+
128
131
## VAE tiling
129
132
130
133
VAE tiling saves memory by dividing an image into smaller overlapping tiles instead of processing the entire image at once. This also reduces peak memory usage because the GPU is only processing a tile at a time. Unlike sliced VAE, tiled VAE maintains some context between tiles because they overlap which can generate more coherent images.
> [`AutoencoderKLWan`] and [`AsymmetricAutoencoderKL`] don't support tiling.
155
+
150
156
## CPU offloading
151
157
152
158
CPU offloading selectively moves weights from the GPU to the CPU. When a component is required, it is transferred to the GPU and when it isn't required, it is moved to the CPU. This method works on submodules rather than whole models. It saves memory by avoiding storing the entire model on the GPU.
0 commit comments