Skip to content

Commit 19c8630

Browse files
authored
[Frontend] Support cache_salt in /v1/completions and /v1/responses (#20981)
Signed-off-by: Marko Rosenmueller <5467316+dr75@users.noreply.github.com>
1 parent f29fd8a commit 19c8630

File tree

4 files changed

+77
-4
lines changed

4 files changed

+77
-4
lines changed

vllm/entrypoints/openai/api_server.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1540,6 +1540,7 @@ async def init_app_state(
15401540
state.openai_serving_models,
15411541
request_logger=request_logger,
15421542
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
1543+
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
15431544
enable_force_include_usage=args.enable_force_include_usage,
15441545
) if "generate" in model_config.supported_tasks else None
15451546
state.openai_serving_pooling = OpenAIServingPooling(

vllm/entrypoints/openai/protocol.py

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,15 @@ class ResponsesRequest(OpenAIBaseModel):
290290
"default: 0). Any priority other than 0 will raise an error "
291291
"if the served model does not use priority scheduling."),
292292
)
293+
cache_salt: Optional[str] = Field(
294+
default=None,
295+
description=(
296+
"If specified, the prefix cache will be salted with the provided "
297+
"string to prevent an attacker to guess prompts in multi-user "
298+
"environments. The salt should be random, protected from "
299+
"access by 3rd parties, and long enough to be "
300+
"unpredictable (e.g., 43 characters base64-encoded, corresponding "
301+
"to 256 bit). Not supported by vLLM engine V0."))
293302
# --8<-- [end:responses-extra-params]
294303

295304
_DEFAULT_SAMPLING_PARAMS = {
@@ -351,6 +360,19 @@ def validate_prompt(cls, data):
351360
raise ValueError("prompt template is not supported")
352361
return data
353362

363+
@model_validator(mode="before")
364+
def check_cache_salt_support(cls, data):
365+
if data.get("cache_salt") is not None:
366+
if not envs.VLLM_USE_V1:
367+
raise ValueError(
368+
"Parameter 'cache_salt' is not supported with "
369+
"this instance of vLLM, which uses engine V0.")
370+
if not isinstance(data["cache_salt"],
371+
str) or not data["cache_salt"]:
372+
raise ValueError("Parameter 'cache_salt' must be a "
373+
"non-empty string if provided.")
374+
return data
375+
354376

355377
class ChatCompletionRequest(OpenAIBaseModel):
356378
# Ordered by official OpenAI API documentation
@@ -1004,6 +1026,16 @@ class CompletionRequest(OpenAIBaseModel):
10041026
" as strings of the form 'token_id:{token_id}' so that tokens "
10051027
"that are not JSON-encodable can be identified."))
10061028

1029+
cache_salt: Optional[str] = Field(
1030+
default=None,
1031+
description=(
1032+
"If specified, the prefix cache will be salted with the provided "
1033+
"string to prevent an attacker to guess prompts in multi-user "
1034+
"environments. The salt should be random, protected from "
1035+
"access by 3rd parties, and long enough to be "
1036+
"unpredictable (e.g., 43 characters base64-encoded, corresponding "
1037+
"to 256 bit). Not supported by vLLM engine V0."))
1038+
10071039
kv_transfer_params: Optional[dict[str, Any]] = Field(
10081040
default=None,
10091041
description="KVTransfer parameters used for disaggregated serving.")
@@ -1180,6 +1212,20 @@ def validate_prompt_and_prompt_embeds(cls, data):
11801212
"At least one of `prompt` or `prompt_embeds` must be set.")
11811213
return data
11821214

1215+
@model_validator(mode="before")
1216+
@classmethod
1217+
def check_cache_salt_support(cls, data):
1218+
if data.get("cache_salt") is not None:
1219+
if not envs.VLLM_USE_V1:
1220+
raise ValueError(
1221+
"Parameter 'cache_salt' is not supported with "
1222+
"this instance of vLLM, which uses engine V0.")
1223+
if not isinstance(data["cache_salt"],
1224+
str) or not data["cache_salt"]:
1225+
raise ValueError("Parameter 'cache_salt' must be a "
1226+
"non-empty string if provided.")
1227+
return data
1228+
11831229

11841230
class EmbeddingCompletionRequest(OpenAIBaseModel):
11851231
# Ordered by official OpenAI API documentation
@@ -1971,7 +2017,7 @@ class TranscriptionRequest(OpenAIBaseModel):
19712017
"""
19722018

19732019
stream: Optional[bool] = False
1974-
"""When set, it will enable output to be streamed in a similar fashion
2020+
"""When set, it will enable output to be streamed in a similar fashion
19752021
as the Chat Completion endpoint.
19762022
"""
19772023
# --8<-- [start:transcription-extra-params]
@@ -2233,9 +2279,9 @@ class TranslationRequest(OpenAIBaseModel):
22332279
"""
22342280

22352281
stream: Optional[bool] = False
2236-
"""Custom field not present in the original OpenAI definition. When set,
2282+
"""Custom field not present in the original OpenAI definition. When set,
22372283
it will enable output to be streamed in a similar fashion as the Chat
2238-
Completion endpoint.
2284+
Completion endpoint.
22392285
"""
22402286
# Flattened stream option to simplify form data.
22412287
stream_include_usage: Optional[bool] = False

vllm/entrypoints/openai/serving_completion.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
CompletionResponseStreamChoice,
2424
CompletionStreamResponse,
2525
ErrorResponse,
26+
PromptTokenUsageInfo,
2627
RequestResponseMetadata,
2728
UsageInfo)
2829
from vllm.entrypoints.openai.serving_engine import (
@@ -56,6 +57,7 @@ def __init__(
5657
*,
5758
request_logger: Optional[RequestLogger],
5859
return_tokens_as_token_ids: bool = False,
60+
enable_prompt_tokens_details: bool = False,
5961
enable_force_include_usage: bool = False,
6062
):
6163
super().__init__(engine_client=engine_client,
@@ -64,6 +66,7 @@ def __init__(
6466
request_logger=request_logger,
6567
return_tokens_as_token_ids=return_tokens_as_token_ids,
6668
enable_force_include_usage=enable_force_include_usage)
69+
self.enable_prompt_tokens_details = enable_prompt_tokens_details
6770
self.default_sampling_params = (
6871
self.model_config.get_diff_sampling_param())
6972
if self.default_sampling_params:
@@ -313,6 +316,8 @@ async def completion_stream_generator(
313316
previous_num_tokens = [0] * num_choices * num_prompts
314317
has_echoed = [False] * num_choices * num_prompts
315318
num_prompt_tokens = [0] * num_prompts
319+
num_cached_tokens = None
320+
first_iteration = True
316321

317322
stream_options = request.stream_options
318323
if stream_options:
@@ -328,6 +333,10 @@ async def completion_stream_generator(
328333
prompt_token_ids = res.prompt_token_ids
329334
prompt_logprobs = res.prompt_logprobs
330335

336+
if first_iteration:
337+
num_cached_tokens = res.num_cached_tokens
338+
first_iteration = False
339+
331340
if res.prompt is not None:
332341
prompt_text = res.prompt
333342
else:
@@ -431,6 +440,10 @@ async def completion_stream_generator(
431440
completion_tokens=total_completion_tokens,
432441
total_tokens=total_prompt_tokens + total_completion_tokens)
433442

443+
if self.enable_prompt_tokens_details and num_cached_tokens:
444+
final_usage_info.prompt_tokens_details = PromptTokenUsageInfo(
445+
cached_tokens=num_cached_tokens)
446+
434447
if include_usage:
435448
final_usage_chunk = CompletionStreamResponse(
436449
id=request_id,
@@ -535,6 +548,10 @@ def request_output_to_completion_response(
535548
total_tokens=num_prompt_tokens + num_generated_tokens,
536549
)
537550

551+
if self.enable_prompt_tokens_details and final_res.num_cached_tokens:
552+
usage.prompt_tokens_details = PromptTokenUsageInfo(
553+
cached_tokens=final_res.num_cached_tokens)
554+
538555
request_metadata.final_usage_info = usage
539556

540557
return CompletionResponse(

vllm/entrypoints/openai/serving_engine.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ def __init__(
226226

227227
def _get_async_tokenizer(self, tokenizer) -> AsyncMicrobatchTokenizer:
228228
"""
229-
Return (and cache) an `AsyncMicrobatchTokenizer` bound to the
229+
Return (and cache) an `AsyncMicrobatchTokenizer` bound to the
230230
given tokenizer.
231231
"""
232232
async_tokenizer = self._async_tokenizer_pool.get(tokenizer)
@@ -811,6 +811,12 @@ async def _preprocess_completion(
811811
prompt_token_ids=request_prompt_text["prompt_token_ids"])
812812
for request_prompt_text in request_prompts_text
813813
]
814+
cache_salt = request.cache_salt if (
815+
hasattr(request, "cache_salt")
816+
and request.cache_salt is not None) else None
817+
if cache_salt:
818+
for prompt_text in engine_prompts_text:
819+
prompt_text["cache_salt"] = cache_salt
814820

815821
# This check is equivalent to simply checking if
816822
# `request_prompts_embeds` is empty, but it's difficult to propagate
@@ -828,6 +834,9 @@ async def _preprocess_completion(
828834
prompt_embeds=request_prompt_embeds["prompt_embeds"])
829835
for request_prompt_embeds in request_prompts_embeds
830836
]
837+
if cache_salt:
838+
for prompt_embed in engine_prompts_embeds:
839+
prompt_embed["cache_salt"] = cache_salt
831840

832841
request_prompts = request_prompts_embeds + request_prompts_text
833842
engine_prompts = engine_prompts_embeds + engine_prompts_text

0 commit comments

Comments
 (0)