diff --git a/common/CMakeLists.txt b/common/CMakeLists.txt index 0ec8d6d8d03b5..fe865abab708d 100644 --- a/common/CMakeLists.txt +++ b/common/CMakeLists.txt @@ -65,6 +65,7 @@ add_library(${TARGET} STATIC train.cpp ngram-cache.h ngram-cache.cpp + chaton.hpp ) if (BUILD_SHARED_LIBS) diff --git a/common/chaton.hpp b/common/chaton.hpp new file mode 100644 index 0000000000000..9616dea407c69 --- /dev/null +++ b/common/chaton.hpp @@ -0,0 +1,69 @@ +#pragma once + +/** + * + * Provides a simple and dumb helpers which help chat with llm chat/instruct models + * using the chat template expected by them. + * + * Normally used to tag system prompt and user messages. + * Currently used by example/main programs. + * + * This builds on the llama_chat_apply_template. When adding support for new chat templates + * remember to update llama_chat_apply_template_internal as well as llama_chat_reverse_prompt. + * + * example/main program uses this when --chaton TEMPLATE_ID is passed to it along with -i + * sample TEMPLATE_ID's include chatml, llama2, llama3, ... + * + */ + +#include +#include + +#include "llama.h" +#include "log.h" + +// Tag the passed message suitabley as expected by the specified chat handshake template +// and the role. If the specified template is not supported logic will return false. +inline bool llama_chat_apply_template_simple( + const std::string &tmpl, + const std::string &role, + const std::string &content, + std::string &dst, + bool add_ass) { + llama_chat_message msg = { role.c_str(), content.c_str() }; + std::vector buf(content.size() * 2); // This may under allot for small messages and over allot for large messages + + int32_t slen = llama_chat_apply_template(nullptr, tmpl.c_str(), &msg, 1, add_ass, buf.data(), buf.size()); + if (slen == -1) { + LOG_TEELN("WARN:%s:Unknown template [%s] requested", __func__, tmpl.c_str()); + dst = ""; + return false; + } + if ((size_t) slen > buf.size()) { + LOGLN("INFO:%s:%s:LengthNeeded:%d:BufSizeWas:%zu", __func__, role.c_str(), slen, buf.size()); + buf.resize(slen); + slen = llama_chat_apply_template(nullptr, tmpl.c_str(), &msg, 1, add_ass, buf.data(), buf.size()); + } + + const std::string tagged_msg(buf.data(), slen); + LOGLN("INFO:%s:%s:%s", __func__, role.c_str(), tagged_msg.c_str()); + dst = tagged_msg; + return true; +} + +// return what should be the reverse prompt for the given template id +// ie possible end text tag(s) of specified model type's chat query response. +// Note that It adds these reverse prompts to any that may already exist in the passed vector. +inline bool llama_chat_reverse_prompt(std::string &template_id, std::vector &rprompts) { + if (template_id == "chatml") { + rprompts.push_back("<|im_start|>user\n"); + } else if (template_id == "llama2") { + rprompts.push_back(""); + } else if (template_id == "llama3") { + rprompts.push_back("<|eot_id|>"); + } else { + LOG_TEELN("WARN:%s:Unknown template [%s] requested", __func__, template_id.c_str()); + return false; + } + return true; +} diff --git a/common/common.cpp b/common/common.cpp index cf69535e2d1f5..b704d6f1f986a 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -868,6 +868,15 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.chatml = true; return true; } + if (arg == "--chaton") { + params.chaton = true; + if (++i >= argc) { + invalid_param = true; + return true; + } + params.chaton_template_id = argv[i]; + return true; + } if (arg == "--infill") { params.infill = true; return true; @@ -1378,6 +1387,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" --version show version and build info\n"); printf(" -i, --interactive run in interactive mode\n"); printf(" --interactive-first run in interactive mode and wait for input right away\n"); + printf(" --chaton TEMPLATE_ID allow the interactive mode to apply the specified chat template before sending user input to model (you need to specify -i also)\n"); + printf(" TEMPLATE_ID could be chatml, llama3, ...\n"); printf(" -ins, --instruct run in instruction mode (use with Alpaca models)\n"); printf(" -cml, --chatml run in chatml mode (use with ChatML-compatible models)\n"); printf(" --multiline-input allows you to write or paste multiple lines without ending each in '\\'\n"); diff --git a/common/common.h b/common/common.h index cca44268e6df5..931317c832153 100644 --- a/common/common.h +++ b/common/common.h @@ -139,6 +139,8 @@ struct gpt_params { bool use_color = false; // use color to distinguish generations and inputs bool interactive = false; // interactive mode bool chatml = false; // chatml mode (used for models trained on chatml syntax) + bool chaton = false; // chaton mode (used to chat with models which have been trained for chat and or instruct operation) + std::string chaton_template_id = ""; // the internal chat template to use bool prompt_cache_all = false; // save user input and generations to prompt cache bool prompt_cache_ro = false; // open the prompt cache read-only and do not update it diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 249fc2bb605b3..32bcee9c43199 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -1,4 +1,5 @@ #include "common.h" +#include "chaton.hpp" #include "console.h" #include "llama.h" @@ -251,11 +252,17 @@ int main(int argc, char ** argv) { std::vector embd_inp; - if (params.interactive_first || params.instruct || params.chatml || !params.prompt.empty() || session_tokens.empty()) { - LOG("tokenize the prompt\n"); + if (params.interactive_first || params.instruct || params.chatml || params.chaton || !params.prompt.empty() || session_tokens.empty()) { + LOG("tokenize the prompt: %s\n", params.prompt.c_str()); if (params.chatml) { params.prompt = "<|im_start|>system\n" + params.prompt + "<|im_end|>"; } + if (params.chaton) { + if (!llama_chat_apply_template_simple(params.chaton_template_id, "system", params.prompt, params.prompt, false)) { + LOG_TEELN("ERRR:%s:Wrt:%s:%s:%s", __func__, params.chaton_template_id.c_str(), "system", params.prompt.c_str()); + exit(2); + } + } embd_inp = ::llama_tokenize(ctx, params.prompt, true, true); } else { LOG("use session tokens\n"); @@ -333,7 +340,7 @@ int main(int argc, char ** argv) { } // number of tokens to keep when resetting context - if (params.n_keep < 0 || params.n_keep > (int) embd_inp.size() || params.instruct || params.chatml) { + if (params.n_keep < 0 || params.n_keep > (int) embd_inp.size() || params.instruct || params.chatml || params.chaton) { params.n_keep = (int)embd_inp.size(); } else { params.n_keep += add_bos; // always keep the BOS token @@ -363,6 +370,14 @@ int main(int argc, char ** argv) { params.interactive_first = true; params.antiprompt.emplace_back("<|im_start|>user\n"); } + // handle chaton mode, it adds on to any reverse prompt specified explicitly by the user + if (params.chaton) { + params.interactive_first = true; + if (!llama_chat_reverse_prompt(params.chaton_template_id, params.antiprompt)) { + LOG_TEELN("ERRR:%s:ChatOn:Unsupported ChatTemplateType:%s", __func__, params.chaton_template_id.c_str()); + exit(1); + } + } // enable interactive mode if interactive start is specified if (params.interactive_first) { @@ -817,7 +832,7 @@ int main(int argc, char ** argv) { if (n_past > 0 && is_interacting) { LOG("waiting for user input\n"); - if (params.instruct || params.chatml) { + if (params.instruct || params.chatml || params.chaton) { printf("\n> "); } @@ -876,15 +891,27 @@ int main(int argc, char ** argv) { process_escapes(buffer); } - const auto line_pfx = ::llama_tokenize(ctx, params.input_prefix, false, true); - const auto line_inp = ::llama_tokenize(ctx, buffer, false, false); - const auto line_sfx = ::llama_tokenize(ctx, params.input_suffix, false, true); - - LOG("input tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, line_inp).c_str()); - - embd_inp.insert(embd_inp.end(), line_pfx.begin(), line_pfx.end()); - embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end()); - embd_inp.insert(embd_inp.end(), line_sfx.begin(), line_sfx.end()); + std::vector line_inp; + if (params.chaton) { + std::string f_chat; + if (!llama_chat_apply_template_simple(params.chaton_template_id, "user", buffer.c_str(), f_chat, true)) { + LOG_TEELN("ERRR:%s:Wrt:%s:%s:%s", __func__, params.chaton_template_id.c_str(), "user", params.prompt.c_str()); + exit(2); + } + line_inp = ::llama_tokenize(ctx, f_chat, false, true); + LOG("formatted input tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, line_inp).c_str()); + embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end()); + } else { + const auto line_pfx = ::llama_tokenize(ctx, params.input_prefix, false, true); + line_inp = ::llama_tokenize(ctx, buffer, false, false); + const auto line_sfx = ::llama_tokenize(ctx, params.input_suffix, false, true); + + LOG("input tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, line_inp).c_str()); + + embd_inp.insert(embd_inp.end(), line_pfx.begin(), line_pfx.end()); + embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end()); + embd_inp.insert(embd_inp.end(), line_sfx.begin(), line_sfx.end()); + } // instruct mode: insert response suffix if (params.instruct) { @@ -921,6 +948,7 @@ int main(int argc, char ** argv) { } // end of text token + // chaton expected to be used along with interactive argument, so not checking for chaton seperately if (!embd.empty() && embd.back() == llama_token_eos(model) && !(params.instruct || params.interactive || params.chatml)) { LOG_TEE(" [end of text]\n"); break;