Skip to content

Commit 8f2720d

Browse files
[Frontend] Support Tool Calling with both tool_choice='required' and $defs. (#20629)
Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com>
1 parent ad6c2e1 commit 8f2720d

File tree

2 files changed

+56
-0
lines changed

2 files changed

+56
-0
lines changed

tests/entrypoints/openai/test_completion_with_function_calling.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,43 @@ async def test_function_tool_use(client: openai.AsyncOpenAI, model_name: str,
7272
"The unit to fetch the temperature in",
7373
"enum": ["celsius", "fahrenheit"],
7474
},
75+
"options": {
76+
"$ref": "#/$defs/WeatherOptions",
77+
"description":
78+
"Optional parameters for weather query",
79+
},
7580
},
7681
"required": ["country", "unit"],
82+
"$defs": {
83+
"WeatherOptions": {
84+
"title": "WeatherOptions",
85+
"type": "object",
86+
"additionalProperties": False,
87+
"properties": {
88+
"unit": {
89+
"type": "string",
90+
"enum": ["celsius", "fahrenheit"],
91+
"default": "celsius",
92+
"description": "Temperature unit",
93+
"title": "Temperature Unit",
94+
},
95+
"include_forecast": {
96+
"type": "boolean",
97+
"default": False,
98+
"description":
99+
"Whether to include a 24-hour forecast",
100+
"title": "Include Forecast",
101+
},
102+
"language": {
103+
"type": "string",
104+
"default": "zh-CN",
105+
"description": "Language of the response",
106+
"title": "Language",
107+
"enum": ["zh-CN", "en-US", "ja-JP"],
108+
},
109+
},
110+
},
111+
},
77112
},
78113
},
79114
},

vllm/entrypoints/openai/protocol.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -707,6 +707,24 @@ def get_tool_schema(tool: ChatCompletionToolsParam) -> dict:
707707
"required": ["name", "parameters"]
708708
}
709709

710+
def get_tool_schema_defs(
711+
tools: list[ChatCompletionToolsParam]) -> dict:
712+
all_defs = dict[str, dict[str, Any]]()
713+
for tool in tools:
714+
if tool.function.parameters is None:
715+
continue
716+
defs = tool.function.parameters.pop("$defs", {})
717+
for def_name, def_schema in defs.items():
718+
if def_name in all_defs and all_defs[
719+
def_name] != def_schema:
720+
raise ValueError(
721+
f"Tool definition '{def_name}' has "
722+
"multiple schemas, which is not "
723+
"supported.")
724+
else:
725+
all_defs[def_name] = def_schema
726+
return all_defs
727+
710728
json_schema = {
711729
"type": "array",
712730
"minItems": 1,
@@ -715,6 +733,9 @@ def get_tool_schema(tool: ChatCompletionToolsParam) -> dict:
715733
"anyOf": [get_tool_schema(tool) for tool in self.tools]
716734
}
717735
}
736+
json_schema_defs = get_tool_schema_defs(self.tools)
737+
if json_schema_defs:
738+
json_schema["$defs"] = json_schema_defs
718739
return json_schema
719740

720741
return None

0 commit comments

Comments
 (0)