Skip to content

Commit 279815b

Browse files
authored
[Serve.llm] Add router replicas and batch size to llm config (#52655)
Signed-off-by: Gene Su <e870252314@gmail.com>
1 parent f23510c commit 279815b

File tree

6 files changed

+159
-6
lines changed

6 files changed

+159
-6
lines changed

python/ray/llm/_internal/serve/configs/server_models.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
DEFAULT_MULTIPLEX_DOWNLOAD_TRIES,
4545
MAX_NUM_STOPPING_SEQUENCES,
4646
ENABLE_WORKER_PROCESS_SETUP_HOOK,
47+
MODEL_RESPONSE_BATCH_TIMEOUT_MS,
4748
)
4849
from ray.llm._internal.serve.configs.prompt_formats import (
4950
Prompt,
@@ -223,6 +224,19 @@ class LLMConfig(BaseModelExtended):
223224
""",
224225
)
225226

227+
experimental_configs: Dict[str, Any] = Field(
228+
default_factory=dict,
229+
description="Experimental configurations for Ray Serve LLM. This is a "
230+
"dictionary of key-value pairs. Current supported keys are:\n"
231+
"- `stream_batching_interval_ms`: Ray Serve LLM batches streaming "
232+
"requests together. This config decides how long to wait for the "
233+
"batch before processing the requests. Defaults to "
234+
f"{MODEL_RESPONSE_BATCH_TIMEOUT_MS}.\n"
235+
"- `num_router_replicas`: The number of replicas for the router. Ray "
236+
"Serve will take the max amount all the replicas. Default would be 2 "
237+
"router replicas per model replica.\n",
238+
)
239+
226240
_supports_vision: bool = PrivateAttr(False)
227241
_model_architecture: str = PrivateAttr("")
228242
_prompt_format: HuggingFacePromptFormat = PrivateAttr(

python/ray/llm/_internal/serve/deployments/llm/vllm/vllm_engine.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -519,15 +519,25 @@ async def prepare_request(
519519
vllm_request = VLLMGenerationRequest(**request_params)
520520
return vllm_request
521521

522+
def _get_batch_interval_ms(self, stream: bool = True) -> int:
523+
"""Calculate the batching interval for responses."""
524+
stream_batching_interval_ms = self.llm_config.experimental_configs.get(
525+
"stream_batching_interval_ms"
526+
)
527+
if stream_batching_interval_ms is None:
528+
stream_batching_interval_ms = MODEL_RESPONSE_BATCH_TIMEOUT_MS
529+
return stream_batching_interval_ms if stream else None
530+
522531
async def generate(
523532
self,
524533
request: GenerationRequest,
525534
) -> AsyncGenerator[LLMRawResponse, None]:
526-
batch_interval_ms = MODEL_RESPONSE_BATCH_TIMEOUT_MS if request.stream else None
527-
535+
# TODO (genesu): Responses batching logics should be common to all
536+
# engines and belongs to the LLMServer level instead of the engine
537+
# level here. Refactor the entire batching logics up.
528538
response_stream = LLMRawResponsesBatcher(
529539
self._generate(request),
530-
interval_ms=batch_interval_ms,
540+
interval_ms=self._get_batch_interval_ms(request.stream),
531541
)
532542
async for response in response_stream.stream():
533543
yield response

python/ray/llm/_internal/serve/deployments/routers/router.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -422,6 +422,7 @@ def as_deployment(
422422
min_replicas = RAYLLM_ROUTER_MIN_REPLICAS
423423
initial_replicas = RAYLLM_ROUTER_INITIAL_REPLICAS
424424
max_replicas = RAYLLM_ROUTER_MAX_REPLICAS
425+
num_router_replicas = 0
425426

426427
# Note (genesu): Based on our internal benchmark, we are currently bottleneck
427428
# by the router replicas during high concurrency situation. We are setting the
@@ -431,6 +432,11 @@ def as_deployment(
431432
model_initial_replicas = 0
432433
model_max_replicas = 0
433434
for llm_config in llm_configs:
435+
num_router_replicas = max(
436+
num_router_replicas,
437+
llm_config.experimental_configs.get("num_router_replicas", 0),
438+
)
439+
434440
if "autoscaling_config" in llm_config.deployment_config:
435441
autoscaling_config = llm_config.deployment_config[
436442
"autoscaling_config"
@@ -448,11 +454,15 @@ def as_deployment(
448454
or autoscaling_config.min_replicas
449455
)
450456
model_max_replicas += autoscaling_config.max_replicas
451-
min_replicas = int(model_min_replicas * ROUTER_TO_MODEL_REPLICA_RATIO)
452-
initial_replicas = int(
457+
min_replicas = num_router_replicas or int(
458+
model_min_replicas * ROUTER_TO_MODEL_REPLICA_RATIO
459+
)
460+
initial_replicas = num_router_replicas or int(
453461
model_initial_replicas * ROUTER_TO_MODEL_REPLICA_RATIO
454462
)
455-
max_replicas = int(model_max_replicas * ROUTER_TO_MODEL_REPLICA_RATIO)
463+
max_replicas = num_router_replicas or int(
464+
model_max_replicas * ROUTER_TO_MODEL_REPLICA_RATIO
465+
)
456466

457467
ingress_cls = serve.ingress(fastapi_router_app)(cls)
458468
deployment_decorator = serve.deployment(

python/ray/llm/tests/serve/cpu/configs/test_models.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,32 @@ def test_engine_config_cached(self):
248248
new_engine_config = llm_config.get_engine_config()
249249
assert new_engine_config is old_engine_config
250250

251+
def test_experimental_configs(self):
252+
"""Test that `experimental_configs` can be used."""
253+
# Test with a valid dictionary can be used.
254+
experimental_configs = {
255+
"experimental_feature1": "value1",
256+
"experimental_feature2": "value2",
257+
}
258+
llm_config = LLMConfig(
259+
model_loading_config=ModelLoadingConfig(
260+
model_id="llm_model_id",
261+
),
262+
experimental_configs=experimental_configs,
263+
)
264+
assert llm_config.experimental_configs == experimental_configs
265+
266+
# test with invalid dictionary will raise a validation error.
267+
with pytest.raises(
268+
pydantic.ValidationError,
269+
):
270+
LLMConfig(
271+
model_loading_config=ModelLoadingConfig(
272+
model_id="llm_model_id",
273+
),
274+
experimental_configs={123: "value1"},
275+
)
276+
251277

252278
if __name__ == "__main__":
253279
sys.exit(pytest.main(["-v", __file__]))

python/ray/llm/tests/serve/cpu/deployments/llm/vllm/test_vllm_engine.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from ray.llm._internal.serve.configs.server_models import (
1919
LLMConfig,
2020
LLMRawResponse,
21+
ModelLoadingConfig,
2122
)
2223
from ray.llm._internal.serve.configs.constants import MODEL_RESPONSE_BATCH_TIMEOUT_MS
2324

@@ -195,6 +196,42 @@ def test_parse_sampling_params_json_mode(
195196
assert guided_json == sampling_params.response_format.json_schema
196197
assert getattr(parsed_params, "response_format", None) is None
197198

199+
def test_get_batch_interval_ms(self):
200+
"""Test that the batch interval is set correctly in the config."""
201+
202+
# Test with a no stream_batching_interval_ms.
203+
llm_config = LLMConfig(
204+
model_loading_config=ModelLoadingConfig(
205+
model_id="llm_model_id",
206+
),
207+
)
208+
vllm_engine = VLLMEngine(llm_config)
209+
assert vllm_engine._get_batch_interval_ms() == MODEL_RESPONSE_BATCH_TIMEOUT_MS
210+
211+
# Test with a non-zero stream_batching_interval_ms.
212+
llm_config = LLMConfig(
213+
model_loading_config=ModelLoadingConfig(
214+
model_id="llm_model_id",
215+
),
216+
experimental_configs={
217+
"stream_batching_interval_ms": 13,
218+
},
219+
)
220+
vllm_engine = VLLMEngine(llm_config)
221+
assert vllm_engine._get_batch_interval_ms() == 13
222+
223+
# Test with zero stream_batching_interval_ms.
224+
llm_config = LLMConfig(
225+
model_loading_config=ModelLoadingConfig(
226+
model_id="llm_model_id",
227+
),
228+
experimental_configs={
229+
"stream_batching_interval_ms": 0,
230+
},
231+
)
232+
vllm_engine = VLLMEngine(llm_config)
233+
assert vllm_engine._get_batch_interval_ms() == 0
234+
198235

199236
TEXT_VALUE = "foo"
200237
FINAL_TEXT_VALUE = "bar"
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import pytest
2+
import sys
3+
4+
from ray.llm._internal.serve.configs.server_models import (
5+
LLMConfig,
6+
ModelLoadingConfig,
7+
)
8+
from ray.llm._internal.serve.deployments.routers.router import (
9+
LLMRouter,
10+
)
11+
12+
13+
def test_router_with_num_router_replicas_config():
14+
"""Test the router with num_router_replicas config."""
15+
# Test with no num_router_replicas config.
16+
llm_configs = [
17+
LLMConfig(
18+
model_loading_config=ModelLoadingConfig(
19+
model_id="llm_model_id",
20+
),
21+
)
22+
]
23+
llm_router_deployment = LLMRouter.as_deployment(llm_configs=llm_configs)
24+
autoscaling_config = llm_router_deployment._deployment_config.autoscaling_config
25+
assert autoscaling_config.min_replicas == 2
26+
assert autoscaling_config.initial_replicas == 2
27+
assert autoscaling_config.max_replicas == 2
28+
29+
# Test with num_router_replicas config on multiple llm configs.
30+
llm_configs = [
31+
LLMConfig(
32+
model_loading_config=ModelLoadingConfig(
33+
model_id="llm_model_id",
34+
),
35+
experimental_configs={
36+
"num_router_replicas": 3,
37+
},
38+
),
39+
LLMConfig(
40+
model_loading_config=ModelLoadingConfig(
41+
model_id="llm_model_id",
42+
),
43+
experimental_configs={
44+
"num_router_replicas": 5,
45+
},
46+
),
47+
]
48+
llm_router_deployment = LLMRouter.as_deployment(llm_configs=llm_configs)
49+
autoscaling_config = llm_router_deployment._deployment_config.autoscaling_config
50+
assert autoscaling_config.min_replicas == 5
51+
assert autoscaling_config.initial_replicas == 5
52+
assert autoscaling_config.max_replicas == 5
53+
54+
55+
if __name__ == "__main__":
56+
sys.exit(pytest.main(["-v", __file__]))

0 commit comments

Comments
 (0)