You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/source/en/optimization/memory.md
+19-19Lines changed: 19 additions & 19 deletions
Original file line number
Diff line number
Diff line change
@@ -12,9 +12,9 @@ specific language governing permissions and limitations under the License.
12
12
13
13
# Reduce memory usage
14
14
15
-
Modern diffusion models like [Flux](../api/pipelines/flux) and [Wan](../api/pipelines/wan) have billions of parameters that take up a lot of memory on your hardware for inference. This poses a challenge because common GPUs often don't have sufficient memory.
15
+
Modern diffusion models like [Flux](../api/pipelines/flux) and [Wan](../api/pipelines/wan) have billions of parameters that take up a lot of memory on your hardware for inference. This is challenging because common GPUs often don't have sufficient memory.
16
16
17
-
To overcome these memory constraints, you can use a second GPU (if available), offload some of the pipeline components to the CPU, and more. This guide will show you how to reduce your memory usage.
17
+
To overcome the memory limitations, you can use more than one GPU (if available), offload some of the pipeline components to the CPU, and more. This guide will show you how to reduce your memory usage.
18
18
19
19
## Multiple GPUs
20
20
@@ -26,9 +26,9 @@ pip install -U accelerate
26
26
27
27
### Sharded checkpoints
28
28
29
-
Loading large checkpoints in several shards in useful because shards are loaded one at a time. This keeps memory usage low, only requiring enough memory for the model size and the largest shard size. We recommend sharding when the fp32 checkpoint is greater than 5GB. The default shard size is 5GB.
29
+
Loading large checkpoints in several shards in useful because the shards are loaded one at a time. This keeps memory usage low, only requiring enough memory for the model size and the largest shard size. We recommend sharding when the fp32 checkpoint is greater than 5GB. The default shard size is 5GB.
30
30
31
-
You can shard a checkpoint in [`~DiffusionPipeline.save_pretrained`] with the `max_shard_size` parameter.
31
+
Shard a checkpoint in [`~DiffusionPipeline.save_pretrained`] with the `max_shard_size` parameter.
> Device placement is an experimental feature and the API may change. Only the `balanced` strategy is supported at the moment, and we plan to support additional mapping strategies in the future.
61
+
> Device placement is an experimental feature and the API may change. Only the `balanced` strategy is supported at the moment. We plan to support additional mapping strategies in the future.
62
62
63
-
The `device_map` parameter allows you to control how the model components in a pipeline are distributed across your devices. The `balanced` device placement strategy evenly splits the pipeline across all available devices.
63
+
The `device_map` parameter controls how the model components in a pipeline are distributed across devices. The `balanced` device placement strategy evenly splits the pipeline across all available devices.
Diffusers uses the maxmium memory of all devices, but if they don't fit on the GPUs, then you'll need to use a single GPU and offload to the CPU with the methods below.
98
+
Diffusers uses the maxmium memory of all devices by default, but if they don't fit on the GPUs, then you'll need to use a single GPU and offload to the CPU with the methods below.
99
99
100
-
-[`~DiffusionPipeline.enable_model_cpu_offload`] only works on a single GPU and a model may not fit on it
100
+
-[`~DiffusionPipeline.enable_model_cpu_offload`] only works on a single GPU but a very large model may not fit on it
101
101
-[`~DiffusionPipeline.enable_sequential_cpu_offload`] may work but it is extremely slow and also limited to a single GPU
102
102
103
103
Use the [`~DiffusionPipeline.reset_device_map`] method to reset the `device_map`. This is necessary if you want to use methods like `.to()`, [`~DiffusionPipeline.enable_sequential_cpu_offload`], and [`~DiffusionPipeline.enable_model_cpu_offload`] on a pipeline that was device-mapped.
@@ -108,9 +108,9 @@ pipeline.reset_device_map
108
108
109
109
## Sliced VAE
110
110
111
-
Sliced VAE saves memory by processing an image in smaller non-overlapping "slices" instead of processing the entire image at once. This reduces peak memory usage because the GPU is only processing one slice at a time.
111
+
Sliced VAE saves memory by processing an image in smaller non-overlapping "slices" instead of processing the entire image at once. This reduces peak memory usage because the GPU is only processing a small slice at a time.
112
112
113
-
Call [`~DiffusionPipeline.enable_vae_slicing`] to sliced VAE. You can expect a small increase in performance when decoding multi-image batches and no performance impact for single-image batches.
113
+
Call [`~StableDiffusionPipeline.enable_vae_slicing`] to enable sliced VAE. You can expect a small increase in performance when decoding multi-image batches and no performance impact for single-image batches.
VAE tiling saves memory by dividing an image into smaller overlapping tiles instead of processing the entire image at once. This also reduces peak memory usage because the GPU is only processing a tile at a time. Unlike sliced VAE, tiled VAE maintains some context between tiles because they overlap which can generate more coherent images.
131
131
132
-
Call [`~DiffusionPipeline.enable_vae_tiling`] to enable VAE tiling. The generate image may have some tone variation from tile-to-tile because they're decoded separately, but there shouldn't be any obvious seams between the tiles. Tiling is disabled for images that are 512x512 or smaller.
132
+
Call [`~StableDiffusionPipeline.enable_vae_tiling`] to enable VAE tiling. The generated image may have some tone variation from tile-to-tile because they're decoded separately, but there shouldn't be any obvious seams between the tiles. Tiling is disabled for images that are 512x512 or smaller.
CPU offloading selectively moves weights from the GPU to the CPU to reduce memory usage. When a component is required, it is transferred to the GPU and when it isn't required, it is moved to the CPU. This method works on submodules rather than whole models.
152
+
CPU offloading selectively moves weights from the GPU to the CPU. When a component is required, it is transferred to the GPU and when it isn't required, it is moved to the CPU. This method works on submodules rather than whole models. It saves memory by avoiding storing the entire model on the GPU.
153
153
154
154
CPU offloading dramatically reduces memory usage, but it is also extremely slow because submodules are passed back and forth multiple times between devices.
155
155
@@ -214,7 +214,7 @@ Group offloading moves groups of internal layers ([torch.nn.ModuleList](https://
214
214
> [!WARNING]
215
215
> Group offloading may not work with all models if the forward implementation contains weight-dependent device casting of inputs because it may clash with group offloading's device casting mechanism.
216
216
217
-
Call [`ModelMixin.enable_group_offload`] to enable it for standard Diffusers model components that inherit from [`ModelMixin`]. For other model components that don't inherit from [`ModelMixin`], such as a generic [torch.nn.Module](https://pytorch.org/docs/stable/generated/torch.nn.Module.html), use [`~hooks.apply_group_offloading`] instead.
217
+
Call [`~ModelMixin.enable_group_offload`] to enable it for standard Diffusers model components that inherit from [`ModelMixin`]. For other model components that don't inherit from [`ModelMixin`], such as a generic [torch.nn.Module](https://pytorch.org/docs/stable/generated/torch.nn.Module.html), use [`~hooks.apply_group_offloading`] instead.
218
218
219
219
The `offload_type` parameter can be set to `block_level` or `leaf_level`.
220
220
@@ -331,7 +331,7 @@ apply_layerwise_casting(
331
331
332
332
## torch.channels_last
333
333
334
-
[torch.channels_last](https://pytorch.org/tutorials/intermediate/memory_format_tutorial.html) flips how tensors are stored from `batch size, channels, height, width` to `batch size, heigh, width, channels`. This aligns the tensors with how the hardware sequentially accesses the tensors stored in memory and avoids skipping around in memory to access the pixel values.
334
+
[torch.channels_last](https://pytorch.org/tutorials/intermediate/memory_format_tutorial.html) flips how tensors are stored from `(batch size, channels, height, width)` to `(batch size, heigh, width, channels)`. This aligns the tensors with how the hardware sequentially accesses the tensors stored in memory and avoids skipping around in memory to access the pixel values.
335
335
336
336
Not all operators currently support the channels-last format and may result in worst performance, but it is still worth trying.
337
337
@@ -454,15 +454,15 @@ with torch.inference_mode():
The Transformers attention mechanism is memory-intensive, especially for long sequences, so you can try using different and more memory-efficient attention types.
460
460
461
-
By default, if PyTorch >= 2.0 is installed, the PyTorch [scaled dot-product attention (SDPA)](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) is used. You don't need to make any additional changes to your code.
461
+
By default, if PyTorch >= 2.0 is installed, [scaled dot-product attention (SDPA)](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) is used. You don't need to make any additional changes to your code.
462
462
463
463
SDPA supports [FlashAttention](https://github.com/Dao-AILab/flash-attention) and [xFormers](https://github.com/facebookresearch/xformers) as well as a native C++ PyTorch implementation. It automatically selects the most optimal implementation based on your input.
464
464
465
-
You can also explicitly use xFormers with the [`~ModelMixin.enable_xformers_memory_efficient_attention`] method.
465
+
You can explicitly use xFormers with the [`~ModelMixin.enable_xformers_memory_efficient_attention`] method.
0 commit comments