|
1 |
| -#include <iostream> |
2 | 1 | #include <string>
|
3 | 2 | #include <vector>
|
4 | 3 | #include <sstream>
|
@@ -133,26 +132,45 @@ int main(void) {
|
133 | 132 | );
|
134 | 133 | formatted_chat.resize(res);
|
135 | 134 | std::string output(formatted_chat.data(), formatted_chat.size());
|
136 |
| - std::cout << output << "\n-------------------------\n"; |
| 135 | + printf("%s\n", output.c_str()); |
| 136 | + printf("-------------------------\n"); |
137 | 137 | assert(output == expected);
|
138 | 138 | }
|
139 | 139 |
|
140 |
| - // test llama_chat_format_single |
141 |
| - std::cout << "\n\n=== llama_chat_format_single ===\n\n"; |
| 140 | + |
| 141 | + // test llama_chat_format_single for system message |
| 142 | + printf("\n\n=== llama_chat_format_single (system message) ===\n\n"); |
142 | 143 | std::vector<llama_chat_msg> chat2;
|
| 144 | + llama_chat_msg sys_msg{"system", "You are a helpful assistant"}; |
| 145 | + |
| 146 | + auto fmt_sys = [&](std::string tmpl) { |
| 147 | + auto output = llama_chat_format_single(nullptr, tmpl, chat2, sys_msg, false); |
| 148 | + printf("fmt_sys(%s) : %s\n", tmpl.c_str(), output.c_str()); |
| 149 | + printf("-------------------------\n", output.c_str()); |
| 150 | + return output; |
| 151 | + }; |
| 152 | + assert(fmt_sys("chatml") == "<|im_start|>system\nYou are a helpful assistant<|im_end|>\n"); |
| 153 | + assert(fmt_sys("llama2") == "[INST] You are a helpful assistant\n"); |
| 154 | + assert(fmt_sys("gemma") == ""); // for gemma, system message is merged with user message |
| 155 | + assert(fmt_sys("llama3") == "<|start_header_id|>system<|end_header_id|>\n\nYou are a helpful assistant<|eot_id|>"); |
| 156 | + |
| 157 | + |
| 158 | + // test llama_chat_format_single for user message |
| 159 | + printf("\n\n=== llama_chat_format_single (user message) ===\n\n"); |
143 | 160 | chat2.push_back({"system", "You are a helpful assistant"});
|
144 | 161 | chat2.push_back({"user", "Hello"});
|
145 | 162 | chat2.push_back({"assistant", "I am assistant"});
|
146 | 163 | llama_chat_msg new_msg{"user", "How are you"};
|
147 | 164 |
|
148 | 165 | auto fmt_single = [&](std::string tmpl) {
|
149 | 166 | auto output = llama_chat_format_single(nullptr, tmpl, chat2, new_msg, true);
|
150 |
| - std::cout << "fmt_single(" << tmpl << ")\n" << output << "\n-------------------------\n"; |
| 167 | + printf("fmt_single(%s) : %s\n", tmpl.c_str(), output.c_str()); |
| 168 | + printf("-------------------------\n", output.c_str()); |
151 | 169 | return output;
|
152 | 170 | };
|
153 | 171 | assert(fmt_single("chatml") == "\n<|im_start|>user\nHow are you<|im_end|>\n<|im_start|>assistant\n");
|
154 | 172 | assert(fmt_single("llama2") == "[INST] How are you [/INST]");
|
155 |
| - assert(fmt_single("gemma") == "\n<start_of_turn>user\nHow are you<end_of_turn>\n<start_of_turn>model\n"); |
| 173 | + assert(fmt_single("gemma") == "\n<start_of_turn>user\nHow are you<end_of_turn>\n<start_of_turn>model\n"); |
156 | 174 | assert(fmt_single("llama3") == "<|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n");
|
157 | 175 |
|
158 | 176 | return 0;
|
|
0 commit comments