Skip to content

Commit e545a3a

Browse files
22quinnpy-andy-c
authored andcommitted
[Core] Add update_config RPC method (vllm-project#20095)
Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com>
1 parent fccfb44 commit e545a3a

File tree

7 files changed

+97
-9
lines changed

7 files changed

+97
-9
lines changed

tests/test_config.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from vllm.compilation.backends import VllmBackend
99
from vllm.config import (LoadConfig, ModelConfig, PoolerConfig, VllmConfig,
10-
get_field)
10+
get_field, update_config)
1111
from vllm.model_executor.layers.pooler import PoolingType
1212
from vllm.platforms import current_platform
1313

@@ -46,6 +46,34 @@ def test_get_field():
4646
assert c.default_factory is MISSING
4747

4848

49+
@dataclass
50+
class _TestNestedConfig:
51+
a: _TestConfigFields = field(
52+
default_factory=lambda: _TestConfigFields(a=0))
53+
54+
55+
def test_update_config():
56+
# Simple update
57+
config1 = _TestConfigFields(a=0)
58+
new_config1 = update_config(config1, {"a": 42})
59+
assert new_config1.a == 42
60+
# Nonexistent field
61+
with pytest.raises(AssertionError):
62+
new_config1 = update_config(config1, {"nonexistent": 1})
63+
# Nested update with dataclass
64+
config2 = _TestNestedConfig()
65+
new_inner_config = _TestConfigFields(a=1, c="new_value")
66+
new_config2 = update_config(config2, {"a": new_inner_config})
67+
assert new_config2.a == new_inner_config
68+
# Nested update with dict
69+
config3 = _TestNestedConfig()
70+
new_config3 = update_config(config3, {"a": {"c": "new_value"}})
71+
assert new_config3.a.c == "new_value"
72+
# Nested update with invalid type
73+
with pytest.raises(AssertionError):
74+
new_config3 = update_config(config3, {"a": "new_value"})
75+
76+
4977
@pytest.mark.parametrize(
5078
("model_id", "expected_runner_type", "expected_task"),
5179
[

tests/v1/worker/test_gpu_model_runner.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -434,16 +434,28 @@ def rnd_stride_order():
434434
assert all(not kv.is_contiguous() for kv in model_runner.kv_caches)
435435

436436

437+
def test_update_config(model_runner):
438+
# Simple update
439+
model_runner.update_config({"load_config": {"load_format": "dummy"}})
440+
assert model_runner.load_config.load_format == "dummy"
441+
# Raise error on non-existing config
442+
with pytest.raises(AssertionError):
443+
model_runner.update_config({"do_not_exist_config": "dummy"})
444+
445+
437446
def test_load_model_weights_inplace(dist_init, model_runner, model_runner_2):
438447
# In this test, model_runner loads model + weights in one go, while
439448
# model_runner_2 loads dummy weights first then load real weights inplace
440449
model_runner.load_model()
441450
original_load_format = model_runner_2.load_config.load_format
442-
model_runner_2.load_config.load_format = "dummy"
451+
model_runner_2.update_config({"load_config": {"load_format": "dummy"}})
443452
model_runner_2.load_model() # Initial model loading with dummy weights
444453
assert str(model_runner.get_model().state_dict()) != str(
445454
model_runner_2.get_model().state_dict())
446-
model_runner_2.load_config.load_format = original_load_format
455+
model_runner_2.update_config(
456+
{"load_config": {
457+
"load_format": original_load_format
458+
}})
447459
model_runner_2.load_model() # Load real weights inplace
448460
assert str(model_runner.get_model().state_dict()) == str(
449461
model_runner_2.get_model().state_dict())

vllm/config.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@
7171
ConfigType = type[DataclassInstance]
7272
HfOverrides = Union[dict, Callable[[type], type]]
7373
else:
74+
DataclassInstance = Any
7475
PlacementGroup = Any
7576
PretrainedConfig = Any
7677
ExecutorBase = Any
@@ -87,7 +88,7 @@
8788
"vllm.model_executor.models")
8889

8990
logger = init_logger(__name__)
90-
91+
DataclassInstanceT = TypeVar("DataclassInstanceT", bound=DataclassInstance)
9192
ConfigT = TypeVar("ConfigT", bound=ConfigType)
9293

9394
TaskOption = Literal["auto", "generate", "embedding", "embed", "classify",
@@ -5049,3 +5050,21 @@ class SpeechToTextConfig:
50495050
@property
50505051
def allow_audio_chunking(self) -> bool:
50515052
return self.min_energy_split_window_size is not None
5053+
5054+
5055+
def update_config(config: DataclassInstanceT,
5056+
overrides: dict[str, Any]) -> DataclassInstanceT:
5057+
processed_overrides = {}
5058+
for field_name, value in overrides.items():
5059+
assert hasattr(
5060+
config, field_name), f"{type(config)} has no field `{field_name}`"
5061+
current_value = getattr(config, field_name)
5062+
if is_dataclass(current_value) and not is_dataclass(value):
5063+
assert isinstance(value, dict), (
5064+
f"Overrides to {type(config)}.{field_name} must be a dict"
5065+
f" or {type(current_value)}, but got {type(value)}")
5066+
value = update_config(
5067+
current_value, # type: ignore[type-var]
5068+
value)
5069+
processed_overrides[field_name] = value
5070+
return replace(config, **processed_overrides)

vllm/v1/worker/gpu_model_runner.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from vllm.attention.layer import Attention
2020
from vllm.compilation.counter import compilation_counter
2121
from vllm.config import (CompilationLevel, VllmConfig,
22-
get_layers_from_vllm_config)
22+
get_layers_from_vllm_config, update_config)
2323
from vllm.distributed.eplb.eplb_state import EplbState
2424
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
2525
has_kv_transfer_group)
@@ -1728,6 +1728,16 @@ def propose_ngram_draft_token_ids(
17281728
draft_token_ids.append(drafter_output.tolist())
17291729
return draft_token_ids
17301730

1731+
def update_config(self, overrides: dict[str, Any]) -> None:
1732+
allowed_config_names = {"load_config", "model_config"}
1733+
for config_name, config_overrides in overrides.items():
1734+
assert config_name in allowed_config_names, \
1735+
f"Config `{config_name}` not supported. " \
1736+
f"Allowed configs: {allowed_config_names}"
1737+
config = getattr(self, config_name)
1738+
new_config = update_config(config, config_overrides)
1739+
setattr(self, config_name, new_config)
1740+
17311741
def load_model(self) -> None:
17321742
logger.info("Starting to load model %s...", self.model_config.model)
17331743
with DeviceMemoryProfiler() as m: # noqa: SIM117

vllm/v1/worker/gpu_worker.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import copy
55
import gc
66
import os
7-
from typing import TYPE_CHECKING, Optional
7+
from typing import TYPE_CHECKING, Any, Optional
88

99
import torch
1010
import torch.distributed
@@ -193,6 +193,9 @@ def load_model(self) -> None:
193193
with context:
194194
self.model_runner.load_model()
195195

196+
def update_config(self, overrides: dict[str, Any]) -> None:
197+
self.model_runner.update_config(overrides)
198+
196199
@torch.inference_mode()
197200
def determine_available_memory(self) -> int:
198201
"""Profiles the peak memory usage of the model to determine how much

vllm/v1/worker/tpu_model_runner.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import bisect
44
import gc
55
import time
6-
from typing import TYPE_CHECKING, Optional, cast
6+
from typing import TYPE_CHECKING, Any, Optional, cast
77
from unittest.mock import patch
88

99
import numpy as np
@@ -18,7 +18,8 @@
1818
from vllm.attention.backends.abstract import AttentionType
1919
from vllm.attention.layer import Attention
2020
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
21-
from vllm.config import ParallelConfig, VllmConfig, get_layers_from_vllm_config
21+
from vllm.config import (ParallelConfig, VllmConfig,
22+
get_layers_from_vllm_config, update_config)
2223
from vllm.forward_context import set_forward_context
2324
from vllm.logger import init_logger
2425
from vllm.lora.layers import BaseLayerWithLoRA
@@ -1111,6 +1112,18 @@ def concat_lists(input_lists):
11111112

11121113
return model_runner_output
11131114

1115+
def update_config(self, overrides: dict[str, Any]) -> None:
1116+
# TODO: TPU config may need extra validation
1117+
# https://github.com/vllm-project/vllm/pull/20095#discussion_r2201497754
1118+
allowed_config_names = {"load_config", "model_config"}
1119+
for config_name, config_overrides in overrides.items():
1120+
assert config_name in allowed_config_names, \
1121+
f"Config `{config_name}` not supported. " \
1122+
f"Allowed configs: {allowed_config_names}"
1123+
config = getattr(self, config_name)
1124+
new_config = update_config(config, config_overrides)
1125+
setattr(self, config_name, new_config)
1126+
11141127
def load_model(self) -> None:
11151128
self.device = self.device_config.device
11161129

vllm/v1/worker/tpu_worker.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
"""A TPU worker class."""
44
import os
5-
from typing import Optional
5+
from typing import Any, Optional
66

77
import torch
88
import torch.distributed
@@ -260,6 +260,9 @@ def add_lora(self, lora_request: LoRARequest) -> bool:
260260
def load_model(self) -> None:
261261
self.model_runner.load_model()
262262

263+
def update_config(self, overrides: dict[str, Any]) -> None:
264+
self.model_runner.update_config(overrides)
265+
263266
def compile_or_warm_up_model(self) -> None:
264267
if not self.model_config.enforce_eager:
265268
self.model_runner.capture_model()

0 commit comments

Comments
 (0)