Skip to content

Commit c7efd01

Browse files
committed
[Frontend] OpenAI Responses API supports Tool/Function calling
Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com>
1 parent 76ddeff commit c7efd01

File tree

4 files changed

+212
-29
lines changed

4 files changed

+212
-29
lines changed

vllm/entrypoints/openai/protocol.py

Lines changed: 99 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,11 @@
1717
from openai.types.chat.chat_completion_message import (
1818
Annotation as OpenAIAnnotation)
1919
# yapf: enable
20-
from openai.types.responses import (ResponseInputParam, ResponseOutputItem,
20+
from openai.types.responses import (ResponseFunctionToolCall,
21+
ResponseInputParam, ResponseOutputItem,
2122
ResponseOutputMessage, ResponsePrompt,
22-
ResponseStatus, ResponseTextConfig)
23+
ResponseStatus, ResponseTextConfig,
24+
ToolChoiceFunction)
2325
from openai.types.responses.response import ToolChoice
2426
from openai.types.responses.tool import Tool
2527
from openai.types.shared import Metadata, Reasoning
@@ -324,16 +326,7 @@ def to_sampling_params(
324326
top_p = default_sampling_params.get(
325327
"top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"])
326328

327-
# Structured output
328-
guided_decoding = None
329-
if self.text is not None and self.text.format is not None:
330-
response_format = self.text.format
331-
if response_format.type == "json_schema":
332-
guided_decoding = GuidedDecodingParams.from_optional(
333-
json=response_format.schema_)
334-
elif response_format.type == "json_object":
335-
raise NotImplementedError("json_object is not supported")
336-
329+
guided_decoding = self._get_guided_decoding()
337330
# TODO: add more parameters
338331
return SamplingParams.from_optional(
339332
temperature=temperature,
@@ -360,7 +353,7 @@ def validate_prompt(cls, data):
360353
raise ValueError("prompt template is not supported")
361354
return data
362355

363-
@model_validator(mode="before")
356+
@model_validator(mode="before")
364357
def check_cache_salt_support(cls, data):
365358
if data.get("cache_salt") is not None:
366359
if not envs.VLLM_USE_V1:
@@ -373,6 +366,97 @@ def check_cache_salt_support(cls, data):
373366
"non-empty string if provided.")
374367
return data
375368

369+
def _get_guided_json_from_tool(
370+
self) -> Optional[Union[str, dict, BaseModel]]:
371+
print(
372+
f"Tool choice: {self.tool_choice}, type: {type(self.tool_choice)}")
373+
# user has chosen to use a named tool
374+
if type(self.tool_choice) is ToolChoiceFunction:
375+
tool_name = self.tool_choice.name
376+
tools = {tool.name: tool for tool in \
377+
self.tools if tool.type == "function"}
378+
if tool_name not in tools:
379+
raise ValueError(
380+
f"Tool '{tool_name}' has not been passed in `tools`.")
381+
tool = tools[tool_name]
382+
print(f"Using tool '{tool_name}' for guided json decoding.")
383+
print(f"Tool parameters: {tool.parameters}")
384+
return tool.parameters
385+
386+
if self.tool_choice == "required":
387+
# Pydantic schema generation cannot be used since the JSON schema
388+
# has to be constructed for a specific instantiation of a tool list
389+
# so that parameters of a function are correctly generated
390+
# based on the chosen function name
391+
def get_tool_schema(tool: ToolChoiceFunction) -> dict:
392+
return {
393+
"properties": {
394+
"name": {
395+
"type": "string",
396+
"enum": [tool.name]
397+
},
398+
# parameters are always generated as '{}' in the final
399+
# output if they are missing from the request
400+
# (i.e. are None or '{}') so the schema is
401+
# updated to produce an empty object in that case
402+
"parameters": tool.parameters if tool.parameters else {
403+
"type": "object",
404+
"properties": {}
405+
}
406+
},
407+
"required": ["name", "parameters"]
408+
}
409+
410+
def get_tool_schema_defs(tools: list[ToolChoiceFunction]) -> dict:
411+
all_defs = dict[str, dict[str, Any]]()
412+
for tool in tools:
413+
if tool.parameters is None:
414+
continue
415+
defs = tool.parameters.pop("$defs", {})
416+
for def_name, def_schema in defs.items():
417+
if def_name in all_defs and all_defs[
418+
def_name] != def_schema:
419+
raise ValueError(
420+
f"Tool definition '{def_name}' has "
421+
"multiple schemas, which is not "
422+
"supported.")
423+
else:
424+
all_defs[def_name] = def_schema
425+
return all_defs
426+
427+
json_schema = {
428+
"type": "array",
429+
"minItems": 1,
430+
"items": {
431+
"type": "object",
432+
"anyOf": [get_tool_schema(tool) for tool in self.tools]
433+
}
434+
}
435+
json_schema_defs = get_tool_schema_defs(self.tools)
436+
if json_schema_defs:
437+
json_schema["$defs"] = json_schema_defs
438+
print("Using tool choice 'required' for guided json decoding.")
439+
print(f"JSON schema: {json_schema}")
440+
return json_schema
441+
442+
return None
443+
444+
def _get_guided_decoding(self) -> Optional[GuidedDecodingParams]:
445+
# Structured output
446+
guided_decoding = None
447+
if self.text is not None and self.text.format is not None:
448+
response_format = self.text.format
449+
if response_format.type == "json_schema":
450+
guided_decoding = GuidedDecodingParams.from_optional(
451+
json=response_format.schema_)
452+
elif response_format.type == "json_object":
453+
raise NotImplementedError("json_object is not supported")
454+
# Function call
455+
elif self.tool_choice != "none" or self.tools is not None:
456+
guided_decoding = GuidedDecodingParams.from_optional(
457+
json=self._get_guided_json_from_tool())
458+
return guided_decoding
459+
376460

377461
class ChatCompletionRequest(OpenAIBaseModel):
378462
# Ordered by official OpenAI API documentation
@@ -1719,7 +1803,8 @@ class ResponsesResponse(OpenAIBaseModel):
17191803
metadata: Optional[Metadata] = None
17201804
model: str
17211805
object: Literal["response"] = "response"
1722-
output: list[Union[ResponseOutputMessage, ResponseReasoningItem]]
1806+
output: list[Union[ResponseOutputMessage, ResponseReasoningItem,
1807+
ResponseFunctionToolCall]]
17231808
parallel_tool_calls: bool
17241809
temperature: float
17251810
tool_choice: ToolChoice

vllm/entrypoints/openai/serving_engine.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -908,8 +908,11 @@ async def _preprocess_chat(
908908
request, "tool_choice") and request.tool_choice != "none")
909909

910910
if should_parse_tools:
911-
if not isinstance(request, ChatCompletionRequest):
912-
msg = "Tool usage is only supported for Chat Completions API"
911+
if not isinstance(request,
912+
ChatCompletionRequest) and not isinstance(
913+
request, ResponsesRequest):
914+
msg = "Tool usage is only supported for Chat Completions API " \
915+
"and Responses API requests."
913916
raise NotImplementedError(msg)
914917

915918
request = tool_parser(tokenizer).adjust_request( # type: ignore

vllm/entrypoints/openai/serving_responses.py

Lines changed: 103 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,18 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

44
import asyncio
5+
import json
56
import time
67
from collections.abc import AsyncGenerator, AsyncIterator
78
from http import HTTPStatus
89
from typing import Callable, Final, Optional, Union
910

1011
import jinja2
1112
from fastapi import Request
12-
from openai.types.responses import ResponseOutputMessage, ResponseOutputText
13+
from openai.types.responses import (ResponseFunctionToolCall,
14+
ResponseOutputMessage, ResponseOutputText,
15+
ToolChoiceFunction)
16+
from pydantic import TypeAdapter
1317

1418
from vllm.config import ModelConfig
1519
from vllm.engine.protocol import EngineClient
@@ -18,7 +22,8 @@
1822
from vllm.entrypoints.logger import RequestLogger
1923
# yapf conflicts with isort for this block
2024
# yapf: disable
21-
from vllm.entrypoints.openai.protocol import (ErrorResponse,
25+
from vllm.entrypoints.openai.protocol import (ErrorResponse, FunctionCall,
26+
FunctionDefinition,
2227
PromptTokenUsageInfo,
2328
RequestResponseMetadata,
2429
ResponseReasoningItem,
@@ -27,12 +32,13 @@
2732
# yapf: enable
2833
from vllm.entrypoints.openai.serving_engine import OpenAIServing
2934
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
35+
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
3036
from vllm.logger import init_logger
3137
from vllm.outputs import RequestOutput
3238
from vllm.reasoning import ReasoningParser, ReasoningParserManager
3339
from vllm.sampling_params import SamplingParams
3440
from vllm.transformers_utils.tokenizer import AnyTokenizer
35-
from vllm.utils import random_uuid
41+
from vllm.utils import random_fc_uuid, random_uuid
3642

3743
logger = init_logger(__name__)
3844

@@ -63,7 +69,18 @@ def __init__(
6369
return_tokens_as_token_ids=return_tokens_as_token_ids,
6470
enable_force_include_usage=enable_force_include_usage,
6571
)
66-
72+
self.enable_auto_tools = enable_auto_tools
73+
self.expand_tools_even_if_tool_choice_none = (
74+
expand_tools_even_if_tool_choice_none)
75+
self.tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None
76+
if self.enable_auto_tools:
77+
try:
78+
self.tool_parser = ToolParserManager.get_tool_parser(
79+
tool_parser)
80+
except Exception as e:
81+
raise TypeError("Error: --enable-auto-tool-choice requires "
82+
f"tool_parser:'{tool_parser}' which has not "
83+
"been registered") from e
6784
self.chat_template = chat_template
6885
self.chat_template_content_format: Final = chat_template_content_format
6986

@@ -139,11 +156,30 @@ async def create_responses(
139156
) = self._maybe_get_adapters(request)
140157
model_name = self._get_model_name(request.model, lora_request)
141158
tokenizer = await self.engine_client.get_tokenizer(lora_request)
142-
159+
if request.tools is None:
160+
tool_dicts = None
161+
elif (request.tool_choice == "none"
162+
and not self.expand_tools_even_if_tool_choice_none):
163+
if len(request.tools) > 0:
164+
logger.warning_once(
165+
"Tools are specified but tool_choice is set to 'none' "
166+
"and --expand-tools-even-if-tool-choice-none is not "
167+
"enabled. Tool definitions will be excluded from the "
168+
"prompt. This behavior will change in vLLM v0.10 where "
169+
"tool definitions will be included by default even "
170+
"with tool_choice='none'. To adopt the new behavior "
171+
"now, use --expand-tools-even-if-tool-choice-none. "
172+
"To suppress this warning, either remove tools from "
173+
"the request or set tool_choice to a different value.")
174+
tool_dicts = None
175+
else:
176+
tool_dicts = [tool.model_dump() for tool in request.tools]
143177
_, request_prompts, engine_prompts = await self._preprocess_chat(
144178
request,
145179
tokenizer,
146180
messages,
181+
tool_dicts=tool_dicts,
182+
tool_parser=self.tool_parser,
147183
chat_template=self.chat_template,
148184
chat_template_content_format=self.chat_template_content_format,
149185
)
@@ -287,28 +323,82 @@ async def responses_full_generator(
287323
reasoning_content = None
288324
content = final_output.text
289325

290-
output = []
291-
if reasoning_content:
292-
reasoning_item = ResponseReasoningItem(
326+
outputs = []
327+
output = None
328+
if self.tool_parser:
329+
function_calls: list[FunctionCall] = []
330+
if request.tool_choice and \
331+
isinstance(request.tool_choice,
332+
ToolChoiceFunction):
333+
# Forced Function Call
334+
function_calls.append(
335+
FunctionCall(name=request.tool_choice.name,
336+
arguments=content))
337+
elif request.tool_choice is None or request.tool_choice == "none":
338+
pass
339+
elif request.tool_choice == "required":
340+
tool_calls = TypeAdapter(
341+
list[FunctionDefinition]).validate_json(content)
342+
function_calls.extend([
343+
FunctionCall(name=tool_call.name,
344+
arguments=json.dumps(tool_call.parameters,
345+
ensure_ascii=False))
346+
for tool_call in tool_calls
347+
])
348+
elif request.tool_choice == "auto":
349+
try:
350+
tool_parser = self.tool_parser(tokenizer)
351+
except RuntimeError as e:
352+
logger.exception("Error in tool parser creation.")
353+
return self.create_error_response(str(e))
354+
tool_call_info = tool_parser.extract_tool_calls(
355+
content if content is not None else "", request=request)
356+
if tool_call_info is not None and tool_call_info.tools_called:
357+
function_calls.extend(
358+
FunctionCall(
359+
name=tool_call.function.name,
360+
arguments=tool_call.function.arguments,
361+
) for tool_call in tool_call_info.tool_calls)
362+
else:
363+
logger.warning(
364+
"Unknown tool choice: %s. "
365+
"Using 'none' as the default tool choice.",
366+
request.tool_choice)
367+
output = [
368+
ResponseFunctionToolCall(
369+
id=f"fc_{random_fc_uuid()}",
370+
call_id=f"call_{random_uuid()}",
371+
type="function_call",
372+
status="completed",
373+
name=tool_call.name,
374+
arguments=tool_call.arguments,
375+
) for tool_call in function_calls
376+
]
377+
# If no tool call is generated, we still need to return an output.
378+
if reasoning_content and output is None:
379+
output = ResponseReasoningItem(
293380
text=reasoning_content,
294381
status=None, # NOTE: Only the last output item has status.
295382
)
296-
output.append(reasoning_item)
297-
if content:
383+
# If no tool call is generated, we still need to return an output.
384+
if content and output is None:
298385
output_text = ResponseOutputText(
299386
text=content,
300387
annotations=[], # TODO
301388
type="output_text",
302389
logprobs=None, # TODO
303390
)
304-
message = ResponseOutputMessage(
391+
output = ResponseOutputMessage(
305392
id=f"msg_{random_uuid()}",
306393
content=[output_text],
307394
role="assistant",
308395
status="completed",
309396
type="message",
310397
)
311-
output.append(message)
398+
if isinstance(output, list):
399+
outputs.extend(output)
400+
else:
401+
outputs.append(output)
312402

313403
# Calculate usage.
314404
assert final_res.prompt_token_ids is not None
@@ -329,7 +419,7 @@ async def responses_full_generator(
329419
sampling_params,
330420
model_name=model_name,
331421
created_time=created_time,
332-
output=output,
422+
output=outputs,
333423
status="completed",
334424
usage=usage,
335425
)

vllm/utils/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -510,6 +510,11 @@ def random_uuid() -> str:
510510
return str(uuid.uuid4().hex)
511511

512512

513+
def random_fc_uuid() -> str:
514+
"""Generates a random UUID for function call tool outputs."""
515+
return str(os.urandom(24).hex())
516+
517+
513518
class AsyncMicrobatchTokenizer:
514519
"""Asynchronous tokenizer with micro-batching.
515520

0 commit comments

Comments
 (0)