Skip to content

Commit dc7836c

Browse files
ngxsonarthw
authored andcommitted
llama : fix llama_chat_format_single for mistral (ggml-org#8657)
* fix `llama_chat_format_single` for mistral * fix typo * use printf
1 parent 146da8b commit dc7836c

File tree

3 files changed

+26
-7
lines changed

3 files changed

+26
-7
lines changed

common/common.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2723,7 +2723,7 @@ std::string llama_chat_format_single(const struct llama_model * model,
27232723
const llama_chat_msg & new_msg,
27242724
bool add_ass) {
27252725
std::ostringstream ss;
2726-
auto fmt_past_msg = llama_chat_apply_template(model, tmpl, past_msg, false);
2726+
auto fmt_past_msg = past_msg.empty() ? "" : llama_chat_apply_template(model, tmpl, past_msg, false);
27272727
std::vector<llama_chat_msg> chat_new(past_msg);
27282728
// if the past_msg ends with a newline, we must preserve it in the formatted version
27292729
if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') {

examples/main/main.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ static std::string chat_add_and_format(struct llama_model * model, std::vector<l
124124
auto formatted = llama_chat_format_single(
125125
model, g_params->chat_template, chat_msgs, new_msg, role == "user");
126126
chat_msgs.push_back({role, content});
127+
LOG("formatted: %s\n", formatted.c_str());
127128
return formatted;
128129
}
129130

tests/test-chat-template.cpp

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
#include <iostream>
21
#include <string>
32
#include <vector>
43
#include <sstream>
@@ -133,26 +132,45 @@ int main(void) {
133132
);
134133
formatted_chat.resize(res);
135134
std::string output(formatted_chat.data(), formatted_chat.size());
136-
std::cout << output << "\n-------------------------\n";
135+
printf("%s\n", output.c_str());
136+
printf("-------------------------\n");
137137
assert(output == expected);
138138
}
139139

140-
// test llama_chat_format_single
141-
std::cout << "\n\n=== llama_chat_format_single ===\n\n";
140+
141+
// test llama_chat_format_single for system message
142+
printf("\n\n=== llama_chat_format_single (system message) ===\n\n");
142143
std::vector<llama_chat_msg> chat2;
144+
llama_chat_msg sys_msg{"system", "You are a helpful assistant"};
145+
146+
auto fmt_sys = [&](std::string tmpl) {
147+
auto output = llama_chat_format_single(nullptr, tmpl, chat2, sys_msg, false);
148+
printf("fmt_sys(%s) : %s\n", tmpl.c_str(), output.c_str());
149+
printf("-------------------------\n", output.c_str());
150+
return output;
151+
};
152+
assert(fmt_sys("chatml") == "<|im_start|>system\nYou are a helpful assistant<|im_end|>\n");
153+
assert(fmt_sys("llama2") == "[INST] You are a helpful assistant\n");
154+
assert(fmt_sys("gemma") == ""); // for gemma, system message is merged with user message
155+
assert(fmt_sys("llama3") == "<|start_header_id|>system<|end_header_id|>\n\nYou are a helpful assistant<|eot_id|>");
156+
157+
158+
// test llama_chat_format_single for user message
159+
printf("\n\n=== llama_chat_format_single (user message) ===\n\n");
143160
chat2.push_back({"system", "You are a helpful assistant"});
144161
chat2.push_back({"user", "Hello"});
145162
chat2.push_back({"assistant", "I am assistant"});
146163
llama_chat_msg new_msg{"user", "How are you"};
147164

148165
auto fmt_single = [&](std::string tmpl) {
149166
auto output = llama_chat_format_single(nullptr, tmpl, chat2, new_msg, true);
150-
std::cout << "fmt_single(" << tmpl << ")\n" << output << "\n-------------------------\n";
167+
printf("fmt_single(%s) : %s\n", tmpl.c_str(), output.c_str());
168+
printf("-------------------------\n", output.c_str());
151169
return output;
152170
};
153171
assert(fmt_single("chatml") == "\n<|im_start|>user\nHow are you<|im_end|>\n<|im_start|>assistant\n");
154172
assert(fmt_single("llama2") == "[INST] How are you [/INST]");
155-
assert(fmt_single("gemma") == "\n<start_of_turn>user\nHow are you<end_of_turn>\n<start_of_turn>model\n");
173+
assert(fmt_single("gemma") == "\n<start_of_turn>user\nHow are you<end_of_turn>\n<start_of_turn>model\n");
156174
assert(fmt_single("llama3") == "<|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n");
157175

158176
return 0;

0 commit comments

Comments
 (0)