Skip to content

Commit 6b7eb35

Browse files
committed
[Frontend] OpenAI Responses API supports Tool/Function calling
Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com>
1 parent 8f2720d commit 6b7eb35

File tree

4 files changed

+211
-28
lines changed

4 files changed

+211
-28
lines changed

vllm/entrypoints/openai/protocol.py

Lines changed: 98 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,11 @@
1111
import regex as re
1212
import torch
1313
from fastapi import HTTPException, UploadFile
14-
from openai.types.responses import (ResponseInputParam, ResponseOutputItem,
14+
from openai.types.responses import (ResponseFunctionToolCall,
15+
ResponseInputParam, ResponseOutputItem,
1516
ResponseOutputMessage, ResponsePrompt,
16-
ResponseStatus, ResponseTextConfig)
17+
ResponseStatus, ResponseTextConfig,
18+
ToolChoiceFunction)
1719
from openai.types.responses.response import ToolChoice
1820
from openai.types.responses.tool import Tool
1921
from openai.types.shared import Metadata, Reasoning
@@ -307,16 +309,7 @@ def to_sampling_params(
307309
top_p = default_sampling_params.get(
308310
"top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"])
309311

310-
# Structured output
311-
guided_decoding = None
312-
if self.text is not None and self.text.format is not None:
313-
response_format = self.text.format
314-
if response_format.type == "json_schema":
315-
guided_decoding = GuidedDecodingParams.from_optional(
316-
json=response_format.schema_)
317-
elif response_format.type == "json_object":
318-
raise NotImplementedError("json_object is not supported")
319-
312+
guided_decoding = self._get_guided_decoding()
320313
# TODO: add more parameters
321314
return SamplingParams.from_optional(
322315
temperature=temperature,
@@ -343,6 +336,97 @@ def validate_prompt(cls, data):
343336
raise ValueError("prompt template is not supported")
344337
return data
345338

339+
def _get_guided_json_from_tool(
340+
self) -> Optional[Union[str, dict, BaseModel]]:
341+
print(
342+
f"Tool choice: {self.tool_choice}, type: {type(self.tool_choice)}")
343+
# user has chosen to use a named tool
344+
if type(self.tool_choice) is ToolChoiceFunction:
345+
tool_name = self.tool_choice.name
346+
tools = {tool.name: tool for tool in \
347+
self.tools if tool.type == "function"}
348+
if tool_name not in tools:
349+
raise ValueError(
350+
f"Tool '{tool_name}' has not been passed in `tools`.")
351+
tool = tools[tool_name]
352+
print(f"Using tool '{tool_name}' for guided json decoding.")
353+
print(f"Tool parameters: {tool.parameters}")
354+
return tool.parameters
355+
356+
if self.tool_choice == "required":
357+
# Pydantic schema generation cannot be used since the JSON schema
358+
# has to be constructed for a specific instantiation of a tool list
359+
# so that parameters of a function are correctly generated
360+
# based on the chosen function name
361+
def get_tool_schema(tool: ToolChoiceFunction) -> dict:
362+
return {
363+
"properties": {
364+
"name": {
365+
"type": "string",
366+
"enum": [tool.name]
367+
},
368+
# parameters are always generated as '{}' in the final
369+
# output if they are missing from the request
370+
# (i.e. are None or '{}') so the schema is
371+
# updated to produce an empty object in that case
372+
"parameters": tool.parameters if tool.parameters else {
373+
"type": "object",
374+
"properties": {}
375+
}
376+
},
377+
"required": ["name", "parameters"]
378+
}
379+
380+
def get_tool_schema_defs(tools: list[ToolChoiceFunction]) -> dict:
381+
all_defs = dict[str, dict[str, Any]]()
382+
for tool in tools:
383+
if tool.parameters is None:
384+
continue
385+
defs = tool.parameters.pop("$defs", {})
386+
for def_name, def_schema in defs.items():
387+
if def_name in all_defs and all_defs[
388+
def_name] != def_schema:
389+
raise ValueError(
390+
f"Tool definition '{def_name}' has "
391+
"multiple schemas, which is not "
392+
"supported.")
393+
else:
394+
all_defs[def_name] = def_schema
395+
return all_defs
396+
397+
json_schema = {
398+
"type": "array",
399+
"minItems": 1,
400+
"items": {
401+
"type": "object",
402+
"anyOf": [get_tool_schema(tool) for tool in self.tools]
403+
}
404+
}
405+
json_schema_defs = get_tool_schema_defs(self.tools)
406+
if json_schema_defs:
407+
json_schema["$defs"] = json_schema_defs
408+
print("Using tool choice 'required' for guided json decoding.")
409+
print(f"JSON schema: {json_schema}")
410+
return json_schema
411+
412+
return None
413+
414+
def _get_guided_decoding(self) -> Optional[GuidedDecodingParams]:
415+
# Structured output
416+
guided_decoding = None
417+
if self.text is not None and self.text.format is not None:
418+
response_format = self.text.format
419+
if response_format.type == "json_schema":
420+
guided_decoding = GuidedDecodingParams.from_optional(
421+
json=response_format.schema_)
422+
elif response_format.type == "json_object":
423+
raise NotImplementedError("json_object is not supported")
424+
# Function call
425+
elif self.tool_choice != "none" or self.tools is not None:
426+
guided_decoding = GuidedDecodingParams.from_optional(
427+
json=self._get_guided_json_from_tool())
428+
return guided_decoding
429+
346430

347431
class ChatCompletionRequest(OpenAIBaseModel):
348432
# Ordered by official OpenAI API documentation
@@ -1636,7 +1720,8 @@ class ResponsesResponse(OpenAIBaseModel):
16361720
metadata: Optional[Metadata] = None
16371721
model: str
16381722
object: Literal["response"] = "response"
1639-
output: list[Union[ResponseOutputMessage, ResponseReasoningItem]]
1723+
output: list[Union[ResponseOutputMessage, ResponseReasoningItem,
1724+
ResponseFunctionToolCall]]
16401725
parallel_tool_calls: bool
16411726
temperature: float
16421727
tool_choice: ToolChoice

vllm/entrypoints/openai/serving_engine.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -850,8 +850,11 @@ async def _preprocess_chat(
850850
request, "tool_choice") and request.tool_choice != "none")
851851

852852
if should_parse_tools:
853-
if not isinstance(request, ChatCompletionRequest):
854-
msg = "Tool usage is only supported for Chat Completions API"
853+
if not isinstance(request,
854+
ChatCompletionRequest) and not isinstance(
855+
request, ResponsesRequest):
856+
msg = "Tool usage is only supported for Chat Completions API " \
857+
"and Responses API requests."
855858
raise NotImplementedError(msg)
856859

857860
request = tool_parser(tokenizer).adjust_request( # type: ignore

vllm/entrypoints/openai/serving_responses.py

Lines changed: 103 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,18 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

44
import asyncio
5+
import json
56
import time
67
from collections.abc import AsyncGenerator, AsyncIterator
78
from http import HTTPStatus
89
from typing import Callable, Final, Optional, Union
910

1011
import jinja2
1112
from fastapi import Request
12-
from openai.types.responses import ResponseOutputMessage, ResponseOutputText
13+
from openai.types.responses import (ResponseFunctionToolCall,
14+
ResponseOutputMessage, ResponseOutputText,
15+
ToolChoiceFunction)
16+
from pydantic import TypeAdapter
1317

1418
from vllm.config import ModelConfig
1519
from vllm.engine.protocol import EngineClient
@@ -18,7 +22,8 @@
1822
from vllm.entrypoints.logger import RequestLogger
1923
# yapf conflicts with isort for this block
2024
# yapf: disable
21-
from vllm.entrypoints.openai.protocol import (ErrorResponse,
25+
from vllm.entrypoints.openai.protocol import (ErrorResponse, FunctionCall,
26+
FunctionDefinition,
2227
PromptTokenUsageInfo,
2328
RequestResponseMetadata,
2429
ResponseReasoningItem,
@@ -27,12 +32,13 @@
2732
# yapf: enable
2833
from vllm.entrypoints.openai.serving_engine import OpenAIServing
2934
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
35+
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
3036
from vllm.logger import init_logger
3137
from vllm.outputs import RequestOutput
3238
from vllm.reasoning import ReasoningParser, ReasoningParserManager
3339
from vllm.sampling_params import SamplingParams
3440
from vllm.transformers_utils.tokenizer import AnyTokenizer
35-
from vllm.utils import random_uuid
41+
from vllm.utils import random_fc_uuid, random_uuid
3642

3743
logger = init_logger(__name__)
3844

@@ -64,7 +70,18 @@ def __init__(
6470
return_tokens_as_token_ids=return_tokens_as_token_ids,
6571
enable_force_include_usage=enable_force_include_usage,
6672
)
67-
73+
self.enable_auto_tools = enable_auto_tools
74+
self.expand_tools_even_if_tool_choice_none = (
75+
expand_tools_even_if_tool_choice_none)
76+
self.tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None
77+
if self.enable_auto_tools:
78+
try:
79+
self.tool_parser = ToolParserManager.get_tool_parser(
80+
tool_parser)
81+
except Exception as e:
82+
raise TypeError("Error: --enable-auto-tool-choice requires "
83+
f"tool_parser:'{tool_parser}' which has not "
84+
"been registered") from e
6885
self.chat_template = chat_template
6986
self.chat_template_content_format: Final = chat_template_content_format
7087

@@ -140,11 +157,30 @@ async def create_responses(
140157
) = self._maybe_get_adapters(request)
141158
model_name = self._get_model_name(request.model, lora_request)
142159
tokenizer = await self.engine_client.get_tokenizer(lora_request)
143-
160+
if request.tools is None:
161+
tool_dicts = None
162+
elif (request.tool_choice == "none"
163+
and not self.expand_tools_even_if_tool_choice_none):
164+
if len(request.tools) > 0:
165+
logger.warning_once(
166+
"Tools are specified but tool_choice is set to 'none' "
167+
"and --expand-tools-even-if-tool-choice-none is not "
168+
"enabled. Tool definitions will be excluded from the "
169+
"prompt. This behavior will change in vLLM v0.10 where "
170+
"tool definitions will be included by default even "
171+
"with tool_choice='none'. To adopt the new behavior "
172+
"now, use --expand-tools-even-if-tool-choice-none. "
173+
"To suppress this warning, either remove tools from "
174+
"the request or set tool_choice to a different value.")
175+
tool_dicts = None
176+
else:
177+
tool_dicts = [tool.model_dump() for tool in request.tools]
144178
_, request_prompts, engine_prompts = await self._preprocess_chat(
145179
request,
146180
tokenizer,
147181
messages,
182+
tool_dicts=tool_dicts,
183+
tool_parser=self.tool_parser,
148184
chat_template=self.chat_template,
149185
chat_template_content_format=self.chat_template_content_format,
150186
)
@@ -288,28 +324,82 @@ async def responses_full_generator(
288324
reasoning_content = None
289325
content = final_output.text
290326

291-
output = []
292-
if reasoning_content:
293-
reasoning_item = ResponseReasoningItem(
327+
outputs = []
328+
output = None
329+
if self.tool_parser:
330+
function_calls: list[FunctionCall] = []
331+
if request.tool_choice and \
332+
isinstance(request.tool_choice,
333+
ToolChoiceFunction):
334+
# Forced Function Call
335+
function_calls.append(
336+
FunctionCall(name=request.tool_choice.name,
337+
arguments=content))
338+
elif request.tool_choice is None or request.tool_choice == "none":
339+
pass
340+
elif request.tool_choice == "required":
341+
tool_calls = TypeAdapter(
342+
list[FunctionDefinition]).validate_json(content)
343+
function_calls.extend([
344+
FunctionCall(name=tool_call.name,
345+
arguments=json.dumps(tool_call.parameters,
346+
ensure_ascii=False))
347+
for tool_call in tool_calls
348+
])
349+
elif request.tool_choice == "auto":
350+
try:
351+
tool_parser = self.tool_parser(tokenizer)
352+
except RuntimeError as e:
353+
logger.exception("Error in tool parser creation.")
354+
return self.create_error_response(str(e))
355+
tool_call_info = tool_parser.extract_tool_calls(
356+
content if content is not None else "", request=request)
357+
if tool_call_info is not None and tool_call_info.tools_called:
358+
function_calls.extend(
359+
FunctionCall(
360+
name=tool_call.function.name,
361+
arguments=tool_call.function.arguments,
362+
) for tool_call in tool_call_info.tool_calls)
363+
else:
364+
logger.warning(
365+
"Unknown tool choice: %s. "
366+
"Using 'none' as the default tool choice.",
367+
request.tool_choice)
368+
output = [
369+
ResponseFunctionToolCall(
370+
id=f"fc_{random_fc_uuid()}",
371+
call_id=f"call_{random_uuid()}",
372+
type="function_call",
373+
status="completed",
374+
name=tool_call.name,
375+
arguments=tool_call.arguments,
376+
) for tool_call in function_calls
377+
]
378+
# If no tool call is generated, we still need to return an output.
379+
if reasoning_content and output is None:
380+
output = ResponseReasoningItem(
294381
text=reasoning_content,
295382
status=None, # NOTE: Only the last output item has status.
296383
)
297-
output.append(reasoning_item)
298-
if content:
384+
# If no tool call is generated, we still need to return an output.
385+
if content and output is None:
299386
output_text = ResponseOutputText(
300387
text=content,
301388
annotations=[], # TODO
302389
type="output_text",
303390
logprobs=None, # TODO
304391
)
305-
message = ResponseOutputMessage(
392+
output = ResponseOutputMessage(
306393
id=f"msg_{random_uuid()}",
307394
content=[output_text],
308395
role="assistant",
309396
status="completed",
310397
type="message",
311398
)
312-
output.append(message)
399+
if isinstance(output, list):
400+
outputs.extend(output)
401+
else:
402+
outputs.append(output)
313403

314404
# Calculate usage.
315405
assert final_res.prompt_token_ids is not None
@@ -330,7 +420,7 @@ async def responses_full_generator(
330420
sampling_params,
331421
model_name=model_name,
332422
created_time=created_time,
333-
output=output,
423+
output=outputs,
334424
status="completed",
335425
usage=usage,
336426
)

vllm/utils/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -509,6 +509,11 @@ def random_uuid() -> str:
509509
return str(uuid.uuid4().hex)
510510

511511

512+
def random_fc_uuid() -> str:
513+
"""Generates a random UUID for function call tool outputs."""
514+
return str(os.urandom(24).hex())
515+
516+
512517
class AsyncMicrobatchTokenizer:
513518
"""Asynchronous tokenizer with micro-batching.
514519

0 commit comments

Comments
 (0)