Skip to content

Commit a2142f0

Browse files
Support non-string values in JSON keys from CLI (#19471)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
1 parent 871d6b7 commit a2142f0

File tree

3 files changed

+48
-24
lines changed

3 files changed

+48
-24
lines changed

tests/test_config.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,32 +13,32 @@
1313
from vllm.platforms import current_platform
1414

1515

16-
class TestConfig1:
16+
class _TestConfig1:
1717
pass
1818

1919

2020
@dataclass
21-
class TestConfig2:
21+
class _TestConfig2:
2222
a: int
2323
"""docstring"""
2424

2525

2626
@dataclass
27-
class TestConfig3:
27+
class _TestConfig3:
2828
a: int = 1
2929

3030

3131
@dataclass
32-
class TestConfig4:
32+
class _TestConfig4:
3333
a: Union[Literal[1], Literal[2]] = 1
3434
"""docstring"""
3535

3636

3737
@pytest.mark.parametrize(("test_config", "expected_error"), [
38-
(TestConfig1, "must be a dataclass"),
39-
(TestConfig2, "must have a default"),
40-
(TestConfig3, "must have a docstring"),
41-
(TestConfig4, "must use a single Literal"),
38+
(_TestConfig1, "must be a dataclass"),
39+
(_TestConfig2, "must have a default"),
40+
(_TestConfig3, "must have a docstring"),
41+
(_TestConfig4, "must use a single Literal"),
4242
])
4343
def test_config(test_config, expected_error):
4444
with pytest.raises(Exception, match=expected_error):
@@ -57,23 +57,23 @@ def test_compile_config_repr_succeeds():
5757
assert 'inductor_passes' in val
5858

5959

60-
def test_get_field():
60+
@dataclass
61+
class _TestConfigFields:
62+
a: int
63+
b: dict = field(default_factory=dict)
64+
c: str = "default"
6165

62-
@dataclass
63-
class TestConfig:
64-
a: int
65-
b: dict = field(default_factory=dict)
66-
c: str = "default"
6766

67+
def test_get_field():
6868
with pytest.raises(ValueError):
69-
get_field(TestConfig, "a")
69+
get_field(_TestConfigFields, "a")
7070

71-
b = get_field(TestConfig, "b")
71+
b = get_field(_TestConfigFields, "b")
7272
assert isinstance(b, Field)
7373
assert b.default is MISSING
7474
assert b.default_factory is dict
7575

76-
c = get_field(TestConfig, "c")
76+
c = get_field(_TestConfigFields, "c")
7777
assert isinstance(c, Field)
7878
assert c.default == "default"
7979
assert c.default_factory is MISSING

tests/test_utils.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,15 @@ def test_dict_args(parser):
272272
"val5",
273273
"--hf_overrides.key-7.key_8",
274274
"val6",
275+
# Test data type detection
276+
"--hf_overrides.key9",
277+
"100",
278+
"--hf_overrides.key10",
279+
"100.0",
280+
"--hf_overrides.key11",
281+
"true",
282+
"--hf_overrides.key12.key13",
283+
"null",
275284
]
276285
parsed_args = parser.parse_args(args)
277286
assert parsed_args.model_name == "something.something"
@@ -286,6 +295,12 @@ def test_dict_args(parser):
286295
"key-7": {
287296
"key_8": "val6",
288297
},
298+
"key9": 100,
299+
"key10": 100.0,
300+
"key11": True,
301+
"key12": {
302+
"key13": None,
303+
},
289304
}
290305

291306

vllm/utils.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1466,7 +1466,7 @@ def repl(match: re.Match) -> str:
14661466
pattern = re.compile(r"(?<=--)[^\.]*")
14671467

14681468
# Convert underscores to dashes and vice versa in argument names
1469-
processed_args = []
1469+
processed_args = list[str]()
14701470
for arg in args:
14711471
if arg.startswith('--'):
14721472
if '=' in arg:
@@ -1483,7 +1483,7 @@ def repl(match: re.Match) -> str:
14831483
else:
14841484
processed_args.append(arg)
14851485

1486-
def create_nested_dict(keys: list[str], value: str):
1486+
def create_nested_dict(keys: list[str], value: str) -> dict[str, Any]:
14871487
"""Creates a nested dictionary from a list of keys and a value.
14881488
14891489
For example, `keys = ["a", "b", "c"]` and `value = 1` will create:
@@ -1494,27 +1494,36 @@ def create_nested_dict(keys: list[str], value: str):
14941494
nested_dict = {key: nested_dict}
14951495
return nested_dict
14961496

1497-
def recursive_dict_update(original: dict, update: dict):
1497+
def recursive_dict_update(
1498+
original: dict[str, Any],
1499+
update: dict[str, Any],
1500+
):
14981501
"""Recursively updates a dictionary with another dictionary."""
14991502
for k, v in update.items():
15001503
if isinstance(v, dict) and isinstance(original.get(k), dict):
15011504
recursive_dict_update(original[k], v)
15021505
else:
15031506
original[k] = v
15041507

1505-
delete = set()
1506-
dict_args: dict[str, dict] = defaultdict(dict)
1508+
delete = set[int]()
1509+
dict_args = defaultdict[str, dict[str, Any]](dict)
15071510
for i, processed_arg in enumerate(processed_args):
15081511
if processed_arg.startswith("--") and "." in processed_arg:
15091512
if "=" in processed_arg:
1510-
processed_arg, value = processed_arg.split("=", 1)
1513+
processed_arg, value_str = processed_arg.split("=", 1)
15111514
if "." not in processed_arg:
15121515
# False positive, . was only in the value
15131516
continue
15141517
else:
1515-
value = processed_args[i + 1]
1518+
value_str = processed_args[i + 1]
15161519
delete.add(i + 1)
1520+
15171521
key, *keys = processed_arg.split(".")
1522+
try:
1523+
value = json.loads(value_str)
1524+
except json.decoder.JSONDecodeError:
1525+
value = value_str
1526+
15181527
# Merge all values with the same key into a single dict
15191528
arg_dict = create_nested_dict(keys, value)
15201529
recursive_dict_update(dict_args[key], arg_dict)

0 commit comments

Comments
 (0)