Skip to content

Commit 67d25ec

Browse files
authored
[Tests] Update online DP tests to verify that requests are balanced (#20157)
Signed-off-by: Nick Hill <nhill@redhat.com>
1 parent 363528d commit 67d25ec

File tree

3 files changed

+170
-9
lines changed

3 files changed

+170
-9
lines changed

tests/v1/entrypoints/openai/test_completion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def default_server_args():
3838
]])
3939
def server(default_server_args, request):
4040
if request.param:
41-
default_server_args.extend(request.param)
41+
default_server_args = default_server_args + request.param
4242
with RemoteOpenAIServer(MODEL_NAME, default_server_args) as remote_server:
4343
yield remote_server
4444

tests/v1/entrypoints/openai/test_multi_api_servers.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
import asyncio
44
import os
5+
import re
56

67
import openai # use the official client for correctness check
78
import pytest
89
import pytest_asyncio
10+
import requests
911

1012
from tests.utils import RemoteOpenAIServer
1113

@@ -14,6 +16,122 @@
1416
DP_SIZE = os.getenv("DP_SIZE", "1")
1517

1618

19+
def get_prometheus_metrics(
20+
server: RemoteOpenAIServer) -> dict[str, dict[str, float]]:
21+
"""Fetch and parse Prometheus metrics from the /metrics endpoint.
22+
23+
Returns:
24+
Dict mapping metric names to their values grouped by labels.
25+
For example: {"vllm:request_success": {
26+
"engine=0": 5.0, "engine=1": 3.0}
27+
}
28+
"""
29+
try:
30+
response = requests.get(server.url_for("metrics"), timeout=10)
31+
response.raise_for_status()
32+
33+
metrics: dict[str, dict[str, float]] = {}
34+
35+
# Regex patterns for Prometheus metrics
36+
metric_with_labels = re.compile(
37+
r'^([a-zA-Z_:][a-zA-Z0-9_:]*)\{([^}]*)\}\s+([\d\.\-\+e]+)$')
38+
metric_simple = re.compile(
39+
r'^([a-zA-Z_:][a-zA-Z0-9_:]*)\s+([\d\.\-\+e]+)$')
40+
41+
for line in response.text.split('\n'):
42+
line = line.strip()
43+
# Skip comments and empty lines
44+
if not line or line.startswith('#'):
45+
continue
46+
47+
# Try to match metric with labels first
48+
match = metric_with_labels.match(line)
49+
if match:
50+
metric_name, labels_part, value_str = match.groups()
51+
try:
52+
value = float(value_str)
53+
if metric_name not in metrics:
54+
metrics[metric_name] = {}
55+
metrics[metric_name][f'{{{labels_part}}}'] = value
56+
except ValueError:
57+
continue
58+
else:
59+
# Try simple metric without labels
60+
match = metric_simple.match(line)
61+
if match:
62+
metric_name, value_str = match.groups()
63+
try:
64+
value = float(value_str)
65+
if metric_name not in metrics:
66+
metrics[metric_name] = {}
67+
metrics[metric_name][''] = value
68+
except ValueError:
69+
continue
70+
71+
return metrics
72+
except Exception as e:
73+
pytest.fail(f"Failed to fetch Prometheus metrics: {e}")
74+
return {}
75+
76+
77+
def get_engine_request_counts(
78+
metrics: dict[str, dict[str, float]]) -> dict[str, float]:
79+
"""Extract request counts per engine from Prometheus metrics.
80+
81+
Returns:
82+
Dict mapping engine indices to request counts.
83+
For example: {"0": 15.0, "1": 12.0}
84+
"""
85+
engine_counts = {}
86+
87+
# Look for request success metrics with engine labels
88+
success_metrics = metrics.get("vllm:request_success_total", {})
89+
engine_pattern = re.compile(r'engine="([^"]*)"')
90+
91+
for labels, count in success_metrics.items():
92+
# Extract engine ID from labels using regex
93+
match = engine_pattern.search(labels)
94+
if match:
95+
engine_id = match.group(1)
96+
if engine_id not in engine_counts:
97+
engine_counts[engine_id] = 0.0
98+
engine_counts[engine_id] += count
99+
100+
return engine_counts
101+
102+
103+
def check_request_balancing(server: RemoteOpenAIServer):
104+
"""Check request balancing via Prometheus metrics if DP_SIZE > 1.
105+
106+
Args:
107+
server: The RemoteOpenAIServer instance
108+
"""
109+
dp_size = int(DP_SIZE)
110+
if dp_size <= 1:
111+
return
112+
113+
# Get metrics after all requests are completed
114+
metrics = get_prometheus_metrics(server)
115+
engine_counts = get_engine_request_counts(metrics)
116+
117+
# Check that multiple engines received requests
118+
engines_with_requests = [
119+
engine for engine, count in engine_counts.items() if count > 0
120+
]
121+
assert len(engines_with_requests) == dp_size, (
122+
f"Expected requests to be distributed across multiple engines,"
123+
f" but only engine(s) {engines_with_requests} received "
124+
f"requests. Engine counts: {engine_counts}")
125+
126+
# Verify that the load is reasonably balanced
127+
# (no engine should handle all requests)
128+
total_requests = sum(engine_counts.values())
129+
130+
for count in engine_counts.values():
131+
assert count > total_requests // (dp_size + 1), (
132+
f"requests are imbalanced: {engine_counts}")
133+
134+
17135
@pytest.fixture(scope="module")
18136
def default_server_args():
19137
return [
@@ -50,6 +168,7 @@ async def client(server):
50168
[MODEL_NAME],
51169
)
52170
async def test_single_completion(client: openai.AsyncOpenAI,
171+
server: RemoteOpenAIServer,
53172
model_name: str) -> None:
54173

55174
async def make_request():
@@ -97,13 +216,17 @@ async def make_request():
97216
assert len(results) == num_requests
98217
assert all(completion is not None for completion in results)
99218

219+
# Check request balancing via Prometheus metrics if DP_SIZE > 1
220+
check_request_balancing(server)
221+
100222

101223
@pytest.mark.asyncio
102224
@pytest.mark.parametrize(
103225
"model_name",
104226
[MODEL_NAME],
105227
)
106228
async def test_completion_streaming(client: openai.AsyncOpenAI,
229+
server: RemoteOpenAIServer,
107230
model_name: str) -> None:
108231
prompt = "What is an LLM?"
109232

@@ -170,3 +293,6 @@ async def make_streaming_request():
170293
results
171294
) == num_requests, f"Expected {num_requests} results, got {len(results)}"
172295
assert all(results), "Not all streaming requests completed successfully."
296+
297+
# Check request balancing via Prometheus metrics if DP_SIZE > 1
298+
check_request_balancing(server)

tests/v1/test_async_llm_dp.py

Lines changed: 43 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,24 +4,30 @@
44
import asyncio
55
import os
66
from contextlib import ExitStack
7+
from dataclasses import dataclass
78
from typing import Optional
89

910
import pytest
1011

1112
from vllm import SamplingParams
13+
from vllm.config import VllmConfig
1214
from vllm.engine.arg_utils import AsyncEngineArgs
1315
from vllm.inputs import PromptType
1416
from vllm.platforms import current_platform
1517
from vllm.sampling_params import RequestOutputKind
1618
from vllm.v1.engine.async_llm import AsyncLLM
1719
from vllm.v1.engine.core_client import DPAsyncMPClient
20+
from vllm.v1.metrics.loggers import StatLoggerBase
21+
from vllm.v1.metrics.stats import IterationStats, SchedulerStats
22+
23+
DP_SIZE = int(os.getenv("DP_SIZE", 2))
1824

1925
engine_args = AsyncEngineArgs(
2026
model="ibm-research/PowerMoE-3b",
2127
enforce_eager=True,
2228
disable_log_requests=True,
2329
tensor_parallel_size=int(os.getenv("TP_SIZE", 1)),
24-
data_parallel_size=int(os.getenv("DP_SIZE", 2)),
30+
data_parallel_size=DP_SIZE,
2531
)
2632

2733
if not current_platform.supports_v1(engine_args.create_model_config()):
@@ -74,12 +80,32 @@ async def generate(
7480
async def test_load(output_kind: RequestOutputKind,
7581
data_parallel_backend: str):
7682

83+
stats_loggers = {}
84+
85+
@dataclass
86+
class SimpleStatsLogger(StatLoggerBase):
87+
init_count: int = 0
88+
finished_req_count: int = 0
89+
90+
def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
91+
stats_loggers[engine_index] = self
92+
93+
def record(self, scheduler_stats: Optional[SchedulerStats],
94+
iteration_stats: Optional[IterationStats]):
95+
if iteration_stats:
96+
self.finished_req_count += len(
97+
iteration_stats.finished_requests)
98+
99+
def log_engine_initialized(self):
100+
self.init_count += 1
101+
77102
with ExitStack() as after:
78103

79104
prompt = "This is a test of data parallel"
80105

81106
engine_args.data_parallel_backend = data_parallel_backend
82-
engine = AsyncLLM.from_engine_args(engine_args)
107+
engine = AsyncLLM.from_engine_args(engine_args,
108+
stat_loggers=[SimpleStatsLogger])
83109
after.callback(engine.shutdown)
84110

85111
NUM_REQUESTS = 100
@@ -92,12 +118,10 @@ async def test_load(output_kind: RequestOutputKind,
92118
for request_id in request_ids:
93119
tasks.append(
94120
asyncio.create_task(
95-
generate(engine,
96-
request_id,
97-
prompt,
98-
output_kind,
99-
NUM_EXPECTED_TOKENS,
100-
data_parallel_rank=0)))
121+
generate(engine, request_id, prompt, output_kind,
122+
NUM_EXPECTED_TOKENS)))
123+
# Short sleep to ensure that requests are distributed.
124+
await asyncio.sleep(0.01)
101125
# Confirm that we got all the EXPECTED tokens from the requests.
102126
done, pending = await asyncio.wait(tasks,
103127
return_when=asyncio.FIRST_EXCEPTION)
@@ -122,3 +146,14 @@ async def test_load(output_kind: RequestOutputKind,
122146

123147
assert not core_client.engines_running
124148
assert not core_client.reqs_in_flight
149+
150+
# Check that requests were distributed between the engines
151+
print(f"Stats loggers after test: {stats_loggers}")
152+
assert len(stats_loggers) == DP_SIZE
153+
assert stats_loggers[0].init_count == 1
154+
155+
for sl in stats_loggers.values():
156+
slogger: SimpleStatsLogger = sl
157+
158+
assert slogger.finished_req_count > NUM_REQUESTS // (
159+
DP_SIZE + 1), f"requests are imbalanced: {stats_loggers}"

0 commit comments

Comments
 (0)