Skip to content

Commit 7353492

Browse files
authored
[Core] Raise when non-multi-instance DP clients target a DP rank (#19227)
Signed-off-by: Jon Swenson <jmswen@gmail.com>
1 parent 7661e92 commit 7353492

File tree

6 files changed

+77
-12
lines changed

6 files changed

+77
-12
lines changed

tests/async_engine/test_async_llm_engine.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -384,3 +384,25 @@ async def test_delayed_generator(async_engine, stop):
384384
assert final_output is not None
385385
assert len(final_output.outputs[0].token_ids) == 10
386386
assert final_output.finished
387+
388+
389+
@pytest.mark.asyncio(scope="module")
390+
async def test_invalid_argument(async_engine):
391+
scheduler_config = await async_engine.get_scheduler_config()
392+
393+
if scheduler_config.num_scheduler_steps != 1:
394+
pytest.skip("no need to test this one with multistep")
395+
396+
sampling_params = SamplingParams(
397+
temperature=0,
398+
min_tokens=10,
399+
max_tokens=10,
400+
)
401+
402+
# Targeting specific DP rank only supported in v1 multi-instance DP
403+
with pytest.raises(ValueError):
404+
async for _ in async_engine.generate("test",
405+
sampling_params,
406+
request_id=uid(),
407+
data_parallel_rank=0):
408+
pass

tests/v1/engine/test_async_llm.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,3 +250,32 @@ async def test_customize_loggers(monkeypatch):
250250
assert len(engine.stat_loggers) == 1
251251
assert len(engine.stat_loggers[0]) == 1
252252
engine.stat_loggers[0][0].log.assert_called_once()
253+
254+
255+
@pytest.mark.asyncio(scope="module")
256+
async def test_dp_rank_argument(monkeypatch: pytest.MonkeyPatch):
257+
with monkeypatch.context() as m, ExitStack() as after:
258+
m.setenv("VLLM_USE_V1", "1")
259+
260+
engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
261+
after.callback(engine.shutdown)
262+
263+
sampling_params = SamplingParams(max_tokens=100,
264+
output_kind=RequestOutputKind.DELTA,
265+
temperature=1.0,
266+
seed=33)
267+
268+
# Test with valid DP rank.
269+
async for _ in engine.generate(request_id="request-34",
270+
prompt=TEXT_PROMPT,
271+
sampling_params=sampling_params,
272+
data_parallel_rank=0):
273+
pass
274+
275+
# Test with out-of-range DP rank.
276+
with pytest.raises(ValueError):
277+
async for _ in engine.generate(request_id="request-35",
278+
prompt=TEXT_PROMPT,
279+
sampling_params=sampling_params,
280+
data_parallel_rank=1):
281+
pass

tests/v1/test_async_llm_dp.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,14 @@
2929
allow_module_level=True)
3030

3131

32-
async def generate(engine: AsyncLLM,
33-
request_id: str,
34-
prompt: PromptType,
35-
output_kind: RequestOutputKind,
36-
max_tokens: int,
37-
prompt_logprobs: Optional[int] = None) -> tuple[int, str]:
32+
async def generate(
33+
engine: AsyncLLM,
34+
request_id: str,
35+
prompt: PromptType,
36+
output_kind: RequestOutputKind,
37+
max_tokens: int,
38+
prompt_logprobs: Optional[int] = None,
39+
data_parallel_rank: Optional[int] = None) -> tuple[int, str]:
3840
# Ensure generate doesn't complete too fast for cancellation test.
3941
await asyncio.sleep(0.2)
4042

@@ -46,7 +48,8 @@ async def generate(engine: AsyncLLM,
4648
prompt_logprobs=prompt_logprobs)
4749
async for out in engine.generate(request_id=request_id,
4850
prompt=prompt,
49-
sampling_params=sampling_params):
51+
sampling_params=sampling_params,
52+
data_parallel_rank=data_parallel_rank):
5053

5154
num_tokens = len(out.outputs[0].token_ids)
5255
if output_kind == RequestOutputKind.DELTA:
@@ -89,8 +92,12 @@ async def test_load(output_kind: RequestOutputKind,
8992
for request_id in request_ids:
9093
tasks.append(
9194
asyncio.create_task(
92-
generate(engine, request_id, prompt, output_kind,
93-
NUM_EXPECTED_TOKENS)))
95+
generate(engine,
96+
request_id,
97+
prompt,
98+
output_kind,
99+
NUM_EXPECTED_TOKENS,
100+
data_parallel_rank=0)))
94101
# Confirm that we got all the EXPECTED tokens from the requests.
95102
done, pending = await asyncio.wait(tasks,
96103
return_when=asyncio.FIRST_EXCEPTION)

vllm/engine/async_llm_engine.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -494,6 +494,10 @@ async def add_request_async(
494494
if arrival_time is None:
495495
arrival_time = time.time()
496496

497+
if data_parallel_rank is not None:
498+
raise ValueError("Targeting data_parallel_rank only supported "
499+
"in v1 client.")
500+
497501
if (isinstance(prompt, dict)
498502
and prompt.get("prompt_embeds", None) is not None
499503
and not prompt.get("prompt_token_ids", None)):

vllm/v1/engine/core_client.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1000,9 +1000,6 @@ def get_core_engine_for_request(self,
10001000
) -> CoreEngine:
10011001
if dp_rank is not None:
10021002
# engines are already in rank order
1003-
if dp_rank < 0 or dp_rank >= len(self.core_engines):
1004-
raise ValueError(f"Requested DP rank {dp_rank} is out of "
1005-
f"range [0, {len(self.core_engines)})")
10061003
return self.core_engines[dp_rank]
10071004

10081005
if not self.lb_engines:

vllm/v1/engine/processor.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,12 @@ def process_inputs(
226226
if prompt_adapter_request is not None:
227227
raise ValueError("V1 does not support prompt_adapter_request.")
228228

229+
data_parallel_size = self.vllm_config.parallel_config.data_parallel_size
230+
if data_parallel_rank is not None and not (0 <= data_parallel_rank <
231+
data_parallel_size):
232+
raise ValueError(f"data_parallel_rank {data_parallel_rank} "
233+
f"is out of range [0, {data_parallel_size}).")
234+
229235
if arrival_time is None:
230236
arrival_time = time.time()
231237

0 commit comments

Comments
 (0)