Skip to content

Commit 99c5593

Browse files
committed
Merge branch 'main' of github.com:character-tech/vllm into cached_tokens_completions
Signed-off-by: Amog Kamsetty <amogkamsetty@gmail.com>
2 parents 0efe79e + 6c7451c commit 99c5593

File tree

3 files changed

+62
-1
lines changed

3 files changed

+62
-1
lines changed

vllm/entrypoints/openai/protocol.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -875,6 +875,13 @@ class CompletionRequest(OpenAIBaseModel):
875875
description="KVTransfer parameters used for disaggregated serving.")
876876

877877
# --8<-- [end:completion-extra-params]
878+
accumulate: Optional[bool] = Field(
879+
default=None,
880+
description=(
881+
"Special kind of echo where in the response instead of delta we return the accumulated text"
882+
)
883+
)
884+
# doc: end-completion-extra-params
878885

879886
# Default sampling parameters for completion requests
880887
_DEFAULT_SAMPLING_PARAMS: dict = {
@@ -1323,6 +1330,47 @@ class PoolingResponse(OpenAIBaseModel):
13231330
usage: UsageInfo
13241331

13251332

1333+
class ClassificationRequest(OpenAIBaseModel):
1334+
model: Optional[str] = None
1335+
input: Union[list[str], str]
1336+
truncate_prompt_tokens: Optional[int] = None
1337+
user: Optional[str] = None
1338+
1339+
# --8<-- [start:classification-pooling-params]
1340+
additional_data: Optional[Any] = None
1341+
# --8<-- [end:classification-pooling-params]
1342+
1343+
# --8<-- [start:classification-extra-params]
1344+
priority: int = Field(
1345+
default=0,
1346+
description=(
1347+
"The priority of the request (lower means earlier handling; "
1348+
"default: 0). Any priority other than 0 will raise an error "
1349+
"if the served model does not use priority scheduling."),
1350+
)
1351+
1352+
# --8<-- [end:classification-extra-params]
1353+
1354+
def to_pooling_params(self):
1355+
return PoolingParams(additional_data=self.additional_data)
1356+
1357+
1358+
class ClassificationData(OpenAIBaseModel):
1359+
index: int
1360+
label: Optional[str]
1361+
probs: list[float]
1362+
num_classes: int
1363+
1364+
1365+
class ClassificationResponse(OpenAIBaseModel):
1366+
id: str = Field(default_factory=lambda: f"classify-{random_uuid()}")
1367+
object: str = "list"
1368+
created: int = Field(default_factory=lambda: int(time.time()))
1369+
model: str
1370+
data: list[ClassificationData]
1371+
usage: UsageInfo
1372+
1373+
13261374
class ScoreResponseData(OpenAIBaseModel):
13271375
index: int
13281376
object: str = "score"

vllm/entrypoints/openai/serving_completion.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,9 @@ async def completion_stream_generator(
301301
has_echoed = [False] * num_choices * num_prompts
302302
num_prompt_tokens = [0] * num_prompts
303303
num_cached_tokens = [0] * num_prompts
304+
accumulated_text = [""] * num_choices * num_prompts
305+
accumulated_tokens = [[] * num_choices * num_prompts]
306+
accumulated_logprobs = [[] * num_choices * num_prompts]
304307

305308
stream_options = request.stream_options
306309
if stream_options:
@@ -352,6 +355,16 @@ async def completion_stream_generator(
352355
*(output.logprobs or []),
353356
]
354357
has_echoed[i] = True
358+
elif request.accumulate:
359+
i = output.index + prompt_idx * num_choices
360+
# return the accumulated response
361+
accumulated_text[i] += output.text
362+
accumulated_tokens[i].extend(output.token_ids)
363+
accumulated_logprobs[i].extend(output.logprobs or [])
364+
365+
delta_text = accumulated_text[i]
366+
delta_token_ids = accumulated_tokens[i]
367+
out_logprobs = accumulated_logprobs[i]
355368
else:
356369
# return just the delta
357370
delta_text = output.text

vllm/model_executor/models/transformers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,7 @@ def create_attention_instances(self) -> dict[int, Attention]:
306306
self.config.global_attention_layers, list):
307307
global_attention_layers = self.config.global_attention_layers
308308
else:
309-
global_attention_layers = None
309+
global_attention_layers = []
310310

311311
for i in range(start, end):
312312
sliding_window = None

0 commit comments

Comments
 (0)