Skip to content

Commit c16103f

Browse files
jmswenpatrickvonplaten
authored andcommitted
Allow AsyncLLMEngine.generate to target a specific DP rank (vllm-project#19102)
Signed-off-by: Jon Swenson <jmswen@gmail.com> Signed-off-by: Patrick von Platen <patrick.v.platen@gmail.com>
1 parent 8f4ffbd commit c16103f

File tree

10 files changed

+97
-5
lines changed

10 files changed

+97
-5
lines changed
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
import asyncio
3+
from typing import Optional
4+
5+
from vllm.engine.arg_utils import AsyncEngineArgs
6+
from vllm.engine.async_llm_engine import AsyncLLMEngine
7+
from vllm.outputs import RequestOutput
8+
from vllm.sampling_params import SamplingParams
9+
10+
"""
11+
To run this example, run the following commands simultaneously with
12+
different CUDA_VISIBLE_DEVICES:
13+
python examples/online_serving/multi_instance_data_parallel.py
14+
15+
vllm serve ibm-research/PowerMoE-3b -dp 2 -dpr 1 \
16+
--data-parallel-address 127.0.0.1 --data-parallel-rpc-port 62300 \
17+
--data-parallel-size-local 1 --enforce-eager --headless
18+
19+
Once both instances have completed the handshake, this example will
20+
send a request to the instance with DP rank 1.
21+
"""
22+
23+
24+
async def main():
25+
engine_args = AsyncEngineArgs(
26+
model="ibm-research/PowerMoE-3b",
27+
data_parallel_size=2,
28+
dtype="auto",
29+
max_model_len=2048,
30+
data_parallel_address="127.0.0.1",
31+
data_parallel_rpc_port=62300,
32+
data_parallel_size_local=1,
33+
enforce_eager=True,
34+
)
35+
36+
engine_client = AsyncLLMEngine.from_engine_args(engine_args)
37+
38+
sampling_params = SamplingParams(
39+
temperature=0.7,
40+
top_p=0.9,
41+
max_tokens=100,
42+
)
43+
44+
prompt = "Who won the 2004 World Series?"
45+
final_output: Optional[RequestOutput] = None
46+
async for output in engine_client.generate(
47+
prompt=prompt,
48+
sampling_params=sampling_params,
49+
request_id="abcdef",
50+
data_parallel_rank=1,
51+
):
52+
final_output = output
53+
if final_output:
54+
print(final_output.outputs[0].text)
55+
56+
57+
if __name__ == "__main__":
58+
asyncio.run(main())

tests/tokenization/test_detokenize.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,8 @@ def _run_incremental_decode(tokenizer,
7070
None,
7171
0.0,
7272
None,
73-
cache_salt=None)
73+
cache_salt=None,
74+
data_parallel_rank=None)
7475

7576
if fast is None:
7677
detokenizer = IncrementalDetokenizer.from_new_request(

tests/v1/engine/test_engine_core.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def make_request() -> EngineCoreRequest:
4242
arrival_time=time.time(),
4343
lora_request=None,
4444
cache_salt=None,
45+
data_parallel_rank=None,
4546
)
4647

4748

tests/v1/engine/test_engine_core_client.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def make_request(
5656
arrival_time=time.time(),
5757
lora_request=None,
5858
cache_salt=None,
59+
data_parallel_rank=None,
5960
)
6061

6162

tests/v1/engine/test_output_processor.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ def test_incremental_detokenization(request_output_kind: RequestOutputKind,
5959
eos_token_id=None,
6060
lora_request=None,
6161
cache_salt=None,
62+
data_parallel_rank=None,
6263
sampling_params=SamplingParams(
6364
skip_special_tokens=False,
6465
spaces_between_special_tokens=False,
@@ -406,6 +407,7 @@ def test_logprobs_processor(request_output_kind: RequestOutputKind,
406407
eos_token_id=None,
407408
lora_request=None,
408409
cache_salt=None,
410+
data_parallel_rank=None,
409411
sampling_params=SamplingParams(
410412
skip_special_tokens=False,
411413
spaces_between_special_tokens=False,
@@ -569,6 +571,7 @@ def test_stop_token(include_stop_str_in_output: bool,
569571
eos_token_id=eos_token_id,
570572
lora_request=None,
571573
cache_salt=None,
574+
data_parallel_rank=None,
572575
sampling_params=SamplingParams(
573576
skip_special_tokens=False,
574577
spaces_between_special_tokens=False,
@@ -666,6 +669,7 @@ def test_stop_string(include_stop_str_in_output: bool,
666669
eos_token_id=None,
667670
lora_request=None,
668671
cache_salt=None,
672+
data_parallel_rank=None,
669673
sampling_params=SamplingParams(
670674
skip_special_tokens=False,
671675
spaces_between_special_tokens=False,
@@ -780,6 +784,7 @@ def test_iteration_stats(dummy_test_vectors):
780784
eos_token_id=None,
781785
lora_request=None,
782786
cache_salt=None,
787+
data_parallel_rank=None,
783788
sampling_params=SamplingParams(),
784789
) for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
785790
]

vllm/engine/async_llm_engine.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -442,6 +442,7 @@ async def add_request_async(
442442
trace_headers: Optional[Mapping[str, str]] = None,
443443
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
444444
priority: int = 0,
445+
data_parallel_rank: Optional[int] = None,
445446
) -> None:
446447
...
447448

@@ -456,6 +457,7 @@ async def add_request_async(
456457
trace_headers: Optional[Mapping[str, str]] = None,
457458
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
458459
priority: int = 0,
460+
data_parallel_rank: Optional[int] = None,
459461
) -> None:
460462
...
461463

@@ -473,6 +475,7 @@ async def add_request_async(
473475
trace_headers: Optional[Mapping[str, str]] = None,
474476
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
475477
priority: int = 0,
478+
data_parallel_rank: Optional[int] = None,
476479
*,
477480
inputs: Optional[PromptType] = None, # DEPRECATED
478481
) -> None:
@@ -902,6 +905,7 @@ def add_request(
902905
trace_headers: Optional[Mapping[str, str]] = None,
903906
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
904907
priority: int = 0,
908+
data_parallel_rank: Optional[int] = None,
905909
) -> Coroutine[None, None, AsyncGenerator[Union[
906910
RequestOutput, PoolingRequestOutput], None]]:
907911
...
@@ -917,6 +921,7 @@ def add_request(
917921
trace_headers: Optional[Mapping[str, str]] = None,
918922
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
919923
priority: int = 0,
924+
data_parallel_rank: Optional[int] = None,
920925
) -> Coroutine[None, None, AsyncGenerator[Union[
921926
RequestOutput, PoolingRequestOutput], None]]:
922927
...
@@ -935,6 +940,7 @@ async def add_request(
935940
trace_headers: Optional[Mapping[str, str]] = None,
936941
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
937942
priority: int = 0,
943+
data_parallel_rank: Optional[int] = None,
938944
*,
939945
inputs: Optional[PromptType] = None, # DEPRECATED
940946
) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]:
@@ -967,6 +973,7 @@ async def add_request(
967973
trace_headers=trace_headers,
968974
prompt_adapter_request=prompt_adapter_request,
969975
priority=priority,
976+
data_parallel_rank=data_parallel_rank,
970977
)
971978

972979
return stream.generator()
@@ -980,6 +987,7 @@ async def generate(
980987
trace_headers: Optional[Mapping[str, str]] = None,
981988
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
982989
priority: int = 0,
990+
data_parallel_rank: Optional[int] = None,
983991
) -> AsyncGenerator[RequestOutput, None]:
984992
"""Generate outputs for a request.
985993
@@ -999,7 +1007,8 @@ async def generate(
9991007
for generation, if any.
10001008
priority: The priority of the request.
10011009
Only applicable with priority scheduling.
1002-
1010+
data_parallel_rank: The (global) data parallel rank that must
1011+
handle this request. Only applicable if DP is enabled.
10031012
Yields:
10041013
The output `RequestOutput` objects from the LLMEngine
10051014
for the request.
@@ -1057,6 +1066,7 @@ async def generate(
10571066
trace_headers=trace_headers,
10581067
prompt_adapter_request=prompt_adapter_request,
10591068
priority=priority,
1069+
data_parallel_rank=data_parallel_rank,
10601070
):
10611071
yield LLMEngine.validate_output(output, RequestOutput)
10621072
except asyncio.CancelledError:

vllm/v1/engine/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ class EngineCoreRequest(
5555
arrival_time: float
5656
lora_request: Optional[LoRARequest]
5757
cache_salt: Optional[str]
58+
data_parallel_rank: Optional[int]
5859

5960
# Index of the client, used to ensure outputs are sent back to the same
6061
# client for this request when scaling out the front-end.

vllm/v1/engine/async_llm.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,7 @@ async def add_request(
229229
trace_headers: Optional[Mapping[str, str]] = None,
230230
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
231231
priority: int = 0,
232+
data_parallel_rank: Optional[int] = None,
232233
) -> RequestOutputCollector:
233234
"""Add new request to the AsyncLLM."""
234235

@@ -245,7 +246,7 @@ async def add_request(
245246
prompt_str, request = self.processor.process_inputs(
246247
request_id, prompt, params, arrival_time, lora_request,
247248
tokenization_kwargs, trace_headers, prompt_adapter_request,
248-
priority)
249+
priority, data_parallel_rank)
249250

250251
if params.n == 1:
251252
await self._add_request(request, prompt_str, None, 0, queue)
@@ -291,6 +292,7 @@ async def generate(
291292
trace_headers: Optional[Mapping[str, str]] = None,
292293
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
293294
priority: int = 0,
295+
data_parallel_rank: Optional[int] = None,
294296
) -> AsyncGenerator[RequestOutput, None]:
295297
"""
296298
Main function called by the API server to kick off a request
@@ -321,6 +323,7 @@ async def generate(
321323
trace_headers=trace_headers,
322324
prompt_adapter_request=prompt_adapter_request,
323325
priority=priority,
326+
data_parallel_rank=data_parallel_rank,
324327
)
325328

326329
# The output_handler task pushes items into the queue.

vllm/v1/engine/core_client.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -982,7 +982,16 @@ async def run_engine_stats_update_task():
982982
resources.stats_update_task = asyncio.create_task(
983983
run_engine_stats_update_task())
984984

985-
def get_core_engine_for_request(self) -> CoreEngine:
985+
def get_core_engine_for_request(self,
986+
dp_rank: Optional[int] = None
987+
) -> CoreEngine:
988+
if dp_rank is not None:
989+
# engines are already in rank order
990+
if dp_rank < 0 or dp_rank >= len(self.core_engines):
991+
raise ValueError(f"Requested DP rank {dp_rank} is out of "
992+
f"range [0, {len(self.core_engines)})")
993+
return self.core_engines[dp_rank]
994+
986995
if not self.lb_engines:
987996
return self.core_engines[0]
988997
# TODO use P2C alg for larger DP sizes
@@ -1018,7 +1027,8 @@ async def add_request_async(self, request: EngineCoreRequest) -> None:
10181027
request.current_wave = self.current_wave
10191028
request.client_index = self.client_index
10201029

1021-
chosen_engine = self.get_core_engine_for_request()
1030+
chosen_engine = self.get_core_engine_for_request(
1031+
request.data_parallel_rank)
10221032
self.reqs_in_flight[request.request_id] = chosen_engine
10231033

10241034
to_await = self._send_input(EngineCoreRequestType.ADD, request,

vllm/v1/engine/processor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,7 @@ def process_inputs(
212212
trace_headers: Optional[Mapping[str, str]] = None,
213213
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
214214
priority: int = 0,
215+
data_parallel_rank: Optional[int] = None,
215216
) -> tuple[Optional[str], EngineCoreRequest]:
216217

217218
# TODO(woosuk): Support pooling models.
@@ -328,6 +329,7 @@ def process_inputs(
328329
arrival_time=arrival_time,
329330
lora_request=lora_request,
330331
cache_salt=decoder_inputs.get("cache_salt"),
332+
data_parallel_rank=data_parallel_rank,
331333
)
332334

333335
def _validate_model_inputs(self,

0 commit comments

Comments
 (0)