Skip to content

Commit 759e663

Browse files
DarkLight1337yangw-dev
authored andcommitted
[Doc] Improve OOM troubleshooting (vllm-project#16704)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Signed-off-by: Yang Wang <elainewy@meta.com>
1 parent c2bb60e commit 759e663

File tree

2 files changed

+54
-4
lines changed

2 files changed

+54
-4
lines changed

docs/source/getting_started/troubleshooting.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ To isolate the model downloading and loading issue, you can use the `--load-form
2424

2525
## Out of memory
2626

27-
If the model is too large to fit in a single GPU, you will get an out-of-memory (OOM) error. Consider [using tensor parallelism](#distributed-serving) to split the model across multiple GPUs. In that case, every process will read the whole model and split it into chunks, which makes the disk reading time even longer (proportional to the size of tensor parallelism). You can convert the model checkpoint to a sharded checkpoint using <gh-file:examples/offline_inference/save_sharded_state.py>. The conversion process might take some time, but later you can load the sharded checkpoint much faster. The model loading time should remain constant regardless of the size of tensor parallelism.
27+
If the model is too large to fit in a single GPU, you will get an out-of-memory (OOM) error. Consider adopting [these options](#reducing-memory-usage) to reduce the memory consumption.
2828

2929
## Generation quality changed
3030

docs/source/serving/offline_inference.md

Lines changed: 53 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ model = LLM(
5959

6060
Our [list of supported models](#supported-models) shows the model architectures that are recognized by vLLM.
6161

62+
(reducing-memory-usage)=
63+
6264
### Reducing memory usage
6365

6466
Large models might cause your machine to run out of memory (OOM). Here are some options that help alleviate this problem.
@@ -81,6 +83,12 @@ before initializing vLLM. Otherwise, you may run into an error like `RuntimeErro
8183
To control which devices are used, please instead set the `CUDA_VISIBLE_DEVICES` environment variable.
8284
:::
8385

86+
:::{note}
87+
With tensor parallelism enabled, each process will read the whole model and split it into chunks, which makes the disk reading time even longer (proportional to the size of tensor parallelism).
88+
89+
You can convert the model checkpoint to a sharded checkpoint using <gh-file:examples/offline_inference/save_sharded_state.py>. The conversion process might take some time, but later you can load the sharded checkpoint much faster. The model loading time should remain constant regardless of the size of tensor parallelism.
90+
:::
91+
8492
#### Quantization
8593

8694
Quantized models take less memory at the cost of lower precision.
@@ -103,23 +111,65 @@ llm = LLM(model="adept/fuyu-8b",
103111
max_num_seqs=2)
104112
```
105113

114+
#### Reduce CUDA Graphs
115+
116+
By default, we optimize model inference using CUDA graphs which take up extra memory in the GPU.
117+
118+
:::{important}
119+
CUDA graph capture takes up more memory in V1 than in V0.
120+
:::
121+
122+
You can adjust `compilation_config` to achieve a better balance between inference speed and memory usage:
123+
124+
```python
125+
from vllm import LLM
126+
from vllm.config import CompilationConfig, CompilationLevel
127+
128+
llm = LLM(
129+
model="meta-llama/Llama-3.1-8B-Instruct",
130+
compilation_config=CompilationConfig(
131+
level=CompilationLevel.PIECEWISE,
132+
# By default, it goes up to max_num_seqs
133+
cudagraph_capture_sizes=[1, 2, 4, 8, 16],
134+
),
135+
)
136+
```
137+
138+
You can disable graph capturing completely via the `enforce_eager` flag:
139+
140+
```python
141+
from vllm import LLM
142+
143+
llm = LLM(model="meta-llama/Llama-3.1-8B-Instruct",
144+
enforce_eager=True)
145+
```
146+
106147
#### Adjust cache size
107148

108149
If you run out of CPU RAM, try the following options:
109150

110151
- (Multi-modal models only) you can set the size of multi-modal input cache using `VLLM_MM_INPUT_CACHE_GIB` environment variable (default 4 GiB).
111152
- (CPU backend only) you can set the size of KV cache using `VLLM_CPU_KVCACHE_SPACE` environment variable (default 4 GiB).
112153

113-
#### Disable unused modalities
154+
#### Multi-modal input limits
114155

115-
You can disable unused modalities (except for text) by setting its limit to zero.
156+
You can allow a smaller number of multi-modal items per prompt to reduce the memory footprint of the model:
157+
158+
```python
159+
from vllm import LLM
160+
161+
# Accept up to 3 images and 1 video per prompt
162+
llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct",
163+
limit_mm_per_prompt={"image": 3, "video": 1})
164+
```
116165

166+
You can go a step further and disable unused modalities completely by setting its limit to zero.
117167
For example, if your application only accepts image input, there is no need to allocate any memory for videos.
118168

119169
```python
120170
from vllm import LLM
121171

122-
# Accept images but not videos
172+
# Accept any number of images but no videos
123173
llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct",
124174
limit_mm_per_prompt={"video": 0})
125175
```

0 commit comments

Comments
 (0)