Skip to content

Commit 4f502de

Browse files
zuxin666minpeter
authored andcommitted
Add xLAM tool parser support (vllm-project#17148)
Signed-off-by: minpeter <kali2005611@gmail.com>
1 parent feab23b commit 4f502de

File tree

8 files changed

+1389
-1
lines changed

8 files changed

+1389
-1
lines changed

docs/features/tool_calling.md

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,25 @@ AI21's Jamba-1.5 models are supported.
226226

227227
Flags: `--tool-call-parser jamba`
228228

229+
### xLAM Models (`xlam`)
230+
231+
The xLAM tool parser is designed to support models that generate tool calls in various JSON formats. It detects function calls in several different output styles:
232+
233+
1. Direct JSON arrays: Output strings that are JSON arrays starting with `[` and ending with `]`
234+
2. Thinking tags: Using `<think>...</think>` tags containing JSON arrays
235+
3. Code blocks: JSON in code blocks (```json ...```)
236+
4. Tool calls tags: Using `[TOOL_CALLS]` or `<tool_call>...</tool_call>` tags
237+
238+
Parallel function calls are supported, and the parser can effectively separate text content from tool calls.
239+
240+
Supported models:
241+
* Salesforce Llama-xLAM models: `Salesforce/Llama-xLAM-2-8B-fc-r`, `Salesforce/Llama-xLAM-2-70B-fc-r`
242+
* Qwen-xLAM models: `Salesforce/xLAM-1B-fc-r`, `Salesforce/xLAM-3B-fc-r`, `Salesforce/Qwen-xLAM-32B-fc-r`
243+
244+
Flags:
245+
* For Llama-based xLAM models: `--tool-call-parser xlam --chat-template examples/tool_chat_template_xlam_llama.jinja`
246+
* For Qwen-based xLAM models: `--tool-call-parser xlam --chat-template examples/tool_chat_template_xlam_qwen.jinja`
247+
229248
### Qwen Models
230249

231250
For Qwen2.5, the chat template in tokenizer_config.json has already included support for the Hermes-style tool use. Therefore, you can use the `hermes` parser to enable tool calls for Qwen models. For more detailed information, please refer to the official [Qwen documentation](https://qwen.readthedocs.io/en/latest/framework/function_call.html#vllm)
Lines changed: 244 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,244 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# ruff: noqa: E501
3+
"""
4+
Set up this example by starting a vLLM OpenAI-compatible server with tool call
5+
options enabled for xLAM-2 models:
6+
7+
vllm serve --model Salesforce/Llama-xLAM-2-8b-fc-r --enable-auto-tool-choice --tool-call-parser xlam
8+
9+
OR
10+
11+
vllm serve --model Salesforce/xLAM-2-3b-fc-r --enable-auto-tool-choice --tool-call-parser xlam
12+
"""
13+
14+
import json
15+
import time
16+
17+
from openai import OpenAI
18+
19+
# Modify OpenAI's API key and API base to use vLLM's API server.
20+
openai_api_key = "empty"
21+
openai_api_base = "http://localhost:8000/v1"
22+
23+
24+
# Define tool functions
25+
def get_weather(location: str, unit: str):
26+
return f"Weather in {location} is 22 degrees {unit}."
27+
28+
29+
def calculate_expression(expression: str):
30+
try:
31+
result = eval(expression)
32+
return f"The result of {expression} is {result}"
33+
except Exception as e:
34+
return f"Could not calculate {expression}: {e}"
35+
36+
37+
def translate_text(text: str, target_language: str):
38+
return f"Translation of '{text}' to {target_language}: [translated content]"
39+
40+
41+
# Define tools
42+
tools = [
43+
{
44+
"type": "function",
45+
"function": {
46+
"name": "get_weather",
47+
"description": "Get the current weather in a given location",
48+
"parameters": {
49+
"type": "object",
50+
"properties": {
51+
"location": {
52+
"type": "string",
53+
"description": "City and state, e.g., 'San Francisco, CA'",
54+
},
55+
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
56+
},
57+
"required": ["location", "unit"],
58+
},
59+
},
60+
},
61+
{
62+
"type": "function",
63+
"function": {
64+
"name": "calculate_expression",
65+
"description": "Calculate a mathematical expression",
66+
"parameters": {
67+
"type": "object",
68+
"properties": {
69+
"expression": {
70+
"type": "string",
71+
"description": "Mathematical expression to evaluate, needs to be a valid python expression",
72+
}
73+
},
74+
"required": ["expression"],
75+
},
76+
},
77+
},
78+
{
79+
"type": "function",
80+
"function": {
81+
"name": "translate_text",
82+
"description": "Translate text to another language",
83+
"parameters": {
84+
"type": "object",
85+
"properties": {
86+
"text": {"type": "string", "description": "Text to translate"},
87+
"target_language": {
88+
"type": "string",
89+
"description": "Target language for translation",
90+
},
91+
},
92+
"required": ["text", "target_language"],
93+
},
94+
},
95+
},
96+
]
97+
98+
# Map of function names to implementations
99+
tool_functions = {
100+
"get_weather": get_weather,
101+
"calculate_expression": calculate_expression,
102+
"translate_text": translate_text,
103+
}
104+
105+
106+
def process_response(response, tool_functions, original_query):
107+
"""Process a non-streaming response with possible tool calls"""
108+
109+
print("\n--- Response Output ---")
110+
111+
# Check if the response has content
112+
if response.choices[0].message.content:
113+
print(f"Content: {response.choices[0].message.content}")
114+
115+
# Check if the response has tool calls
116+
if response.choices[0].message.tool_calls:
117+
print("--------------------------------")
118+
print(f"Tool calls: {response.choices[0].message.tool_calls}")
119+
print("--------------------------------")
120+
121+
# Collect all tool calls and results before making follow-up request
122+
tool_results = []
123+
assistant_message = {"role": "assistant"}
124+
125+
if response.choices[0].message.content:
126+
assistant_message["content"] = response.choices[0].message.content
127+
128+
assistant_tool_calls = []
129+
130+
# Process each tool call
131+
for tool_call in response.choices[0].message.tool_calls:
132+
function_name = tool_call.function.name
133+
function_args = tool_call.function.arguments
134+
function_id = tool_call.id
135+
136+
print(f"Function called: {function_name}")
137+
print(f"Arguments: {function_args}")
138+
print(f"Function ID: {function_id}")
139+
140+
# Execute the function
141+
try:
142+
# Parse the JSON arguments
143+
args = json.loads(function_args)
144+
145+
# Call the function with the arguments
146+
function_result = tool_functions[function_name](**args)
147+
print(f"\n--- Function Result ---\n{function_result}\n")
148+
149+
# Add tool call to assistant message
150+
assistant_tool_calls.append(
151+
{
152+
"id": function_id,
153+
"type": "function",
154+
"function": {"name": function_name, "arguments": function_args},
155+
}
156+
)
157+
158+
# Add tool result to tool_results
159+
tool_results.append(
160+
{
161+
"role": "tool",
162+
"tool_call_id": function_id,
163+
"content": function_result,
164+
}
165+
)
166+
167+
except Exception as e:
168+
print(f"Error executing function: {e}")
169+
170+
# Add tool_calls to assistant message
171+
assistant_message["tool_calls"] = assistant_tool_calls
172+
173+
# Create a follow-up message with all function results
174+
follow_up_messages = [
175+
{"role": "user", "content": original_query},
176+
assistant_message,
177+
]
178+
179+
# Add all tool results to the messages
180+
follow_up_messages.extend(tool_results)
181+
182+
# Get completion with all tool results in a single follow-up
183+
follow_up_response = client.chat.completions.create(
184+
model=client.models.list().data[0].id,
185+
messages=follow_up_messages,
186+
stream=False,
187+
)
188+
189+
print("\n--- Follow-up Response ---")
190+
print(follow_up_response.choices[0].message.content)
191+
print("--- End Follow-up ---\n")
192+
193+
print("--- End Response ---\n")
194+
195+
196+
def run_test_case(query, test_name):
197+
"""Run a single test case with the given query"""
198+
print(f"\n{'=' * 50}\nTEST CASE: {test_name}\n{'=' * 50}")
199+
print(f"Query: '{query}'")
200+
201+
start_time = time.time()
202+
203+
# Create non-streaming chat completion request
204+
response = client.chat.completions.create(
205+
model=client.models.list().data[0].id,
206+
messages=[{"role": "user", "content": query}],
207+
tools=tools,
208+
tool_choice="auto",
209+
stream=False,
210+
)
211+
212+
# Process the non-streaming response, passing the original query
213+
process_response(response, tool_functions, query)
214+
215+
end_time = time.time()
216+
print(f"Test completed in {end_time - start_time:.2f} seconds")
217+
218+
219+
def main():
220+
# Initialize OpenAI client
221+
global client
222+
client = OpenAI(
223+
api_key=openai_api_key,
224+
base_url=openai_api_base,
225+
)
226+
227+
# Run test cases
228+
test_cases = [
229+
("I want to know the weather in San Francisco", "Weather Information"),
230+
("Calculate 25 * 17 + 31", "Math Calculation"),
231+
("Translate 'Hello world' to Spanish", "Text Translation"),
232+
("What is the weather in Tokyo and New York in celsius", "Multiple Tool Usage"),
233+
]
234+
235+
# Execute all test cases
236+
for query, test_name in test_cases:
237+
run_test_case(query, test_name)
238+
time.sleep(1) # Small delay between tests
239+
240+
print("\nAll tests completed.")
241+
242+
243+
if __name__ == "__main__":
244+
main()

0 commit comments

Comments
 (0)