Skip to content

Commit 657f2f3

Browse files
authored
[DP] Support external DP Load Balancer mode (#19790)
Signed-off-by: Nick Hill <nhill@redhat.com>
1 parent a1aafc8 commit 657f2f3

File tree

11 files changed

+1250
-783
lines changed

11 files changed

+1250
-783
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@ steps:
155155
- examples/offline_inference/rlhf_colocate.py
156156
- tests/examples/offline_inference/data_parallel.py
157157
- tests/v1/test_async_llm_dp.py
158+
- tests/v1/test_external_lb_dp.py
158159
- tests/v1/engine/test_engine_core_client.py
159160
commands:
160161
# test with tp=2 and external_dp=2
@@ -163,8 +164,9 @@ steps:
163164
# test with tp=2 and pp=2
164165
- PP_SIZE=2 torchrun --nproc-per-node=4 distributed/test_torchrun_example.py
165166
# test with internal dp
166-
- python3 ../examples/offline_inference/data_parallel.py
167+
- python3 ../examples/offline_inference/data_parallel.py --enforce-eager
167168
- TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py
169+
- TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_external_lb_dp.py
168170
- pytest -v -s v1/engine/test_engine_core_client.py::test_kv_cache_events_dp
169171
- pytest -v -s distributed/test_utils.py
170172
- pytest -v -s compile/test_basic_correctness.py
@@ -682,10 +684,12 @@ steps:
682684
- vllm/worker/model_runner.py
683685
- entrypoints/llm/test_collective_rpc.py
684686
- tests/v1/test_async_llm_dp.py
687+
- tests/v1/test_external_lb_dp.py
685688
- tests/v1/entrypoints/openai/test_multi_api_servers.py
686689
- vllm/v1/engine/
687690
commands:
688691
- TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py
692+
- TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/test_external_lb_dp.py
689693
- DP_SIZE=2 pytest -v -s v1/entrypoints/openai/test_multi_api_servers.py
690694
- pytest -v -s entrypoints/llm/test_collective_rpc.py
691695
- pytest -v -s ./compile/test_basic_correctness.py

tests/v1/engine/test_engine_core_client.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@
2626
from vllm.v1.engine.core import EngineCore
2727
from vllm.v1.engine.core_client import (AsyncMPClient, EngineCoreClient,
2828
SyncMPClient)
29+
from vllm.v1.engine.utils import CoreEngineProcManager
2930
from vllm.v1.executor.abstract import Executor
30-
from vllm.v1.utils import CoreEngineProcManager
3131

3232
from ...distributed.conftest import MockSubscriber
3333
from ...utils import create_new_process_for_each_test
@@ -563,7 +563,7 @@ def create_mock_executor(vllm_config):
563563
m.setenv("VLLM_USE_V1", "1")
564564
m.setenv("CUDA_VISIBLE_DEVICES", "") # No CUDA devices
565565

566-
from vllm.v1.utils import EngineZmqAddresses
566+
from vllm.v1.engine.utils import EngineZmqAddresses
567567

568568
def mock_startup_handshake(self, handshake_socket, on_head_node,
569569
parallel_config):
@@ -580,7 +580,7 @@ def mock_startup_handshake(self, handshake_socket, on_head_node,
580580
trust_remote_code=True).create_engine_config()
581581
engine_core_proc = EngineCoreProc(
582582
vllm_config=vllm_config,
583-
on_head_node=True,
583+
local_client=True,
584584
handshake_address="tcp://127.0.0.1:12345",
585585
executor_class=mock_executor_class,
586586
log_stats=False,

tests/v1/test_external_lb_dp.py

Lines changed: 312 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,312 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import asyncio
4+
import os
5+
import threading
6+
import time
7+
from contextlib import AsyncExitStack
8+
9+
import openai # use the official client for correctness check
10+
import pytest
11+
import pytest_asyncio
12+
13+
from tests.utils import RemoteOpenAIServer
14+
from vllm.platforms import Platform
15+
16+
MODEL_NAME = "ibm-research/PowerMoE-3b"
17+
18+
# Number of data parallel ranks for external LB testing
19+
DP_SIZE = int(os.getenv("DP_SIZE", "2"))
20+
# Default tensor parallell size to use
21+
TP_SIZE = int(os.getenv("TP_SIZE", "1"))
22+
23+
24+
class ExternalLBServerManager:
25+
"""Manages data parallel vLLM server instances for external
26+
load balancer testing."""
27+
28+
def __init__(self,
29+
model_name: str,
30+
dp_size: int,
31+
api_server_count: int,
32+
base_server_args: list,
33+
tp_size: int = TP_SIZE):
34+
self.model_name = model_name
35+
self.dp_size = dp_size
36+
self.tp_size = tp_size
37+
self.api_server_count = api_server_count
38+
self.base_server_args = base_server_args
39+
self.servers: list[tuple[RemoteOpenAIServer, list[str]]] = []
40+
self.server_threads: list[threading.Thread] = []
41+
42+
def __enter__(self) -> list[tuple[RemoteOpenAIServer, list[str]]]:
43+
"""Start all server instances for external LB mode."""
44+
for rank in range(self.dp_size):
45+
# Create server args for this specific rank
46+
server_args = self.base_server_args.copy()
47+
48+
# Add external LB specific arguments
49+
server_args.extend([
50+
"--data-parallel-size",
51+
str(self.dp_size),
52+
"--data-parallel-rank",
53+
str(rank),
54+
"--data-parallel-size-local",
55+
"1",
56+
"--tensor-parallel-size",
57+
str(self.tp_size),
58+
"--port",
59+
str(8000 + rank), # Different port for each rank
60+
"--api-server-count",
61+
str(self.api_server_count),
62+
])
63+
64+
# Use a thread to start each server to allow parallel initialization
65+
def start_server(r: int, sargs: list[str]):
66+
try:
67+
# Start the server
68+
server = RemoteOpenAIServer(
69+
self.model_name,
70+
sargs,
71+
auto_port=False,
72+
env_dict={
73+
"CUDA_VISIBLE_DEVICES":
74+
",".join(
75+
str(Platform.device_id_to_physical_device_id(
76+
i))
77+
for i in range(r * TP_SIZE, (r + 1) * TP_SIZE))
78+
})
79+
server.__enter__()
80+
print(f"Server rank {r} started successfully with "
81+
f"{self.api_server_count} API servers")
82+
self.servers.append((server, sargs))
83+
except Exception as e:
84+
print(f"Failed to start server rank {r}: {e}")
85+
raise
86+
87+
thread = threading.Thread(target=start_server,
88+
args=(rank, server_args))
89+
thread.start()
90+
91+
self.server_threads.append(thread)
92+
93+
# Wait for all servers to start
94+
for thread in self.server_threads:
95+
thread.join()
96+
97+
# Give servers additional time to fully initialize and coordinate
98+
time.sleep(2)
99+
100+
if len(self.servers) != self.dp_size:
101+
raise Exception("Servers failed to start")
102+
103+
return self.servers
104+
105+
def __exit__(self, exc_type, exc_val, exc_tb):
106+
"""Stop all server instances."""
107+
while self.servers:
108+
try:
109+
self.servers.pop()[0].__exit__(exc_type, exc_val, exc_tb)
110+
except Exception as e:
111+
print(f"Error stopping server: {e}")
112+
113+
114+
@pytest.fixture(scope="module")
115+
def default_server_args():
116+
return [
117+
# use half precision for speed and memory savings in CI environment
118+
"--dtype",
119+
"bfloat16",
120+
"--max-model-len",
121+
"2048",
122+
"--max-num-seqs",
123+
"128",
124+
"--enforce-eager",
125+
]
126+
127+
128+
@pytest.fixture(scope="module", params=[1, 4])
129+
def servers(request, default_server_args):
130+
api_server_count = request.param
131+
with ExternalLBServerManager(MODEL_NAME, DP_SIZE, api_server_count,
132+
default_server_args) as server_list:
133+
yield server_list
134+
135+
136+
@pytest_asyncio.fixture
137+
async def clients(servers: list[tuple[RemoteOpenAIServer, list[str]]]):
138+
# Create a client for each server
139+
async with AsyncExitStack() as stack:
140+
yield [
141+
await stack.enter_async_context(server.get_async_client())
142+
for server, _ in servers
143+
]
144+
145+
146+
@pytest.mark.asyncio
147+
@pytest.mark.parametrize(
148+
"model_name",
149+
[MODEL_NAME],
150+
)
151+
async def test_external_lb_single_completion(clients: list[
152+
openai.AsyncOpenAI], servers: list[tuple[RemoteOpenAIServer, list[str]]],
153+
model_name: str) -> None:
154+
155+
async def make_request(client: openai.AsyncOpenAI):
156+
completion = await client.completions.create(
157+
model=model_name,
158+
prompt="Hello, my name is",
159+
max_tokens=10,
160+
temperature=1.0)
161+
162+
assert completion.id is not None
163+
assert completion.choices is not None and len(completion.choices) == 1
164+
165+
choice = completion.choices[0]
166+
# The exact number of tokens can vary slightly with temperature=1.0,
167+
# so we check for a reasonable minimum length.
168+
assert len(choice.text) >= 1
169+
# Finish reason might not always be 'length' if the model finishes early
170+
# or due to other reasons, especially with high temperature.
171+
# So, we'll accept 'length' or 'stop'.
172+
assert choice.finish_reason in ("length", "stop")
173+
174+
# Token counts can also vary, so we check they are positive.
175+
assert completion.usage.completion_tokens > 0
176+
assert completion.usage.prompt_tokens > 0
177+
assert completion.usage.total_tokens > 0
178+
return completion
179+
180+
# Test single request to each server
181+
for i, client in enumerate(clients):
182+
result = await make_request(client)
183+
assert result is not None
184+
print(f"Server {i} handled single completion request successfully")
185+
186+
await asyncio.sleep(0.5)
187+
188+
# Send requests to all servers in round-robin fashion
189+
num_requests_per_server = 25 # Total 50 requests across 2 servers
190+
all_tasks = []
191+
192+
for i, client in enumerate(clients):
193+
tasks = [make_request(client) for _ in range(num_requests_per_server)]
194+
all_tasks.extend(tasks)
195+
196+
results = await asyncio.gather(*all_tasks)
197+
assert len(results) == num_requests_per_server * len(clients)
198+
assert all(completion is not None for completion in results)
199+
200+
await asyncio.sleep(0.5)
201+
202+
# Second burst of requests
203+
all_tasks = []
204+
for i, client in enumerate(clients):
205+
tasks = [make_request(client) for _ in range(num_requests_per_server)]
206+
all_tasks.extend(tasks)
207+
208+
results = await asyncio.gather(*all_tasks)
209+
assert len(results) == num_requests_per_server * len(clients)
210+
assert all(completion is not None for completion in results)
211+
212+
_, server_args = servers[0]
213+
api_server_count = (
214+
server_args.count('--api-server-count')
215+
and server_args[server_args.index('--api-server-count') + 1] or 1)
216+
print(
217+
f"Successfully completed external LB test with {len(clients)} servers "
218+
f"(API server count: {api_server_count})")
219+
220+
221+
@pytest.mark.asyncio
222+
@pytest.mark.parametrize(
223+
"model_name",
224+
[MODEL_NAME],
225+
)
226+
async def test_external_lb_completion_streaming(clients: list[
227+
openai.AsyncOpenAI], servers: list[tuple[RemoteOpenAIServer, list[str]]],
228+
model_name: str) -> None:
229+
prompt = "What is an LLM?"
230+
231+
async def make_streaming_request(client: openai.AsyncOpenAI):
232+
# Perform a non-streaming request to get the expected full output
233+
single_completion = await client.completions.create(
234+
model=model_name,
235+
prompt=prompt,
236+
max_tokens=5,
237+
temperature=0.0,
238+
)
239+
single_output = single_completion.choices[0].text
240+
241+
# Perform the streaming request
242+
stream = await client.completions.create(model=model_name,
243+
prompt=prompt,
244+
max_tokens=5,
245+
temperature=0.0,
246+
stream=True)
247+
chunks: list[str] = []
248+
finish_reason_count = 0
249+
last_chunk = None
250+
async for chunk in stream:
251+
chunks.append(chunk.choices[0].text)
252+
if chunk.choices[0].finish_reason is not None:
253+
finish_reason_count += 1
254+
last_chunk = chunk # Keep track of the last chunk
255+
256+
# finish reason should only return in the last block for OpenAI API
257+
assert finish_reason_count == 1, (
258+
"Finish reason should appear exactly once.")
259+
assert last_chunk is not None, (
260+
"Stream should have yielded at least one chunk.")
261+
assert last_chunk.choices[
262+
0].finish_reason == "length", "Finish reason should be 'length'."
263+
# Check that the combined text matches the non-streamed version.
264+
assert "".join(
265+
chunks
266+
) == single_output, "Streamed output should match non-streamed output."
267+
return True # Indicate success for this request
268+
269+
# Test single request to each server
270+
for i, client in enumerate(clients):
271+
result = await make_streaming_request(client)
272+
assert result is not None
273+
print(f"Server {i} handled single streaming request successfully")
274+
275+
await asyncio.sleep(0.5)
276+
277+
# Send streaming requests to all servers in round-robin fashion
278+
num_requests_per_server = 25 # Total 50 requests across 2 servers
279+
all_tasks = []
280+
281+
for i, client in enumerate(clients):
282+
tasks = [
283+
make_streaming_request(client)
284+
for _ in range(num_requests_per_server)
285+
]
286+
all_tasks.extend(tasks)
287+
288+
results = await asyncio.gather(*all_tasks)
289+
assert len(results) == num_requests_per_server * len(clients)
290+
assert all(results), "Not all streaming requests completed successfully."
291+
292+
await asyncio.sleep(0.5)
293+
294+
# Second burst of streaming requests
295+
all_tasks = []
296+
for i, client in enumerate(clients):
297+
tasks = [
298+
make_streaming_request(client)
299+
for _ in range(num_requests_per_server)
300+
]
301+
all_tasks.extend(tasks)
302+
303+
results = await asyncio.gather(*all_tasks)
304+
assert len(results) == num_requests_per_server * len(clients)
305+
assert all(results), "Not all streaming requests completed successfully."
306+
307+
_, server_args = servers[0]
308+
api_server_count = (
309+
server_args.count('--api-server-count')
310+
and server_args[server_args.index('--api-server-count') + 1] or 1)
311+
print(f"Successfully completed external LB streaming test with "
312+
f"{len(clients)} servers (API server count: {api_server_count})")

0 commit comments

Comments
 (0)