Skip to content

Commit 6865fe0

Browse files
hmellorYikun
andauthored
Fix interaction between Optional and Annotated in CLI typing (#19093)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Yikun Jiang <yikun@apache.org>
1 parent e31446b commit 6865fe0

File tree

2 files changed

+34
-10
lines changed

2 files changed

+34
-10
lines changed

tests/engine/test_arg_utils.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,14 @@
55
from argparse import ArgumentError, ArgumentTypeError
66
from contextlib import nullcontext
77
from dataclasses import dataclass, field
8-
from typing import Literal, Optional
8+
from typing import Annotated, Literal, Optional
99

1010
import pytest
1111

1212
from vllm.config import CompilationConfig, config
1313
from vllm.engine.arg_utils import (EngineArgs, contains_type, get_kwargs,
14-
get_type, is_not_builtin, is_type,
15-
literal_to_kwargs, nullable_kvs,
14+
get_type, get_type_hints, is_not_builtin,
15+
is_type, literal_to_kwargs, nullable_kvs,
1616
optional_type, parse_type)
1717
from vllm.utils import FlexibleArgumentParser
1818

@@ -160,6 +160,18 @@ def test_is_not_builtin(type_hint, expected):
160160
assert is_not_builtin(type_hint) == expected
161161

162162

163+
@pytest.mark.parametrize(
164+
("type_hint", "expected"), [
165+
(Annotated[int, "annotation"], {int}),
166+
(Optional[int], {int, type(None)}),
167+
(Annotated[Optional[int], "annotation"], {int, type(None)}),
168+
(Optional[Annotated[int, "annotation"]], {int, type(None)}),
169+
],
170+
ids=["Annotated", "Optional", "Annotated_Optional", "Optional_Annotated"])
171+
def test_get_type_hints(type_hint, expected):
172+
assert get_type_hints(type_hint) == expected
173+
174+
163175
def test_get_kwargs():
164176
kwargs = get_kwargs(DummyConfig)
165177
print(kwargs)

vllm/engine/arg_utils.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
import regex as re
1717
import torch
18-
from pydantic import SkipValidation, TypeAdapter, ValidationError
18+
from pydantic import TypeAdapter, ValidationError
1919
from typing_extensions import TypeIs, deprecated
2020

2121
import vllm.envs as envs
@@ -151,17 +151,29 @@ def is_not_builtin(type_hint: TypeHint) -> bool:
151151
return type_hint.__module__ != "builtins"
152152

153153

154+
def get_type_hints(type_hint: TypeHint) -> set[TypeHint]:
155+
"""Extract type hints from Annotated or Union type hints."""
156+
type_hints: set[TypeHint] = set()
157+
origin = get_origin(type_hint)
158+
args = get_args(type_hint)
159+
160+
if origin is Annotated:
161+
type_hints.update(get_type_hints(args[0]))
162+
elif origin is Union:
163+
for arg in args:
164+
type_hints.update(get_type_hints(arg))
165+
else:
166+
type_hints.add(type_hint)
167+
168+
return type_hints
169+
170+
154171
def get_kwargs(cls: ConfigType) -> dict[str, Any]:
155172
cls_docs = get_attr_docs(cls)
156173
kwargs = {}
157174
for field in fields(cls):
158175
# Get the set of possible types for the field
159-
type_hints: set[TypeHint] = set()
160-
if get_origin(field.type) in {Union, Annotated}:
161-
predicate = lambda arg: not isinstance(arg, SkipValidation)
162-
type_hints.update(filter(predicate, get_args(field.type)))
163-
else:
164-
type_hints.add(field.type)
176+
type_hints: set[TypeHint] = get_type_hints(field.type)
165177

166178
# If the field is a dataclass, we can use the model_validate_json
167179
generator = (th for th in type_hints if is_dataclass(th))

0 commit comments

Comments
 (0)