Skip to content

Commit 3482fd7

Browse files
authored
[Doc] Add engine args back in to the docs (#20674)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
1 parent 77f77a9 commit 3482fd7

File tree

14 files changed

+218
-40
lines changed

14 files changed

+218
-40
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ venv.bak/
146146

147147
# mkdocs documentation
148148
/site
149+
docs/argparse
149150
docs/examples
150151

151152
# mypy

docs/configuration/engine_args.md

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,20 @@
1+
---
2+
toc_depth: 3
3+
---
4+
15
# Engine Arguments
26

37
Engine arguments control the behavior of the vLLM engine.
48

59
- For [offline inference](../serving/offline_inference.md), they are part of the arguments to [LLM][vllm.LLM] class.
610
- For [online serving](../serving/openai_compatible_server.md), they are part of the arguments to `vllm serve`.
711

8-
You can look at [EngineArgs][vllm.engine.arg_utils.EngineArgs] and [AsyncEngineArgs][vllm.engine.arg_utils.AsyncEngineArgs] to see the available engine arguments.
12+
The engine argument classes, [EngineArgs][vllm.engine.arg_utils.EngineArgs] and [AsyncEngineArgs][vllm.engine.arg_utils.AsyncEngineArgs], are a combination of the configuration classes defined in [vllm.config][]. Therefore, if you are interested in developer documentation, we recommend looking at these configuration classes as they are the source of truth for types, defaults and docstrings.
13+
14+
## `EngineArgs`
915

10-
However, these classes are a combination of the configuration classes defined in [vllm.config][]. Therefore, we would recommend you read about them there where they are best documented.
16+
--8<-- "docs/argparse/engine_args.md"
1117

12-
For offline inference you will have access to these configuration classes and for online serving you can cross-reference the configs with `vllm serve --help`, which has its arguments grouped by config.
18+
## `AsyncEngineArgs`
1319

14-
!!! note
15-
Additional arguments are available to the [AsyncLLMEngine][vllm.engine.async_llm_engine.AsyncLLMEngine] which is used for online serving. These can be found by running `vllm serve --help`
20+
--8<-- "docs/argparse/async_engine_args.md"
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import logging
4+
import sys
5+
from argparse import SUPPRESS, HelpFormatter
6+
from pathlib import Path
7+
from typing import Literal
8+
from unittest.mock import MagicMock, patch
9+
10+
ROOT_DIR = Path(__file__).parent.parent.parent.parent
11+
ARGPARSE_DOC_DIR = ROOT_DIR / "docs/argparse"
12+
13+
sys.path.insert(0, str(ROOT_DIR))
14+
sys.modules["aiohttp"] = MagicMock()
15+
sys.modules["blake3"] = MagicMock()
16+
sys.modules["vllm._C"] = MagicMock()
17+
18+
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs # noqa: E402
19+
from vllm.utils import FlexibleArgumentParser # noqa: E402
20+
21+
logger = logging.getLogger("mkdocs")
22+
23+
24+
class MarkdownFormatter(HelpFormatter):
25+
"""Custom formatter that generates markdown for argument groups."""
26+
27+
def __init__(self, prog):
28+
super().__init__(prog,
29+
max_help_position=float('inf'),
30+
width=float('inf'))
31+
self._markdown_output = []
32+
33+
def start_section(self, heading):
34+
if heading not in {"positional arguments", "options"}:
35+
self._markdown_output.append(f"\n### {heading}\n\n")
36+
37+
def end_section(self):
38+
pass
39+
40+
def add_text(self, text):
41+
if text:
42+
self._markdown_output.append(f"{text.strip()}\n\n")
43+
44+
def add_usage(self, usage, actions, groups, prefix=None):
45+
pass
46+
47+
def add_arguments(self, actions):
48+
for action in actions:
49+
50+
option_strings = f'`{"`, `".join(action.option_strings)}`'
51+
self._markdown_output.append(f"#### {option_strings}\n\n")
52+
53+
if choices := action.choices:
54+
choices = f'`{"`, `".join(str(c) for c in choices)}`'
55+
self._markdown_output.append(
56+
f"Possible choices: {choices}\n\n")
57+
58+
self._markdown_output.append(f"{action.help}\n\n")
59+
60+
if (default := action.default) != SUPPRESS:
61+
self._markdown_output.append(f"Default: `{default}`\n\n")
62+
63+
def format_help(self):
64+
"""Return the formatted help as markdown."""
65+
return "".join(self._markdown_output)
66+
67+
68+
def create_parser(cls, **kwargs) -> FlexibleArgumentParser:
69+
"""Create a parser for the given class with markdown formatting.
70+
71+
Args:
72+
cls: The class to create a parser for
73+
**kwargs: Additional keyword arguments to pass to `cls.add_cli_args`.
74+
75+
Returns:
76+
FlexibleArgumentParser: A parser with markdown formatting for the class.
77+
"""
78+
parser = FlexibleArgumentParser()
79+
parser.formatter_class = MarkdownFormatter
80+
with patch("vllm.config.DeviceConfig.__post_init__"):
81+
return cls.add_cli_args(parser, **kwargs)
82+
83+
84+
def on_startup(command: Literal["build", "gh-deploy", "serve"], dirty: bool):
85+
logger.info("Generating argparse documentation")
86+
logger.debug("Root directory: %s", ROOT_DIR.resolve())
87+
logger.debug("Output directory: %s", ARGPARSE_DOC_DIR.resolve())
88+
89+
# Create the ARGPARSE_DOC_DIR if it doesn't exist
90+
if not ARGPARSE_DOC_DIR.exists():
91+
ARGPARSE_DOC_DIR.mkdir(parents=True)
92+
93+
# Create parsers to document
94+
parsers = {
95+
"engine_args": create_parser(EngineArgs),
96+
"async_engine_args": create_parser(AsyncEngineArgs,
97+
async_args_only=True),
98+
}
99+
100+
# Generate documentation for each parser
101+
for stem, parser in parsers.items():
102+
doc_path = ARGPARSE_DOC_DIR / f"{stem}.md"
103+
with open(doc_path, "w") as f:
104+
f.write(parser.format_help())
105+
logger.info("Argparse generated: %s", doc_path.relative_to(ROOT_DIR))

docs/mkdocs/hooks/generate_examples.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,8 +161,8 @@ def on_startup(command: Literal["build", "gh-deploy", "serve"], dirty: bool):
161161
for example in sorted(examples, key=lambda e: e.path.stem):
162162
example_name = f"{example.path.stem}.md"
163163
doc_path = EXAMPLE_DOC_DIR / example.category / example_name
164-
logger.debug("Example generated: %s", doc_path.relative_to(ROOT_DIR))
165164
if not doc_path.parent.exists():
166165
doc_path.parent.mkdir(parents=True)
167166
with open(doc_path, "w+") as f:
168167
f.write(example.generate())
168+
logger.debug("Example generated: %s", doc_path.relative_to(ROOT_DIR))
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
<!-- Enables the use of toc_depth in document frontmatter https://github.com/squidfunk/mkdocs-material/issues/4827#issuecomment-1869812019 -->
2+
<li class="md-nav__item">
3+
<a href="{{ toc_item.url }}" class="md-nav__link">
4+
<span class="md-ellipsis">
5+
{{ toc_item.title }}
6+
</span>
7+
</a>
8+
9+
<!-- Table of contents list -->
10+
{% if toc_item.children %}
11+
<nav class="md-nav" aria-label="{{ toc_item.title | striptags }}">
12+
<ul class="md-nav__list">
13+
{% for toc_item in toc_item.children %}
14+
{% if not page.meta.toc_depth or toc_item.level <= page.meta.toc_depth %}
15+
{% include "partials/toc-item.html" %}
16+
{% endif %}
17+
{% endfor %}
18+
</ul>
19+
</nav>
20+
{% endif %}
21+
</li>

mkdocs.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ site_url: https://docs.vllm.ai
33
repo_url: https://github.com/vllm-project/vllm
44
edit_uri: edit/main/docs/
55
exclude_docs: |
6+
argparse
67
*.inc.md
78
*.template.md
89
theme:
@@ -47,6 +48,7 @@ theme:
4748
hooks:
4849
- docs/mkdocs/hooks/remove_announcement.py
4950
- docs/mkdocs/hooks/generate_examples.py
51+
- docs/mkdocs/hooks/generate_argparse.py
5052
- docs/mkdocs/hooks/url_schemes.py
5153

5254
# Required to stop api-autonav from raising an error

requirements/docs.txt

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,18 @@ mkdocs-awesome-nav
77
python-markdown-math
88
regex
99
ruff
10+
11+
# Required for argparse hook only
12+
-f https://download.pytorch.org/whl/cpu
13+
cachetools
14+
cloudpickle
15+
fastapi
16+
msgspec
17+
openai
18+
pillow
19+
psutil
20+
pybase64
21+
pydantic
22+
torch
23+
transformers
24+
zmq

vllm/engine/arg_utils.py

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,9 @@
1212
import warnings
1313
from dataclasses import MISSING, dataclass, fields, is_dataclass
1414
from itertools import permutations
15-
from typing import (Annotated, Any, Callable, Dict, List, Literal, Optional,
16-
Type, TypeVar, Union, cast, get_args, get_origin)
15+
from typing import (TYPE_CHECKING, Annotated, Any, Callable, Dict, List,
16+
Literal, Optional, Type, TypeVar, Union, cast, get_args,
17+
get_origin)
1718

1819
import regex as re
1920
import torch
@@ -33,20 +34,26 @@
3334
SchedulerConfig, SchedulerPolicy, SpeculativeConfig,
3435
TaskOption, TokenizerMode, TokenizerPoolConfig,
3536
VllmConfig, get_attr_docs, get_field)
36-
from vllm.executor.executor_base import ExecutorBase
3737
from vllm.logger import init_logger
38-
from vllm.model_executor.layers.quantization import QuantizationMethods
3938
from vllm.platforms import CpuArchEnum, current_platform
4039
from vllm.plugins import load_general_plugins
4140
from vllm.reasoning import ReasoningParserManager
4241
from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
4342
from vllm.transformers_utils.utils import check_gguf_file
44-
from vllm.usage.usage_lib import UsageContext
4543
from vllm.utils import (STR_DUAL_CHUNK_FLASH_ATTN_VAL, FlexibleArgumentParser,
4644
GiB_bytes, get_ip, is_in_ray_actor)
4745

4846
# yapf: enable
4947

48+
if TYPE_CHECKING:
49+
from vllm.executor.executor_base import ExecutorBase
50+
from vllm.model_executor.layers.quantization import QuantizationMethods
51+
from vllm.usage.usage_lib import UsageContext
52+
else:
53+
ExecutorBase = Any
54+
QuantizationMethods = Any
55+
UsageContext = Any
56+
5057
logger = init_logger(__name__)
5158

5259
# object is used to allow for special typing forms
@@ -200,14 +207,17 @@ def _compute_kwargs(cls: ConfigType) -> dict[str, Any]:
200207
kwargs[name] = {"default": default, "help": help}
201208

202209
# Set other kwargs based on the type hints
203-
json_tip = """\n\nShould either be a valid JSON string or JSON keys
204-
passed individually. For example, the following sets of arguments are
205-
equivalent:\n\n
206-
- `--json-arg '{"key1": "value1", "key2": {"key3": "value2"}}'`\n
207-
- `--json-arg.key1 value1 --json-arg.key2.key3 value2`\n
208-
Additionally, list elements can be passed individually using '+':
209-
- `--json-arg '{"key4": ["value3", "value4", "value5"]}'`\n
210-
- `--json-arg.key4+ value3 --json-arg.key4+='value4,value5'`\n\n"""
210+
json_tip = """Should either be a valid JSON string or JSON keys
211+
passed individually. For example, the following sets of arguments are
212+
equivalent:
213+
214+
- `--json-arg '{"key1": "value1", "key2": {"key3": "value2"}}'`\n
215+
- `--json-arg.key1 value1 --json-arg.key2.key3 value2`
216+
217+
Additionally, list elements can be passed individually using `+`:
218+
219+
- `--json-arg '{"key4": ["value3", "value4", "value5"]}'`\n
220+
- `--json-arg.key4+ value3 --json-arg.key4+='value4,value5'`"""
211221
if dataclass_cls is not None:
212222

213223
def parse_dataclass(val: str, cls=dataclass_cls) -> Any:
@@ -219,7 +229,7 @@ def parse_dataclass(val: str, cls=dataclass_cls) -> Any:
219229
raise argparse.ArgumentTypeError(repr(e)) from e
220230

221231
kwargs[name]["type"] = parse_dataclass
222-
kwargs[name]["help"] += json_tip
232+
kwargs[name]["help"] += f"\n\n{json_tip}"
223233
elif contains_type(type_hints, bool):
224234
# Creates --no-<name> and --<name> flags
225235
kwargs[name]["action"] = argparse.BooleanOptionalAction
@@ -255,7 +265,7 @@ def parse_dataclass(val: str, cls=dataclass_cls) -> Any:
255265
kwargs[name]["type"] = union_dict_and_str
256266
elif contains_type(type_hints, dict):
257267
kwargs[name]["type"] = parse_type(json.loads)
258-
kwargs[name]["help"] += json_tip
268+
kwargs[name]["help"] += f"\n\n{json_tip}"
259269
elif (contains_type(type_hints, str)
260270
or any(is_not_builtin(th) for th in type_hints)):
261271
kwargs[name]["type"] = str
@@ -1545,7 +1555,6 @@ def _set_default_args_v0(self, model_config: ModelConfig) -> None:
15451555
# Enable chunked prefill by default for long context (> 32K)
15461556
# models to avoid OOM errors in initial memory profiling phase.
15471557
elif use_long_context:
1548-
from vllm.platforms import current_platform
15491558
is_gpu = current_platform.is_cuda()
15501559
use_sliding_window = (model_config.get_sliding_window()
15511560
is not None)
@@ -1653,6 +1662,7 @@ def _set_default_args_v1(self, usage_context: UsageContext,
16531662
# NOTE(Kuntai): Setting large `max_num_batched_tokens` for A100 reduces
16541663
# throughput, see PR #17885 for more details.
16551664
# So here we do an extra device name check to prevent such regression.
1665+
from vllm.usage.usage_lib import UsageContext
16561666
if device_memory >= 70 * GiB_bytes and "a100" not in device_name:
16571667
# For GPUs like H100 and MI300x, use larger default values.
16581668
default_max_num_batched_tokens = {

vllm/entrypoints/chat_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@
3838

3939
from vllm.config import ModelConfig
4040
from vllm.logger import init_logger
41-
from vllm.model_executor.model_loader import get_model_cls
4241
from vllm.model_executor.models import SupportsMultiModal
4342
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict
4443
from vllm.multimodal.utils import MediaConnector
@@ -524,6 +523,7 @@ def model_config(self) -> ModelConfig:
524523

525524
@cached_property
526525
def model_cls(self):
526+
from vllm.model_executor.model_loader import get_model_cls
527527
return get_model_cls(self.model_config)
528528

529529
@property

vllm/inputs/registry.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,21 @@
1313
from vllm.jsontree import JSONTree, json_map_leaves
1414
from vllm.logger import init_logger
1515
from vllm.transformers_utils.processor import cached_processor_from_config
16-
from vllm.transformers_utils.tokenizer import AnyTokenizer
1716
from vllm.utils import resolve_mm_processor_kwargs
1817

1918
if TYPE_CHECKING:
2019
from vllm.config import ModelConfig
2120
from vllm.multimodal import (MultiModalDataDict, MultiModalPlaceholderDict,
2221
MultiModalRegistry)
2322
from vllm.sequence import SequenceData
23+
from vllm.transformers_utils.tokenizer import AnyTokenizer
24+
else:
25+
ModelConfig = Any
26+
MultiModalDataDict = Any
27+
MultiModalPlaceholderDict = Any
28+
MultiModalRegistry = Any
29+
SequenceData = Any
30+
AnyTokenizer = Any
2431

2532
_T = TypeVar("_T")
2633
_C = TypeVar("_C", bound=PretrainedConfig, default=PretrainedConfig)
@@ -36,7 +43,7 @@ class InputContext:
3643
modify the inputs.
3744
"""
3845

39-
model_config: "ModelConfig"
46+
model_config: ModelConfig
4047
"""The configuration of the model."""
4148

4249
def get_hf_config(
@@ -200,9 +207,9 @@ class DummyData(NamedTuple):
200207
Note: This is only used in V0.
201208
"""
202209

203-
seq_data: "SequenceData"
204-
multi_modal_data: Optional["MultiModalDataDict"] = None
205-
multi_modal_placeholders: Optional["MultiModalPlaceholderDict"] = None
210+
seq_data: SequenceData
211+
multi_modal_data: Optional[MultiModalDataDict] = None
212+
multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None
206213

207214

208215
class InputRegistry:
@@ -212,9 +219,9 @@ class InputRegistry:
212219

213220
def dummy_data_for_profiling(
214221
self,
215-
model_config: "ModelConfig",
222+
model_config: ModelConfig,
216223
seq_len: int,
217-
mm_registry: "MultiModalRegistry",
224+
mm_registry: MultiModalRegistry,
218225
is_encoder_data: bool = False,
219226
) -> DummyData:
220227
"""

0 commit comments

Comments
 (0)