Skip to content

Commit fdc5b43

Browse files
[Bugfix]: Fix final_res_batch list index out of range error (#21055)
Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com>
1 parent c5b8b59 commit fdc5b43

File tree

2 files changed

+78
-40
lines changed

2 files changed

+78
-40
lines changed

tests/v1/entrypoints/openai/test_completion.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import pytest
88
import pytest_asyncio
99
import regex as re
10+
import requests
1011
from openai import BadRequestError
1112

1213
from tests.utils import RemoteOpenAIServer
@@ -26,7 +27,8 @@ def default_server_args():
2627
"2048",
2728
"--max-num-seqs",
2829
"128",
29-
"--enforce-eager"
30+
"--enforce-eager",
31+
"--enable-prompt-tokens-details",
3032
]
3133

3234

@@ -679,3 +681,17 @@ async def test_invalid_grammar(client: openai.AsyncOpenAI, model_name: str):
679681
prompt=prompt,
680682
extra_body={"guided_grammar": invalid_simplified_sql_grammar},
681683
)
684+
685+
686+
@pytest.mark.asyncio
687+
async def test_completion_with_empty_prompt_embeds(
688+
client: openai.AsyncOpenAI) -> None:
689+
"""Test completion with empty prompt embeds."""
690+
payload: dict[str, list] = {"prompt_embeds": []}
691+
headers: dict[str, str] = {"Content-Type": "application/json"}
692+
# base_url = http://localhost:8000/v1/completions
693+
response = requests.post(f"{client.base_url}completions",
694+
headers=headers,
695+
json=payload)
696+
assert response.status_code == 200, (
697+
f"Expected status code 200, got {response.status_code}. ")

vllm/entrypoints/openai/serving_completion.py

Lines changed: 61 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -60,20 +60,25 @@ def __init__(
6060
enable_prompt_tokens_details: bool = False,
6161
enable_force_include_usage: bool = False,
6262
):
63-
super().__init__(engine_client=engine_client,
64-
model_config=model_config,
65-
models=models,
66-
request_logger=request_logger,
67-
return_tokens_as_token_ids=return_tokens_as_token_ids,
68-
enable_force_include_usage=enable_force_include_usage)
63+
super().__init__(
64+
engine_client=engine_client,
65+
model_config=model_config,
66+
models=models,
67+
request_logger=request_logger,
68+
return_tokens_as_token_ids=return_tokens_as_token_ids,
69+
enable_force_include_usage=enable_force_include_usage,
70+
)
6971
self.enable_prompt_tokens_details = enable_prompt_tokens_details
7072
self.default_sampling_params = (
7173
self.model_config.get_diff_sampling_param())
7274
if self.default_sampling_params:
7375
source = self.model_config.generation_config
7476
source = "model" if source == "auto" else source
75-
logger.info("Using default completion sampling params from %s: %s",
76-
source, self.default_sampling_params)
77+
logger.info(
78+
"Using default completion sampling params from %s: %s",
79+
source,
80+
self.default_sampling_params,
81+
)
7782

7883
async def create_completion(
7984
self,
@@ -172,23 +177,28 @@ async def create_completion(
172177
max_model_len=self.max_model_len,
173178
request=request,
174179
input_length=input_length,
175-
default_sampling_params=self.default_sampling_params)
180+
default_sampling_params=self.default_sampling_params,
181+
)
176182

177183
if request.use_beam_search:
178184
sampling_params = request.to_beam_search_params(
179185
max_tokens, self.default_sampling_params)
180186
else:
181187
sampling_params = request.to_sampling_params(
182-
max_tokens, self.model_config.logits_processor_pattern,
183-
self.default_sampling_params)
188+
max_tokens,
189+
self.model_config.logits_processor_pattern,
190+
self.default_sampling_params,
191+
)
184192

185193
request_id_item = f"{request_id}-{i}"
186194

187-
self._log_inputs(request_id_item,
188-
request_prompts[i],
189-
params=sampling_params,
190-
lora_request=lora_request,
191-
prompt_adapter_request=prompt_adapter_request)
195+
self._log_inputs(
196+
request_id_item,
197+
request_prompts[i],
198+
params=sampling_params,
199+
lora_request=lora_request,
200+
prompt_adapter_request=prompt_adapter_request,
201+
)
192202

193203
trace_headers = (None if raw_request is None else await
194204
self._get_trace_headers(raw_request.headers))
@@ -245,7 +255,8 @@ async def create_completion(
245255
num_prompts=num_prompts,
246256
tokenizer=tokenizer,
247257
request_metadata=request_metadata,
248-
enable_force_include_usage=self.enable_force_include_usage)
258+
enable_force_include_usage=self.enable_force_include_usage,
259+
)
249260

250261
# Non-streaming response
251262
final_res_batch: list[Optional[RequestOutput]] = [None] * num_prompts
@@ -321,10 +332,10 @@ async def completion_stream_generator(
321332

322333
stream_options = request.stream_options
323334
if stream_options:
324-
include_usage = stream_options.include_usage or \
325-
enable_force_include_usage
326-
include_continuous_usage = include_usage and \
327-
stream_options.continuous_usage_stats
335+
include_usage = (stream_options.include_usage
336+
or enable_force_include_usage)
337+
include_continuous_usage = (include_usage and
338+
stream_options.continuous_usage_stats)
328339
else:
329340
include_usage, include_continuous_usage = False, False
330341

@@ -370,7 +381,8 @@ async def completion_stream_generator(
370381
# echo the prompt and first token
371382
delta_text = prompt_text + output.text
372383
delta_token_ids = [
373-
*prompt_token_ids, *output.token_ids
384+
*prompt_token_ids,
385+
*output.token_ids,
374386
]
375387
out_logprobs = [
376388
*(prompt_logprobs or []),
@@ -383,8 +395,8 @@ async def completion_stream_generator(
383395
delta_token_ids = output.token_ids
384396
out_logprobs = output.logprobs
385397

386-
if not delta_text and not delta_token_ids \
387-
and not previous_num_tokens[i]:
398+
if (not delta_text and not delta_token_ids
399+
and not previous_num_tokens[i]):
388400
# Chunked prefill case, don't return empty chunks
389401
continue
390402

@@ -420,7 +432,8 @@ async def completion_stream_generator(
420432
finish_reason=finish_reason,
421433
stop_reason=stop_reason,
422434
)
423-
])
435+
],
436+
)
424437
if include_continuous_usage:
425438
prompt_tokens = num_prompt_tokens[prompt_idx]
426439
completion_tokens = previous_num_tokens[i]
@@ -438,7 +451,8 @@ async def completion_stream_generator(
438451
final_usage_info = UsageInfo(
439452
prompt_tokens=total_prompt_tokens,
440453
completion_tokens=total_completion_tokens,
441-
total_tokens=total_prompt_tokens + total_completion_tokens)
454+
total_tokens=total_prompt_tokens + total_completion_tokens,
455+
)
442456

443457
if self.enable_prompt_tokens_details and num_cached_tokens:
444458
final_usage_info.prompt_tokens_details = PromptTokenUsageInfo(
@@ -452,8 +466,8 @@ async def completion_stream_generator(
452466
choices=[],
453467
usage=final_usage_info,
454468
)
455-
final_usage_data = (final_usage_chunk.model_dump_json(
456-
exclude_unset=False, exclude_none=True))
469+
final_usage_data = final_usage_chunk.model_dump_json(
470+
exclude_unset=False, exclude_none=True)
457471
yield f"data: {final_usage_data}\n\n"
458472

459473
# report to FastAPI middleware aggregate usage across all choices
@@ -478,8 +492,10 @@ def request_output_to_completion_response(
478492
choices: list[CompletionResponseChoice] = []
479493
num_prompt_tokens = 0
480494
num_generated_tokens = 0
481-
495+
kv_transfer_params = None
496+
last_final_res = None
482497
for final_res in final_res_batch:
498+
last_final_res = final_res
483499
prompt_token_ids = final_res.prompt_token_ids
484500
assert prompt_token_ids is not None
485501
prompt_logprobs = clamp_prompt_logprobs(final_res.prompt_logprobs)
@@ -548,19 +564,22 @@ def request_output_to_completion_response(
548564
total_tokens=num_prompt_tokens + num_generated_tokens,
549565
)
550566

551-
if self.enable_prompt_tokens_details and final_res.num_cached_tokens:
567+
if (self.enable_prompt_tokens_details and last_final_res
568+
and last_final_res.num_cached_tokens):
552569
usage.prompt_tokens_details = PromptTokenUsageInfo(
553-
cached_tokens=final_res.num_cached_tokens)
570+
cached_tokens=last_final_res.num_cached_tokens)
554571

555572
request_metadata.final_usage_info = usage
556-
573+
if final_res_batch:
574+
kv_transfer_params = final_res_batch[0].kv_transfer_params
557575
return CompletionResponse(
558576
id=request_id,
559577
created=created_time,
560578
model=model_name,
561579
choices=choices,
562580
usage=usage,
563-
kv_transfer_params=final_res_batch[0].kv_transfer_params)
581+
kv_transfer_params=kv_transfer_params,
582+
)
564583

565584
def _create_completion_logprobs(
566585
self,
@@ -579,8 +598,9 @@ def _create_completion_logprobs(
579598

580599
last_token_len = 0
581600

582-
should_return_as_token_id = return_as_token_id if \
583-
return_as_token_id is not None else self.return_tokens_as_token_ids
601+
should_return_as_token_id = (return_as_token_id
602+
if return_as_token_id is not None else
603+
self.return_tokens_as_token_ids)
584604
for i, token_id in enumerate(token_ids):
585605
step_top_logprobs = top_logprobs[i]
586606
if step_top_logprobs is None:
@@ -612,10 +632,12 @@ def _create_completion_logprobs(
612632
out_top_logprobs.append({
613633
# Convert float("-inf") to the
614634
# JSON-serializable float that OpenAI uses
615-
self._get_decoded_token(top_lp[1],
616-
top_lp[0],
617-
tokenizer,
618-
return_as_token_id=should_return_as_token_id):
635+
self._get_decoded_token(
636+
top_lp[1],
637+
top_lp[0],
638+
tokenizer,
639+
return_as_token_id=should_return_as_token_id,
640+
):
619641
max(top_lp[1].logprob, -9999.0)
620642
for i, top_lp in enumerate(step_top_logprobs.items())
621643
if num_output_top_logprobs >= i

0 commit comments

Comments
 (0)