Skip to content

Commit c3aac4e

Browse files
committed
tool-call: Phi-4 support
- Add system message if needed (per template requirement) - Add tools to system message (req'd by template) - Parse output: -- add tools to response when there is valid JSON between <|tool_call|> and </|tool_call|> -- content outside of tool_call tags is added to the text portion of the response -- if there is no valid JSON, the entire content is added to the text portion of the response
1 parent 92a3913 commit c3aac4e

File tree

5 files changed

+220
-1
lines changed

5 files changed

+220
-1
lines changed

common/chat.cpp

Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -448,6 +448,7 @@ std::string common_chat_format_name(common_chat_format format) {
448448
case COMMON_CHAT_FORMAT_HERMES_2_PRO_EXTRACT_REASONING: return "Hermes 2 Pro (extract reasoning)";
449449
case COMMON_CHAT_FORMAT_COMMAND_R7B: return "Command R7B";
450450
case COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING: return "Command R7B (extract reasoning)";
451+
case COMMON_CHAT_FORMAT_PHI_4: return "Phi-4";
451452
default:
452453
throw std::runtime_error("Unknown chat format");
453454
}
@@ -1356,6 +1357,184 @@ static common_chat_msg common_chat_parse_functionary_v3_1_llama_3_1(const std::s
13561357
return parse_json_tool_calls(input, std::nullopt, function_regex, close_regex);
13571358
}
13581359

1360+
static common_chat_params common_chat_params_init_phi_4(const common_chat_template & tmpl, const struct templates_params & inputs) {
1361+
// Phi-4 has a unique format that expects tools in the system message with <|tool|> tags
1362+
// and returns function calls as a JSON object after <|tool_call|> tag
1363+
common_chat_params data;
1364+
1365+
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
1366+
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
1367+
std::vector<std::string> tool_rules;
1368+
std::vector<std::string> tool_call_alts;
1369+
foreach_function(inputs.tools, [&](const json & tool) {
1370+
const auto & function = tool.at("function");
1371+
std::string name = function.at("name");
1372+
auto parameters = function.at("parameters");
1373+
builder.resolve_refs(parameters);
1374+
tool_rules.push_back(builder.add_schema(name + "-call", {
1375+
{"type", "object"},
1376+
{"properties", {
1377+
{"name", {{"const", name}}},
1378+
{"arguments", parameters},
1379+
}},
1380+
{"required", json::array({"name", "arguments"})},
1381+
}));
1382+
});
1383+
auto any_tool_call = builder.add_rule("any_tool_call", "( " + string_join(tool_rules, " | ") + " ) space");
1384+
std::vector<std::string> alt_tags {
1385+
any_tool_call,
1386+
};
1387+
tool_call_alts.push_back(any_tool_call);
1388+
auto tool_call = builder.add_rule("tool_call", string_join(tool_call_alts, " | "));
1389+
builder.add_rule("root", inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
1390+
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|tool_call|>"});
1391+
data.preserved_tokens = {
1392+
"<|tool_call|>",
1393+
"</|tool_call|>",
1394+
};
1395+
});
1396+
1397+
// For Phi-4, we need to inject tools into the system message
1398+
// because the template expects tools in the system message with <|tool|> tags
1399+
if (inputs.tools.empty()) {
1400+
// No tools, use normal approach
1401+
data.prompt = apply(tmpl, inputs.messages, json::array(), inputs.add_generation_prompt);
1402+
} else {
1403+
// Make a copy of messages that we can modify
1404+
json adjusted_messages = inputs.messages;
1405+
1406+
// Extract just the function part of the OpenAI-formatted tools
1407+
json phi4_tools = json::array();
1408+
foreach_function(inputs.tools, [&](const json & tool) {
1409+
phi4_tools.push_back(tool.at("function"));
1410+
});
1411+
1412+
// Phi-4 template expects tools in the system message with <|tool|> tags.
1413+
// Find the system message, or add one if it doesn't exist
1414+
bool found_system_msg = false;
1415+
for (auto & message : adjusted_messages) {
1416+
if (message.contains("role") && message["role"] == "system") {
1417+
// Add tools to the existing system message and update content to mention tools
1418+
message["tools"] = phi4_tools;
1419+
1420+
// If the system message doesn't mention tools, append that information
1421+
std::string content = message["content"];
1422+
if (content.find("tool") == std::string::npos &&
1423+
content.find("function") == std::string::npos) {
1424+
message["content"] = content + " You have access to some tools.";
1425+
}
1426+
1427+
found_system_msg = true;
1428+
break;
1429+
}
1430+
}
1431+
1432+
// If no system message, add one with tools
1433+
if (!found_system_msg && !adjusted_messages.empty()) {
1434+
json system_msg = {
1435+
{"role", "system"},
1436+
{"content", "You are a helpful assistant with access to tools.\nTo use a tool, respond in this format: <|tool_call|>{\"name\": \"foo\", \"arguments\": {\"a\": 1}}<|/tool_call|>"},
1437+
{"tools", phi4_tools}
1438+
};
1439+
// Insert system message at the beginning
1440+
adjusted_messages.insert(adjusted_messages.begin(), system_msg);
1441+
}
1442+
1443+
// Apply template with tools embedded in system message, passing empty tools separately
1444+
data.prompt = apply(tmpl, adjusted_messages, json(), inputs.add_generation_prompt);
1445+
}
1446+
1447+
data.format = COMMON_CHAT_FORMAT_PHI_4;
1448+
return data;
1449+
}
1450+
1451+
static common_chat_msg common_chat_parse_phi_4(const std::string & input) {
1452+
common_chat_msg result;
1453+
result.role = "assistant";
1454+
1455+
std::string final_content = "";
1456+
1457+
const std::string opening_tag = "<|tool_call|>";
1458+
const std::string closing_tag = "</|tool_call|>";
1459+
1460+
size_t start_pos = 0;
1461+
while (true) {
1462+
// Find next tool call
1463+
size_t tool_start = input.find(opening_tag, start_pos);
1464+
if (tool_start == std::string::npos) {
1465+
// No more tool calls.
1466+
1467+
// Is start_pos within string bounds?
1468+
if (start_pos < input.length()) {
1469+
// Add the rest of the string to final_content
1470+
final_content += input.substr(start_pos);
1471+
}
1472+
break;
1473+
}
1474+
1475+
// Add content before the tool call to final_content
1476+
final_content += input.substr(start_pos, tool_start - start_pos);
1477+
1478+
// Find closing tag
1479+
size_t content_start = tool_start + opening_tag.length();
1480+
size_t tool_end = input.find(closing_tag, content_start);
1481+
1482+
if (tool_end == std::string::npos) {
1483+
// No closing tag found, so just include the rest of the string as tool.
1484+
tool_end = input.length();
1485+
}
1486+
1487+
// Extract tool call content
1488+
std::string tool_content = input.substr(
1489+
content_start,
1490+
tool_end - content_start
1491+
);
1492+
1493+
// Try to parse the tool call
1494+
try {
1495+
auto tool_call = json::parse(tool_content);
1496+
1497+
// Verify the required fields exist
1498+
if (!tool_call.contains("name")) {
1499+
throw std::runtime_error("Missing 'name' field in tool call");
1500+
}
1501+
1502+
if (!tool_call.contains("arguments")) {
1503+
throw std::runtime_error("Missing 'arguments' field in tool call");
1504+
}
1505+
1506+
std::string name = tool_call["name"].get<std::string>();
1507+
1508+
std::string arguments;
1509+
try {
1510+
arguments = tool_call["arguments"].dump();
1511+
} catch (const std::exception & e) {
1512+
LOG_ERR("Failed to serialize arguments: %s\n", e.what());
1513+
arguments = "{}";
1514+
}
1515+
1516+
result.tool_calls.push_back({
1517+
name,
1518+
arguments,
1519+
/* id= */ "",
1520+
});
1521+
} catch (const std::exception & e) {
1522+
// If parsing fails, include the entire tool call in the content
1523+
final_content += input.substr(
1524+
tool_start,
1525+
tool_end + closing_tag.length() - tool_start
1526+
);
1527+
}
1528+
1529+
// Move past this tool call for next iteration
1530+
start_pos = tool_end + closing_tag.length();
1531+
}
1532+
1533+
result.content = final_content;
1534+
return result;
1535+
}
1536+
1537+
13591538
static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat_template & tmpl, const struct templates_params & inputs) {
13601539
common_chat_params data;
13611540
// (content)?(<tool_call>{"name": "foo", "arguments": {"a": 1}}</tool_call>)*
@@ -1642,6 +1821,11 @@ static common_chat_params common_chat_templates_apply_jinja(
16421821
return common_chat_params_init_firefunction_v2(tmpl, params);
16431822
}
16441823

1824+
// Phi-4 mini.
1825+
if (src.find("<|tool|>") != std::string::npos) {
1826+
return common_chat_params_init_phi_4(tmpl, params);
1827+
}
1828+
16451829
// Plain handler (no tools)
16461830
if (params.tools.is_null() || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
16471831
return common_chat_params_init_without_tools(tmpl, params);
@@ -1773,6 +1957,8 @@ common_chat_msg common_chat_parse(const std::string & input, common_chat_format
17731957
return common_chat_parse_command_r7b(input, /* extract_reasoning= */ false);
17741958
case COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING:
17751959
return common_chat_parse_command_r7b(input, /* extract_reasoning= */ true);
1960+
case COMMON_CHAT_FORMAT_PHI_4:
1961+
return common_chat_parse_phi_4(input);
17761962
default:
17771963
throw std::runtime_error("Unsupported format: " + common_chat_format_name(format));
17781964
}

common/chat.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,8 @@ enum common_chat_format {
5656
COMMON_CHAT_FORMAT_HERMES_2_PRO_EXTRACT_REASONING,
5757
COMMON_CHAT_FORMAT_COMMAND_R7B,
5858
COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING,
59-
59+
COMMON_CHAT_FORMAT_PHI_4,
60+
6061
COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
6162
};
6263

models/templates/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,5 @@ These templates can be updated with the following commands:
1919
./scripts/get_chat_template.py NousResearch/Hermes-2-Pro-Llama-3-8B tool_use > models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja
2020
./scripts/get_chat_template.py NousResearch/Hermes-3-Llama-3.1-8B tool_use > models/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja
2121
./scripts/get_chat_template.py Qwen/Qwen2.5-7B-Instruct > models/templates/Qwen-Qwen2.5-7B-Instruct.jinja
22+
./scripts/get_chat_template.py microsoft/Phi-4-mini-instruct > models/templates/microsoft-Phi-4-mini-instruct.jinja
2223
```
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{% for message in messages %}{% if message['role'] == 'system' and 'tools' in message and message['tools'] is not none %}{{ '<|' + message['role'] + '|>' + message['content'] + '<|tool|>' + message['tools'] + '<|/tool|>' + '<|end|>' }}{% else %}{{ '<|' + message['role'] + '|>' + message['content'] + '<|end|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>' }}{% else %}{{ eos_token }}{% endif %}

tests/test-chat.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -820,6 +820,36 @@ static void test_template_output_parsers() {
820820
test_templates(tmpls.get(), end_tokens, message_assist_call, tools,
821821
"{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
822822
}
823+
{
824+
auto tmpls = read_templates("models/templates/microsoft-Phi-4-mini-instruct.jinja");
825+
std::vector<std::string> end_tokens{ "<|end|>" };
826+
827+
assert_equals(COMMON_CHAT_FORMAT_PHI_4, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
828+
829+
// Test normal message without tools
830+
test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
831+
832+
// Test with content before tool call
833+
assert_msg_equals(
834+
common_chat_msg{"assistant", "I'll help with that.", {}, tool_calls, "", "", ""},
835+
common_chat_parse(
836+
"I'll help with that.<|tool_call|>{\"name\":\"special_function\",\"arguments\":{\"arg1\":1}}</|tool_call|>",
837+
COMMON_CHAT_FORMAT_PHI_4));
838+
839+
// Test with content after tool call
840+
assert_msg_equals(
841+
common_chat_msg{"assistant", "I'll help with that.", {}, tool_calls, "", "", ""},
842+
common_chat_parse(
843+
"<|tool_call|>{\"name\":\"special_function\",\"arguments\":{\"arg1\":1}}</|tool_call|>I'll help with that.",
844+
COMMON_CHAT_FORMAT_PHI_4));
845+
846+
// Test with newlines.
847+
assert_msg_equals(message_assist_call, common_chat_parse(
848+
"<|tool_call|>\n"
849+
"{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
850+
"</|tool_call|>",
851+
COMMON_CHAT_FORMAT_PHI_4));
852+
}
823853
{
824854
auto tmpls = read_templates("models/templates/meetkai-functionary-medium-v3.1.jinja");
825855
std::vector<std::string> end_tokens{ "<|eom_id|>", "<|eot_id|>" };

0 commit comments

Comments
 (0)