Skip to content

Commit 221806e

Browse files
afeldman-nmnjhill
authored andcommitted
[Frontend] Expose custom args in OpenAI APIs (vllm-project#16862)
Signed-off-by: Andrew Feldman <afeldman@neuralmagic.com> Signed-off-by: Andrew Feldman <afeldman@redhat.com> Co-authored-by: Nick Hill <nhill@redhat.com> Signed-off-by: minpeter <kali2005611@gmail.com>
1 parent 5563679 commit 221806e

File tree

3 files changed

+44
-14
lines changed

3 files changed

+44
-14
lines changed

benchmarks/kernels/benchmark_moe_align_block_size.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@
44
import itertools
55

66
import torch
7-
import triton
87

98
from vllm import _custom_ops as ops
109
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
1110
moe_align_block_size_triton,
1211
)
12+
from vllm.triton_utils import triton
1313

1414

1515
def get_topk_ids(num_tokens: int, num_experts: int, topk: int) -> torch.Tensor:

vllm/entrypoints/openai/protocol.py

Lines changed: 41 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -326,8 +326,9 @@ class ChatCompletionRequest(OpenAIBaseModel):
326326
)
327327
chat_template_kwargs: Optional[dict[str, Any]] = Field(
328328
default=None,
329-
description=("Additional kwargs to pass to the template renderer. "
330-
"Will be accessible by the chat template."),
329+
description=(
330+
"Additional keyword args to pass to the template renderer. "
331+
"Will be accessible by the chat template."),
331332
)
332333
mm_processor_kwargs: Optional[dict[str, Any]] = Field(
333334
default=None,
@@ -414,6 +415,12 @@ class ChatCompletionRequest(OpenAIBaseModel):
414415
default=None,
415416
description="KVTransfer parameters used for disaggregated serving.")
416417

418+
vllm_xargs: Optional[dict[str, Union[str, int, float]]] = Field(
419+
default=None,
420+
description=("Additional request parameters with string or "
421+
"numeric values, used by custom extensions."),
422+
)
423+
417424
# --8<-- [end:chat-completion-extra-params]
418425

419426
# Default sampling parameters for chat completion requests
@@ -523,6 +530,10 @@ def to_sampling_params(
523530
structural_tag=self.structural_tag,
524531
)
525532

533+
extra_args: dict[str, Any] = self.vllm_xargs if self.vllm_xargs else {}
534+
if self.kv_transfer_params:
535+
# Pass in kv_transfer_params via extra_args
536+
extra_args["kv_transfer_params"] = self.kv_transfer_params
526537
return SamplingParams.from_optional(
527538
n=self.n,
528539
best_of=self.best_of,
@@ -553,8 +564,8 @@ def to_sampling_params(
553564
logit_bias=self.logit_bias,
554565
bad_words= self.bad_words,
555566
allowed_token_ids=self.allowed_token_ids,
556-
extra_args=({"kv_transfer_params": self.kv_transfer_params}
557-
if self.kv_transfer_params else None))
567+
extra_args=extra_args or None,
568+
)
558569

559570
def _get_guided_json_from_tool(
560571
self) -> Optional[Union[str, dict, BaseModel]]:
@@ -871,6 +882,12 @@ class CompletionRequest(OpenAIBaseModel):
871882
default=None,
872883
description="KVTransfer parameters used for disaggregated serving.")
873884

885+
vllm_xargs: Optional[dict[str, Union[str, int, float]]] = Field(
886+
default=None,
887+
description=("Additional request parameters with string or "
888+
"numeric values, used by custom extensions."),
889+
)
890+
874891
# --8<-- [end:completion-extra-params]
875892

876893
# Default sampling parameters for completion requests
@@ -968,6 +985,10 @@ def to_sampling_params(
968985
whitespace_pattern=self.guided_whitespace_pattern,
969986
)
970987

988+
extra_args: dict[str, Any] = self.vllm_xargs if self.vllm_xargs else {}
989+
if self.kv_transfer_params:
990+
# Pass in kv_transfer_params via extra_args
991+
extra_args["kv_transfer_params"] = self.kv_transfer_params
971992
return SamplingParams.from_optional(
972993
n=self.n,
973994
best_of=self.best_of,
@@ -997,8 +1018,8 @@ def to_sampling_params(
9971018
guided_decoding=guided_decoding,
9981019
logit_bias=self.logit_bias,
9991020
allowed_token_ids=self.allowed_token_ids,
1000-
extra_args=({"kv_transfer_params": self.kv_transfer_params}
1001-
if self.kv_transfer_params else None))
1021+
extra_args=extra_args or None,
1022+
)
10021023

10031024
@model_validator(mode="before")
10041025
@classmethod
@@ -1117,8 +1138,9 @@ class EmbeddingChatRequest(OpenAIBaseModel):
11171138
)
11181139
chat_template_kwargs: Optional[dict[str, Any]] = Field(
11191140
default=None,
1120-
description=("Additional kwargs to pass to the template renderer. "
1121-
"Will be accessible by the chat template."),
1141+
description=(
1142+
"Additional keyword args to pass to the template renderer. "
1143+
"Will be accessible by the chat template."),
11221144
)
11231145
mm_processor_kwargs: Optional[dict[str, Any]] = Field(
11241146
default=None,
@@ -1623,8 +1645,9 @@ class TokenizeChatRequest(OpenAIBaseModel):
16231645
)
16241646
chat_template_kwargs: Optional[dict[str, Any]] = Field(
16251647
default=None,
1626-
description=("Additional kwargs to pass to the template renderer. "
1627-
"Will be accessible by the chat template."),
1648+
description=(
1649+
"Additional keyword args to pass to the template renderer. "
1650+
"Will be accessible by the chat template."),
16281651
)
16291652
mm_processor_kwargs: Optional[dict[str, Any]] = Field(
16301653
default=None,
@@ -1736,6 +1759,12 @@ class TranscriptionRequest(OpenAIBaseModel):
17361759
# Flattened stream option to simplify form data.
17371760
stream_include_usage: Optional[bool] = False
17381761
stream_continuous_usage_stats: Optional[bool] = False
1762+
1763+
vllm_xargs: Optional[dict[str, Union[str, int, float]]] = Field(
1764+
default=None,
1765+
description=("Additional request parameters with string or "
1766+
"numeric values, used by custom extensions."),
1767+
)
17391768
# --8<-- [end:transcription-extra-params]
17401769

17411770
# --8<-- [start:transcription-sampling-params]
@@ -1823,7 +1852,8 @@ def to_sampling_params(
18231852
presence_penalty=self.presence_penalty,
18241853
output_kind=RequestOutputKind.DELTA
18251854
if self.stream \
1826-
else RequestOutputKind.FINAL_ONLY)
1855+
else RequestOutputKind.FINAL_ONLY,
1856+
extra_args=self.vllm_xargs)
18271857

18281858
@model_validator(mode="before")
18291859
@classmethod

vllm/sampling_params.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -198,8 +198,8 @@ class SamplingParams(
198198
processor which only retains scores for the given token ids.
199199
Defaults to None.
200200
extra_args: Arbitrary additional args, that can be used by custom
201-
sampling implementations. Not used by any in-tree sampling
202-
implementations.
201+
sampling implementations, plugins, etc. Not used by any in-tree
202+
sampling implementations.
203203
"""
204204

205205
n: int = 1

0 commit comments

Comments
 (0)