Skip to content

Commit 34dad19

Browse files
authored
[Bugfix] set default set cuda_graph_sizes to min(self.max_num_seqs * 2, 512) (#20628)
Signed-off-by: izhuhaoran <izhuhaoran@qq.com>
1 parent 6db31e7 commit 34dad19

File tree

1 file changed

+12
-4
lines changed

1 file changed

+12
-4
lines changed

vllm/config.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2147,11 +2147,12 @@ class SchedulerConfig:
21472147
NOTE: This will be replaced by speculative config in the future; it is
21482148
present to enable correctness tests until then."""
21492149

2150-
cuda_graph_sizes: list[int] = field(default_factory=lambda: [512])
2151-
"""Cuda graph capture sizes, default is 512.
2152-
1. if one value is provided, then the capture list would follow the
2150+
cuda_graph_sizes: list[int] = field(default_factory=list)
2151+
"""Cuda graph capture sizes
2152+
1. if none provided, then default set to [min(max_num_seqs * 2, 512)]
2153+
2. if one value is provided, then the capture list would follow the
21532154
pattern: [1, 2, 4] + [i for i in range(8, cuda_graph_sizes + 1, 8)]
2154-
2. more than one value (e.g. 1 2 128) is provided, then the capture list
2155+
3. more than one value (e.g. 1 2 128) is provided, then the capture list
21552156
will follow the provided list."""
21562157

21572158
delay_factor: float = 0.0
@@ -2316,6 +2317,13 @@ def __post_init__(self) -> None:
23162317
self.max_num_partial_prefills, self.max_long_partial_prefills,
23172318
self.long_prefill_token_threshold)
23182319

2320+
# NOTE: Default set cuda_graph_sizes to [min(max_num_seqs * 2, 512)].
2321+
# This avoids OOM in tight memory scenarios with small max_num_seqs,
2322+
# and prevents capture of many large graphs (>512) that would greatly
2323+
# increase startup time with limited performance benefit.
2324+
if not self.cuda_graph_sizes:
2325+
self.cuda_graph_sizes = [min(self.max_num_seqs * 2, 512)]
2326+
23192327
@model_validator(mode='after')
23202328
def _verify_args(self) -> Self:
23212329
if (self.max_num_batched_tokens < self.max_model_len

0 commit comments

Comments
 (0)