Skip to content

Commit 5f60339

Browse files
ngxsoncebtenzzre
authored andcommitted
server : add llama2 chat template (ggml-org#5425)
* server: add mistral chat template * server: fix typo * server: rename template mistral to llama2 * server: format_llama2: remove BOS * server: validate "--chat-template" argument * server: clean up using_chatml variable Co-authored-by: Jared Van Bortel <cebtenzzre@gmail.com> --------- Co-authored-by: Jared Van Bortel <cebtenzzre@gmail.com>
1 parent 2b3c77d commit 5f60339

File tree

3 files changed

+56
-4
lines changed

3 files changed

+56
-4
lines changed

examples/server/oai.hpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,13 @@
1515
using json = nlohmann::json;
1616

1717
inline static json oaicompat_completion_params_parse(
18-
const json &body /* openai api json semantics */)
18+
const json &body, /* openai api json semantics */
19+
const std::string &chat_template)
1920
{
2021
json llama_params;
22+
std::string formatted_prompt = chat_template == "chatml"
23+
? format_chatml(body["messages"]) // OpenAI 'messages' to chatml (with <|im_start|>,...)
24+
: format_llama2(body["messages"]); // OpenAI 'messages' to llama2 (with [INST],...)
2125

2226
llama_params["__oaicompat"] = true;
2327

@@ -30,7 +34,7 @@ inline static json oaicompat_completion_params_parse(
3034
// https://platform.openai.com/docs/api-reference/chat/create
3135
llama_sampling_params default_sparams;
3236
llama_params["model"] = json_value(body, "model", std::string("unknown"));
33-
llama_params["prompt"] = format_chatml(body["messages"]); // OpenAI 'messages' to llama.cpp 'prompt'
37+
llama_params["prompt"] = formatted_prompt;
3438
llama_params["cache_prompt"] = json_value(body, "cache_prompt", false);
3539
llama_params["temperature"] = json_value(body, "temperature", 0.0);
3640
llama_params["top_k"] = json_value(body, "top_k", default_sparams.top_k);

examples/server/server.cpp

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ struct server_params
3636
std::string hostname = "127.0.0.1";
3737
std::vector<std::string> api_keys;
3838
std::string public_path = "examples/server/public";
39+
std::string chat_template = "chatml";
3940
int32_t port = 8080;
4041
int32_t read_timeout = 600;
4142
int32_t write_timeout = 600;
@@ -1859,6 +1860,8 @@ static void server_print_usage(const char *argv0, const gpt_params &params,
18591860
printf(" types: int, float, bool. example: --override-kv tokenizer.ggml.add_bos_token=bool:false\n");
18601861
printf(" -gan N, --grp-attn-n N set the group attention factor to extend context size through self-extend(default: 1=disabled), used together with group attention width `--grp-attn-w`");
18611862
printf(" -gaw N, --grp-attn-w N set the group attention width to extend context size through self-extend(default: 512), used together with group attention factor `--grp-attn-n`");
1863+
printf(" --chat-template FORMAT_NAME");
1864+
printf(" set chat template, possible valus is: llama2, chatml (default %s)", sparams.chat_template.c_str());
18621865
printf("\n");
18631866
}
18641867

@@ -2290,6 +2293,21 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
22902293
log_set_target(stdout);
22912294
LOG_INFO("logging to file is disabled.", {});
22922295
}
2296+
else if (arg == "--chat-template")
2297+
{
2298+
if (++i >= argc)
2299+
{
2300+
invalid_param = true;
2301+
break;
2302+
}
2303+
std::string value(argv[i]);
2304+
if (value != "chatml" && value != "llama2") {
2305+
fprintf(stderr, "error: chat template can be \"llama2\" or \"chatml\", but got: %s\n", value.c_str());
2306+
invalid_param = true;
2307+
break;
2308+
}
2309+
sparams.chat_template = value;
2310+
}
22932311
else if (arg == "--override-kv")
22942312
{
22952313
if (++i >= argc) {
@@ -2743,13 +2761,13 @@ int main(int argc, char **argv)
27432761

27442762

27452763
// TODO: add mount point without "/v1" prefix -- how?
2746-
svr.Post("/v1/chat/completions", [&llama, &validate_api_key](const httplib::Request &req, httplib::Response &res)
2764+
svr.Post("/v1/chat/completions", [&llama, &validate_api_key, &sparams](const httplib::Request &req, httplib::Response &res)
27472765
{
27482766
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
27492767
if (!validate_api_key(req, res)) {
27502768
return;
27512769
}
2752-
json data = oaicompat_completion_params_parse(json::parse(req.body));
2770+
json data = oaicompat_completion_params_parse(json::parse(req.body), sparams.chat_template);
27532771

27542772
const int task_id = llama.queue_tasks.get_new_id();
27552773
llama.queue_results.add_waiting_task_id(task_id);

examples/server/utils.hpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,34 @@ static T json_value(const json &body, const std::string &key, const T &default_v
167167
: default_value;
168168
}
169169

170+
inline std::string format_llama2(std::vector<json> messages)
171+
{
172+
std::ostringstream output;
173+
bool is_inside_turn = false;
174+
175+
for (auto it = messages.begin(); it != messages.end(); ++it) {
176+
if (!is_inside_turn) {
177+
output << "[INST] ";
178+
}
179+
std::string role = json_value(*it, "role", std::string("user"));
180+
std::string content = json_value(*it, "content", std::string(""));
181+
if (role == "system") {
182+
output << "<<SYS>>\n" << content << "\n<<SYS>>\n\n";
183+
is_inside_turn = true;
184+
} else if (role == "user") {
185+
output << content << " [/INST]";
186+
is_inside_turn = true;
187+
} else {
188+
output << " " << content << " </s>";
189+
is_inside_turn = false;
190+
}
191+
}
192+
193+
LOG_VERBOSE("format_llama2", {{"text", output.str()}});
194+
195+
return output.str();
196+
}
197+
170198
inline std::string format_chatml(std::vector<json> messages)
171199
{
172200
std::ostringstream chatml_msgs;
@@ -180,6 +208,8 @@ inline std::string format_chatml(std::vector<json> messages)
180208

181209
chatml_msgs << "<|im_start|>assistant" << '\n';
182210

211+
LOG_VERBOSE("format_chatml", {{"text", chatml_msgs.str()}});
212+
183213
return chatml_msgs.str();
184214
}
185215

0 commit comments

Comments
 (0)