From 13885c747eadf3d7ff72396eae6bd6ae35c3a897 Mon Sep 17 00:00:00 2001 From: mare5x Date: Thu, 27 Jun 2024 16:08:24 +0200 Subject: [PATCH 01/10] main : add token healing --- common/common.cpp | 23 +++++++ common/sampling.cpp | 148 +++++++++++++++++++++++++++++++++++++++- common/sampling.h | 28 ++++++++ examples/main/README.md | 13 ++++ examples/main/main.cpp | 43 +++++++++++- 5 files changed, 249 insertions(+), 6 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 560e20d080d0f..e5eccc54e0b3b 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1093,6 +1093,25 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa sparams.grammar = json_schema_to_grammar(json::parse(argv[i])); return true; } + if (arg == "-th" || arg == "--token-healing") { + CHECK_ARG + sparams.token_healing_enabled = true; + auto & th_type = sparams.token_healing_type; + auto & th_n_rollback = sparams.token_healing_n_rollback; + std::string value(argv[i]); + /**/ if (value == "0" ) { sparams.token_healing_enabled = false; } + else if (value == "1" ) { th_type = llama_token_healing_type::ROLLBACK_LAST; } + else if (value == "d1") { th_type = llama_token_healing_type::DYNAMIC_ONCE; } + else if (value == "d" ) { th_type = llama_token_healing_type::DYNAMIC_MULTI; } + else if (value[0] == 'r' ) { + th_type = llama_token_healing_type::ROLLBACK_MULTI; + th_n_rollback = std::stoi(value.substr(1)); + if (th_n_rollback <= 0) { + sparams.token_healing_enabled = false; + } + } else { invalid_param = true; } + return true; + } if (arg == "--override-kv") { CHECK_ARG if (!string_parse_kv_override(argv[i], params.kv_overrides)) { @@ -1501,6 +1520,10 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param "if suffix/prefix are specified, template will be disabled\n" "only commonly used templates are accepted:\n" "https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template" }); + + options.push_back({ "main", "-th, --token-healing {0,1,d1,d,r{N}}", + "Token healing type. (default: 0, disabled)\n" + "1: replace one token, d1: replace longest suffix with one token, d: replace longest suffix, r{N}: roll back N tokens" }); options.push_back({ "grammar" }); options.push_back({ "*", " --grammar GRAMMAR", "BNF-like grammar to constrain generations (see samples in grammars/ dir) (default: '%s')", sparams.grammar.c_str() }); options.push_back({ "*", " --grammar-file FNAME", "file to read grammar from" }); diff --git a/common/sampling.cpp b/common/sampling.cpp index 079e405168dff..02795a1827149 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -2,6 +2,112 @@ #include "sampling.h" #include +// +// Token healing (internal) +// + +static bool startswith(const std::string & str, const std::string & prefix) { + return str.rfind(prefix, 0) != std::string::npos; +} + +static bool token_healing_prefix_exists(const llama_context * ctx_main, const std::string & prefix) { + const int32_t n_vocab = llama_n_vocab(llama_get_model(ctx_main)); + for (llama_token token_id = 0; token_id < n_vocab; ++token_id) { + if (startswith(llama_token_to_piece(ctx_main, token_id), prefix)) { + return true; + } + } + return false; +} + +static std::vector token_healing_find_prefix( + const llama_context * ctx_main, + const std::string & prefix, + const bool include_partial_prefix) { + // Example: prefix=" world" -> " world", " worldwide", ... + // If `include_partial_prefix`, include also: " w", " wo", ... + std::vector candidates; + const int32_t n_vocab = llama_n_vocab(llama_get_model(ctx_main)); + for (llama_token token_id = 0; token_id < n_vocab; ++token_id) { + std::string token = llama_token_to_piece(ctx_main, token_id); + if (startswith(token, prefix) || + (include_partial_prefix && startswith(prefix, token))) { + candidates.push_back(token_id); + } + } + return candidates; +} + +// +// Token healing (external) +// + +std::string llama_token_healing_rollback( + const llama_context * ctx_main, + llama_token_healing_type th_type, + std::vector & tokens, + int max_to_remove, + int * n_removed) { + // NB. To avoid returning empty `tokens`, at least 1 token will remain in `tokens` after rolling back. + // It is the caller's responsibility to add BOS to the start of the prompt if they want to roll back the whole prompt. + if (n_removed != nullptr) { + *n_removed = 0; + } + if (tokens.size() <= 1) { + return ""; + } + const llama_model * model = llama_get_model(ctx_main); + const bool is_dynamic = th_type == llama_token_healing_type::DYNAMIC_ONCE || th_type == llama_token_healing_type::DYNAMIC_MULTI; + const int n_ctx = tokens.size(); + max_to_remove = th_type == llama_token_healing_type::ROLLBACK_LAST ? 1 : max_to_remove; + max_to_remove = max_to_remove < 0 ? n_ctx - 1 : std::min(max_to_remove, n_ctx - 1); // 1 token must remain + int removed = 0; + std::string prefix; + // Roll back tokens a fixed amount or until there does not exist a token that can cover the prompt + // and stop early if a special token is encountered. + // NB. This doesn't handle cases where a long token is split many times, + // e.g. if "abc" is tokenized into ["a", "b", "c"] but "bc" is not a token (hypothetically), + // then "abc" will not be returned even if "abcd" exists in the vocab. + while (removed < max_to_remove) { + const llama_token next_token_id = tokens[n_ctx - removed - 1]; + if (llama_token_is_control(model, next_token_id) || llama_token_is_eog(model, next_token_id)) { + break; // Don't roll back e.g. <|endoftext|> + } + std::string new_prefix = llama_token_to_piece(ctx_main, next_token_id) + prefix; + if (is_dynamic && !token_healing_prefix_exists(ctx_main, new_prefix)) { + break; + } + removed += 1; + prefix = new_prefix; + } + if (removed == 0) { // E.g. if the last token is a special token + return ""; + } + // If constrained decoding would give back the original prompt, there is no need to modify the context + const bool is_multi_step = th_type == llama_token_healing_type::ROLLBACK_MULTI || + th_type == llama_token_healing_type::DYNAMIC_MULTI; + const std::vector candidates = token_healing_find_prefix(ctx_main, prefix, is_multi_step); + LOG("token_healing: prefix = '%s' (%d tokens)\n", prefix.c_str(), removed); + if (removed == 1 && candidates.size() == 1) { + LOG("token_healing: nothing to heal\n"); + return ""; + } + // Finalize outputs + if (n_removed != nullptr) { + *n_removed = removed; + } + tokens.resize(n_ctx - removed); + return prefix; +} + +void llama_token_healing_set_prefix(llama_sampling_context * ctx_sampling, const std::string & prefix) { + ctx_sampling->token_healing_prefix = prefix; +} + +// +// Sampling +// + struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params) { struct llama_sampling_context * result = new llama_sampling_context(); @@ -72,6 +178,8 @@ void llama_sampling_reset(llama_sampling_context * ctx) { ctx->grammar = grammar; } + ctx->token_healing_prefix.clear(); + std::fill(ctx->prev.begin(), ctx->prev.end(), 0); ctx->cur.clear(); ctx->n_valid = 0; @@ -130,7 +238,7 @@ std::string llama_sampling_print(const llama_sampling_params & params) { } std::string llama_sampling_order_print(const llama_sampling_params & params) { - std::string result = "CFG -> Penalties "; + std::string result = "(Token healing) -> CFG -> Penalties "; if (params.mirostat == 0) { for (auto sampler_type : params.samplers_sequence) { const auto sampler_type_name = llama_sampling_type_to_str(sampler_type); @@ -393,8 +501,27 @@ static llama_token_data_array llama_sampling_prepare_impl( cur.resize(n_vocab); - for (llama_token token_id = 0; token_id < n_vocab; token_id++) { - cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f}; + // Constrain tokens based on the remaining token healing prefix (if any) + const auto & th_type = params.token_healing_type; + const auto & th_prefix = ctx_sampling->token_healing_prefix; + if (params.token_healing_enabled && !th_prefix.empty()) { + const bool is_multi_step = th_type == llama_token_healing_type::ROLLBACK_MULTI || + th_type == llama_token_healing_type::DYNAMIC_MULTI; + std::vector th_candidates = token_healing_find_prefix(ctx_main, th_prefix, is_multi_step); + + LOG("token_healing: prefix = '%s'\n", th_prefix.c_str()); + for (const llama_token token_id : th_candidates) { + LOG(" [%6d] '%s'\n", token_id, llama_token_to_piece(ctx_main, token_id).c_str()); + } + + // N.B. We could also set token constraints by setting rejected tokens' logits to -inf + for (const llama_token token_id : th_candidates) { + cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f}; + } + } else { + for (llama_token token_id = 0; token_id < n_vocab; token_id++) { + cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f}; + } } llama_token_data_array cur_p = { cur.data(), cur.size(), false }; @@ -457,4 +584,19 @@ void llama_sampling_accept( if (ctx_sampling->grammar != NULL && apply_grammar) { llama_grammar_accept_token(ctx_sampling->grammar, ctx_main, id); } + + if (ctx_sampling->params.token_healing_enabled && apply_grammar) { + std::string & th_prefix = ctx_sampling->token_healing_prefix; + if (!th_prefix.empty()) { + const std::string new_token_piece = llama_token_to_piece(ctx_main, id); + if (new_token_piece.size() < th_prefix.size()) { + // Shift prefix constraint (for multi step token healing) + th_prefix = th_prefix.substr(new_token_piece.size()); + } else { + // Prefix has been generated => no more constrained generation + th_prefix.clear(); + LOG("token_healing: done\n"); + } + } + } } diff --git a/common/sampling.h b/common/sampling.h index eeaa53b8bcd00..4c1172985b364 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -19,6 +19,13 @@ enum class llama_sampler_type : char { TEMPERATURE = 't' }; +enum class llama_token_healing_type : uint8_t { + ROLLBACK_LAST, // roll back last token with a single constrained decoding step + ROLLBACK_MULTI, // roll back a fixed amount of tokens, multiple constrained decoding steps + DYNAMIC_ONCE, // dynamic roll back, single constrained decoding step + DYNAMIC_MULTI // dynamic roll back, multiple constrained decoding steps +}; + // sampling parameters typedef struct llama_sampling_params { int32_t n_prev = 64; // number of previous tokens to remember @@ -62,6 +69,10 @@ typedef struct llama_sampling_params { std::vector penalty_prompt_tokens; bool use_penalty_prompt_tokens = false; + + llama_token_healing_type token_healing_type = llama_token_healing_type::ROLLBACK_LAST; + bool token_healing_enabled = false; + int token_healing_n_rollback = -1; // number of tokens to roll back } llama_sampling_params; // general sampler context @@ -78,6 +89,8 @@ struct llama_sampling_context { // internal grammar_parser::parse_state parsed_grammar; + std::string token_healing_prefix; // remaining prefix to constrain sampling + // TODO: replace with ring-buffer std::vector prev; std::vector cur; @@ -158,3 +171,18 @@ void llama_sampling_accept( struct llama_context * ctx_main, llama_token id, bool apply_grammar); + +// +// Token healing +// + +// Roll back `tokens` for constrained generation according to the token healing +// strategy. Returns the prefix for constrained generation. +std::string llama_token_healing_rollback( + const llama_context * ctx_main, + llama_token_healing_type th_type, + std::vector & tokens, + int max_to_remove = -1, + int * n_removed = nullptr); + +void llama_token_healing_set_prefix(llama_sampling_context * ctx_sampling, const std::string & prefix); diff --git a/examples/main/README.md b/examples/main/README.md index 9396a34fa5a31..2dd2691ce1802 100644 --- a/examples/main/README.md +++ b/examples/main/README.md @@ -251,6 +251,19 @@ A more practical use case might be to prevent the generation of `\code{begin}` a Example usage: `--logit-bias 29905-inf` +### Token healing + +- `-th {0,1,d1,d,r{N}}, --token-healing {0,1,d1,d,r{N}}`: Set the token healing strategy (default: 0, 0 = disabled). + +Token healing (a.k.a. token alignment) alleviates tokenization artifacts for text completion. + +- `-th 1`: Roll back the last token and constrain the bytes of the next token to start with the chopped off last token [0, 2]. +- `-th d1`: Roll back multiple tokens until there doesn't exist a token which can cover the prompt's suffix and do a single constrained decoding step [2]. +- `-th d`: Like `d1` but allow multiple decoding steps until the removed suffix is generated. +- `-th r{N}`: Like `d` but roll back `N` tokens, where `-th r3` is recommended [1]. + +Sources: [0](https://github.com/guidance-ai/guidance/blob/main/notebooks/art_of_prompt_design/prompt_boundaries_and_token_healing.ipynb), [1](https://arxiv.org/abs/2403.08688), [2](https://arxiv.org/abs/2402.01035). + ### RNG Seed - `-s SEED, --seed SEED`: Set the random number generator (RNG) seed (default: -1, -1 = random seed). diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 6e0635a66cd06..b3e47b36b74ba 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -291,6 +291,17 @@ int main(int argc, char ** argv) { LOG("tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str()); } + if (sparams.token_healing_enabled && (params.conversation || !params.input_suffix.empty())) { + sparams.token_healing_enabled = false; + LOG("token_healing: disabled due to custom suffix/conversation mode"); + } + std::string token_healing_prefix; + int token_healing_n_removed = 0; + if (!params.interactive_first && sparams.token_healing_enabled) { + token_healing_prefix = llama_token_healing_rollback(ctx, sparams.token_healing_type, embd_inp, + sparams.token_healing_n_rollback, &token_healing_n_removed); + } + // Should not run without any tokens if (embd_inp.empty()) { if (add_bos) { @@ -315,7 +326,7 @@ int main(int argc, char ** argv) { std::vector original_inp = ::llama_tokenize(ctx, params.prompt, true, true); LOG("original_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, original_inp).c_str()); - original_prompt_len = original_inp.size(); + original_prompt_len = original_inp.size() - token_healing_n_removed; guidance_offset = (int)guidance_inp.size() - original_prompt_len; LOG("original_prompt_len: %s", log_tostr(original_prompt_len)); LOG("guidance_offset: %s", log_tostr(guidance_offset)); @@ -510,6 +521,7 @@ int main(int argc, char ** argv) { int n_consumed = 0; int n_session_consumed = 0; int n_past_guidance = 0; + int n_bytes_to_skip = 0; // to skip printing when generating token healing prefix std::vector input_tokens; g_input_tokens = &input_tokens; std::vector output_tokens; g_output_tokens = &output_tokens; @@ -536,6 +548,7 @@ int main(int argc, char ** argv) { fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__); exit(1); } + llama_token_healing_set_prefix(ctx_sampling, token_healing_prefix); if (llama_model_has_encoder(model)) { int enc_input_size = embd_inp.size(); @@ -770,7 +783,15 @@ int main(int argc, char ** argv) { const std::string token_str = llama_token_to_piece(ctx, id, params.special); // Console/Stream Output - fprintf(stdout, "%s", token_str.c_str()); + // Suppress printing while generating token healing prefix + if (n_bytes_to_skip > 0 && n_bytes_to_skip < (int)token_str.size()) { + fprintf(stdout, "%s", token_str.substr(n_bytes_to_skip).c_str()); + n_bytes_to_skip = 0; + } else if (n_bytes_to_skip > 0) { + n_bytes_to_skip -= token_str.size(); + } else { + fprintf(stdout, "%s", token_str.c_str()); + } // Record Displayed Tokens To Log // Note: Generated tokens are created one by one hence this check @@ -862,6 +883,7 @@ int main(int argc, char ** argv) { assistant_ss << llama_token_to_piece(ctx, id, false); } + token_healing_n_removed = 0; if (n_past > 0 && is_interacting) { LOG("waiting for user input\n"); @@ -934,6 +956,17 @@ int main(int argc, char ** argv) { embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end()); embd_inp.insert(embd_inp.end(), line_sfx.begin(), line_sfx.end()); + if (sparams.token_healing_enabled) { + // Limit token healing rollback to new tokens only (otherwise would need to shift everything) + const int n_new_tokens = embd_inp.size() - original_size; + const int max_to_remove = sparams.token_healing_n_rollback < 0 + ? n_new_tokens + : std::min(sparams.token_healing_n_rollback, n_new_tokens); + token_healing_prefix = llama_token_healing_rollback(ctx, sparams.token_healing_type, embd_inp, + max_to_remove, &token_healing_n_removed); + n_bytes_to_skip = token_healing_prefix.size(); + } + for (size_t i = original_size; i < embd_inp.size(); ++i) { const llama_token token = embd_inp[i]; output_tokens.push_back(token); @@ -943,7 +976,7 @@ int main(int argc, char ** argv) { // reset assistant message assistant_ss.str(""); - n_remain -= line_inp.size(); + n_remain -= line_inp.size() + token_healing_n_removed; LOG("n_remain: %d\n", n_remain); } else { LOG("empty line, passing control back\n"); @@ -955,6 +988,10 @@ int main(int argc, char ** argv) { if (n_past > 0) { if (is_interacting) { llama_sampling_reset(ctx_sampling); + if (token_healing_n_removed > 0) { + // Set new prefix after an interaction + llama_token_healing_set_prefix(ctx_sampling, token_healing_prefix); + } } is_interacting = false; } From db9c018891772624d9da0a92816ce68920c67c50 Mon Sep 17 00:00:00 2001 From: mare5x Date: Sat, 29 Jun 2024 13:02:30 +0200 Subject: [PATCH 02/10] token healing : change dynamic rollback Dynamic rollback now starts checking prefixes based on the length of the longest token. --- common/sampling.cpp | 136 ++++++++++++++++++++++++++++------------- examples/main/main.cpp | 2 +- 2 files changed, 95 insertions(+), 43 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index 02795a1827149..b407df45cace0 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -13,14 +13,15 @@ static bool startswith(const std::string & str, const std::string & prefix) { static bool token_healing_prefix_exists(const llama_context * ctx_main, const std::string & prefix) { const int32_t n_vocab = llama_n_vocab(llama_get_model(ctx_main)); for (llama_token token_id = 0; token_id < n_vocab; ++token_id) { - if (startswith(llama_token_to_piece(ctx_main, token_id), prefix)) { + std::string token = llama_token_to_piece(ctx_main, token_id); + if (startswith(token, prefix)) { return true; } } return false; } -static std::vector token_healing_find_prefix( +static std::vector token_healing_get_candidates( const llama_context * ctx_main, const std::string & prefix, const bool include_partial_prefix) { @@ -38,6 +39,85 @@ static std::vector token_healing_find_prefix( return candidates; } +static size_t get_max_token_length(const llama_context * ctx_main) { + const int32_t n_vocab = llama_n_vocab(llama_get_model(ctx_main)); + size_t len = 0; + for (llama_token token_id = 0; token_id < n_vocab; ++token_id) { + std::string token = llama_token_to_piece(ctx_main, token_id); + len = std::max(len, token.size()); + } + return len; +} + +struct token_healing_info { + std::string prefix; + int n_tokens_removed; +}; + +token_healing_info llama_token_healing_get_prefix( + const llama_context * ctx_main, + const llama_token_healing_type th_type, + const std::vector & tokens, + int max_to_remove) { + if (tokens.size() <= 1) { + return {"", 0}; + } + + const int n_ctx = tokens.size(); + max_to_remove = th_type == llama_token_healing_type::ROLLBACK_LAST ? 1 : max_to_remove; + max_to_remove = max_to_remove < 0 ? n_ctx - 1 : std::min(max_to_remove, n_ctx - 1); // 1 token must remain + + int removed = 0; + std::string prefix; + + const llama_model * model = llama_get_model(ctx_main); + auto is_special_token = [&](const llama_token token_id) { + return llama_token_is_control(model, token_id) || llama_token_is_eog(model, token_id); + }; + + if (th_type == llama_token_healing_type::DYNAMIC_ONCE || th_type == llama_token_healing_type::DYNAMIC_MULTI) { + // The number of bytes to roll back cannot exceed the length of the longest token. + const size_t n_longest_token = get_max_token_length(ctx_main); + size_t len = 0; + while (removed < max_to_remove) { + const llama_token next_token_id = tokens[n_ctx - removed - 1]; + if (is_special_token(next_token_id)) { + break; + } + const size_t next_token_size = llama_token_to_piece(ctx_main, next_token_id).size(); + if (len + next_token_size > n_longest_token) { + break; + } + len += next_token_size; + removed += 1; + } + + while (removed > 0) { + prefix.clear(); + for (int i = n_ctx - removed; i < n_ctx; i++) { + prefix += llama_token_to_piece(ctx_main, tokens[i]); + } + if (token_healing_prefix_exists(ctx_main, prefix)) { + break; // Stop on longest valid prefix + } + removed -= 1; + } + } else { + // Roll back tokens a fixed amount and stop early if a special token is encountered. + while (removed < max_to_remove) { + const llama_token next_token_id = tokens[n_ctx - removed - 1]; + if (is_special_token(next_token_id)) { + break; + } + removed += 1; + } + for (int i = n_ctx - removed; i < n_ctx; i++) { + prefix += llama_token_to_piece(ctx_main, tokens[i]); + } + } + return {prefix, removed}; +} + // // Token healing (external) // @@ -48,56 +128,28 @@ std::string llama_token_healing_rollback( std::vector & tokens, int max_to_remove, int * n_removed) { - // NB. To avoid returning empty `tokens`, at least 1 token will remain in `tokens` after rolling back. - // It is the caller's responsibility to add BOS to the start of the prompt if they want to roll back the whole prompt. if (n_removed != nullptr) { *n_removed = 0; } - if (tokens.size() <= 1) { - return ""; - } - const llama_model * model = llama_get_model(ctx_main); - const bool is_dynamic = th_type == llama_token_healing_type::DYNAMIC_ONCE || th_type == llama_token_healing_type::DYNAMIC_MULTI; - const int n_ctx = tokens.size(); - max_to_remove = th_type == llama_token_healing_type::ROLLBACK_LAST ? 1 : max_to_remove; - max_to_remove = max_to_remove < 0 ? n_ctx - 1 : std::min(max_to_remove, n_ctx - 1); // 1 token must remain - int removed = 0; - std::string prefix; - // Roll back tokens a fixed amount or until there does not exist a token that can cover the prompt - // and stop early if a special token is encountered. - // NB. This doesn't handle cases where a long token is split many times, - // e.g. if "abc" is tokenized into ["a", "b", "c"] but "bc" is not a token (hypothetically), - // then "abc" will not be returned even if "abcd" exists in the vocab. - while (removed < max_to_remove) { - const llama_token next_token_id = tokens[n_ctx - removed - 1]; - if (llama_token_is_control(model, next_token_id) || llama_token_is_eog(model, next_token_id)) { - break; // Don't roll back e.g. <|endoftext|> - } - std::string new_prefix = llama_token_to_piece(ctx_main, next_token_id) + prefix; - if (is_dynamic && !token_healing_prefix_exists(ctx_main, new_prefix)) { - break; - } - removed += 1; - prefix = new_prefix; - } - if (removed == 0) { // E.g. if the last token is a special token - return ""; - } - // If constrained decoding would give back the original prompt, there is no need to modify the context + // NB. To avoid returning empty `tokens`, at least 1 token will remain in `tokens` after rolling back. + // It is the caller's responsibility to add BOS to the start of the prompt if they want to roll back the whole prompt. + token_healing_info info = llama_token_healing_get_prefix(ctx_main, th_type, tokens, max_to_remove); + + // If constrained decoding would give back the original prompt, there is no need to modify the prompt. const bool is_multi_step = th_type == llama_token_healing_type::ROLLBACK_MULTI || th_type == llama_token_healing_type::DYNAMIC_MULTI; - const std::vector candidates = token_healing_find_prefix(ctx_main, prefix, is_multi_step); - LOG("token_healing: prefix = '%s' (%d tokens)\n", prefix.c_str(), removed); - if (removed == 1 && candidates.size() == 1) { + const std::vector candidates = token_healing_get_candidates(ctx_main, info.prefix, is_multi_step); + LOG("token_healing: prefix = '%s' (%d tokens)\n", info.prefix.c_str(), info.n_tokens_removed); + if (info.n_tokens_removed == 1 && candidates.size() == 1) { LOG("token_healing: nothing to heal\n"); return ""; } // Finalize outputs if (n_removed != nullptr) { - *n_removed = removed; + *n_removed = info.n_tokens_removed; } - tokens.resize(n_ctx - removed); - return prefix; + tokens.resize(tokens.size() - info.n_tokens_removed); + return info.prefix; } void llama_token_healing_set_prefix(llama_sampling_context * ctx_sampling, const std::string & prefix) { @@ -507,7 +559,7 @@ static llama_token_data_array llama_sampling_prepare_impl( if (params.token_healing_enabled && !th_prefix.empty()) { const bool is_multi_step = th_type == llama_token_healing_type::ROLLBACK_MULTI || th_type == llama_token_healing_type::DYNAMIC_MULTI; - std::vector th_candidates = token_healing_find_prefix(ctx_main, th_prefix, is_multi_step); + std::vector th_candidates = token_healing_get_candidates(ctx_main, th_prefix, is_multi_step); LOG("token_healing: prefix = '%s'\n", th_prefix.c_str()); for (const llama_token token_id : th_candidates) { diff --git a/examples/main/main.cpp b/examples/main/main.cpp index b3e47b36b74ba..b08fec7dcfdd1 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -293,7 +293,7 @@ int main(int argc, char ** argv) { if (sparams.token_healing_enabled && (params.conversation || !params.input_suffix.empty())) { sparams.token_healing_enabled = false; - LOG("token_healing: disabled due to custom suffix/conversation mode"); + LOG("token healing: disabled due to custom suffix/conversation mode"); } std::string token_healing_prefix; int token_healing_n_removed = 0; From 414fc13248c446331f541b7099aff26290ef26fc Mon Sep 17 00:00:00 2001 From: mare5x Date: Sat, 29 Jun 2024 13:42:00 +0200 Subject: [PATCH 03/10] token healing : refactor to return struct --- common/sampling.cpp | 51 +++++++++++++++++------------------------- common/sampling.h | 20 ++++++++++------- examples/main/main.cpp | 25 ++++++++++----------- 3 files changed, 44 insertions(+), 52 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index b407df45cace0..bdcdde057f1d1 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -49,18 +49,13 @@ static size_t get_max_token_length(const llama_context * ctx_main) { return len; } -struct token_healing_info { - std::string prefix; - int n_tokens_removed; -}; - -token_healing_info llama_token_healing_get_prefix( - const llama_context * ctx_main, - const llama_token_healing_type th_type, - const std::vector & tokens, - int max_to_remove) { +static llama_token_healing_output llama_token_healing_get_prefix( + const llama_context * ctx_main, + const llama_token_healing_type th_type, + const std::vector & tokens, + int max_to_remove) { if (tokens.size() <= 1) { - return {"", 0}; + return {}; } const int n_ctx = tokens.size(); @@ -122,34 +117,28 @@ token_healing_info llama_token_healing_get_prefix( // Token healing (external) // -std::string llama_token_healing_rollback( - const llama_context * ctx_main, - llama_token_healing_type th_type, - std::vector & tokens, - int max_to_remove, - int * n_removed) { - if (n_removed != nullptr) { - *n_removed = 0; - } +llama_token_healing_output llama_token_healing_rollback( + const llama_context * ctx_main, + llama_token_healing_type th_type, + std::vector & tokens, + int max_to_remove) { // NB. To avoid returning empty `tokens`, at least 1 token will remain in `tokens` after rolling back. // It is the caller's responsibility to add BOS to the start of the prompt if they want to roll back the whole prompt. - token_healing_info info = llama_token_healing_get_prefix(ctx_main, th_type, tokens, max_to_remove); + llama_token_healing_output out = llama_token_healing_get_prefix(ctx_main, th_type, tokens, max_to_remove); // If constrained decoding would give back the original prompt, there is no need to modify the prompt. const bool is_multi_step = th_type == llama_token_healing_type::ROLLBACK_MULTI || th_type == llama_token_healing_type::DYNAMIC_MULTI; - const std::vector candidates = token_healing_get_candidates(ctx_main, info.prefix, is_multi_step); - LOG("token_healing: prefix = '%s' (%d tokens)\n", info.prefix.c_str(), info.n_tokens_removed); - if (info.n_tokens_removed == 1 && candidates.size() == 1) { + const std::vector candidates = token_healing_get_candidates(ctx_main, out.prefix, is_multi_step); + LOG("token_healing: prefix = '%s' (%d tokens)\n", out.prefix.c_str(), out.n_tokens_removed); + if (out.n_tokens_removed == 1 && candidates.size() == 1) { LOG("token_healing: nothing to heal\n"); - return ""; + return {}; } - // Finalize outputs - if (n_removed != nullptr) { - *n_removed = info.n_tokens_removed; - } - tokens.resize(tokens.size() - info.n_tokens_removed); - return info.prefix; + + // Finally, trim prompt tokens + tokens.resize(tokens.size() - out.n_tokens_removed); + return out; } void llama_token_healing_set_prefix(llama_sampling_context * ctx_sampling, const std::string & prefix) { diff --git a/common/sampling.h b/common/sampling.h index 4c1172985b364..094b40c8912c5 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -176,13 +176,17 @@ void llama_sampling_accept( // Token healing // -// Roll back `tokens` for constrained generation according to the token healing -// strategy. Returns the prefix for constrained generation. -std::string llama_token_healing_rollback( - const llama_context * ctx_main, - llama_token_healing_type th_type, - std::vector & tokens, - int max_to_remove = -1, - int * n_removed = nullptr); +struct llama_token_healing_output { + std::string prefix; + int n_tokens_removed; +}; + +// Roll back `tokens` for constrained generation according to the token healing strategy. +// Call `llama_token_healing_set_prefix` with the returned prefix before the first sampling. +llama_token_healing_output llama_token_healing_rollback( + const llama_context * ctx_main, + llama_token_healing_type th_type, + std::vector & tokens, + int max_to_remove = -1); void llama_token_healing_set_prefix(llama_sampling_context * ctx_sampling, const std::string & prefix); diff --git a/examples/main/main.cpp b/examples/main/main.cpp index b08fec7dcfdd1..6976b269773fa 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -295,11 +295,10 @@ int main(int argc, char ** argv) { sparams.token_healing_enabled = false; LOG("token healing: disabled due to custom suffix/conversation mode"); } - std::string token_healing_prefix; - int token_healing_n_removed = 0; + llama_token_healing_output token_healing_out{}; if (!params.interactive_first && sparams.token_healing_enabled) { - token_healing_prefix = llama_token_healing_rollback(ctx, sparams.token_healing_type, embd_inp, - sparams.token_healing_n_rollback, &token_healing_n_removed); + token_healing_out = llama_token_healing_rollback(ctx, sparams.token_healing_type, embd_inp, + sparams.token_healing_n_rollback); } // Should not run without any tokens @@ -326,7 +325,7 @@ int main(int argc, char ** argv) { std::vector original_inp = ::llama_tokenize(ctx, params.prompt, true, true); LOG("original_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, original_inp).c_str()); - original_prompt_len = original_inp.size() - token_healing_n_removed; + original_prompt_len = original_inp.size() - token_healing_out.n_tokens_removed; guidance_offset = (int)guidance_inp.size() - original_prompt_len; LOG("original_prompt_len: %s", log_tostr(original_prompt_len)); LOG("guidance_offset: %s", log_tostr(guidance_offset)); @@ -548,7 +547,7 @@ int main(int argc, char ** argv) { fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__); exit(1); } - llama_token_healing_set_prefix(ctx_sampling, token_healing_prefix); + llama_token_healing_set_prefix(ctx_sampling, token_healing_out.prefix); if (llama_model_has_encoder(model)) { int enc_input_size = embd_inp.size(); @@ -883,7 +882,8 @@ int main(int argc, char ** argv) { assistant_ss << llama_token_to_piece(ctx, id, false); } - token_healing_n_removed = 0; + token_healing_out = {}; + if (n_past > 0 && is_interacting) { LOG("waiting for user input\n"); @@ -962,9 +962,8 @@ int main(int argc, char ** argv) { const int max_to_remove = sparams.token_healing_n_rollback < 0 ? n_new_tokens : std::min(sparams.token_healing_n_rollback, n_new_tokens); - token_healing_prefix = llama_token_healing_rollback(ctx, sparams.token_healing_type, embd_inp, - max_to_remove, &token_healing_n_removed); - n_bytes_to_skip = token_healing_prefix.size(); + token_healing_out = llama_token_healing_rollback(ctx, sparams.token_healing_type, embd_inp, max_to_remove); + n_bytes_to_skip = token_healing_out.prefix.size(); } for (size_t i = original_size; i < embd_inp.size(); ++i) { @@ -976,7 +975,7 @@ int main(int argc, char ** argv) { // reset assistant message assistant_ss.str(""); - n_remain -= line_inp.size() + token_healing_n_removed; + n_remain -= line_inp.size() + token_healing_out.n_tokens_removed; LOG("n_remain: %d\n", n_remain); } else { LOG("empty line, passing control back\n"); @@ -988,9 +987,9 @@ int main(int argc, char ** argv) { if (n_past > 0) { if (is_interacting) { llama_sampling_reset(ctx_sampling); - if (token_healing_n_removed > 0) { + if (token_healing_out.n_tokens_removed > 0) { // Set new prefix after an interaction - llama_token_healing_set_prefix(ctx_sampling, token_healing_prefix); + llama_token_healing_set_prefix(ctx_sampling, token_healing_out.prefix); } } is_interacting = false; From fc8773d3096fdbd1266c0e62d7136289fbae631f Mon Sep 17 00:00:00 2001 From: mare5x Date: Sun, 30 Jun 2024 20:14:18 +0200 Subject: [PATCH 04/10] token healing : handle more special tokens Infill tokens were being rolled back in certain cases. --- common/sampling.cpp | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index bdcdde057f1d1..b5c6b9ad3a7e5 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -67,7 +67,16 @@ static llama_token_healing_output llama_token_healing_get_prefix( const llama_model * model = llama_get_model(ctx_main); auto is_special_token = [&](const llama_token token_id) { - return llama_token_is_control(model, token_id) || llama_token_is_eog(model, token_id); + return llama_token_is_control(model, token_id) + || llama_token_bos (model) == token_id + || llama_token_eos (model) == token_id + || llama_token_cls (model) == token_id + || llama_token_sep (model) == token_id + || llama_token_pad (model) == token_id + || llama_token_prefix (model) == token_id + || llama_token_middle (model) == token_id + || llama_token_suffix (model) == token_id + || llama_token_eot (model) == token_id; }; if (th_type == llama_token_healing_type::DYNAMIC_ONCE || th_type == llama_token_healing_type::DYNAMIC_MULTI) { From d5eea137977d57741eff9913e3304c5f10fcef90 Mon Sep 17 00:00:00 2001 From: mare5x Date: Wed, 26 Jun 2024 17:12:57 +0200 Subject: [PATCH 05/10] server : add token healing support --- examples/server/README.md | 2 + examples/server/server.cpp | 77 ++++++++++++++++++++++++++++++++++---- 2 files changed, 72 insertions(+), 7 deletions(-) diff --git a/examples/server/README.md b/examples/server/README.md index e17595fe87f25..f2cea4741cac0 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -436,6 +436,8 @@ node index.js `json_schema`: Set a JSON schema for grammar-based sampling (e.g. `{"items": {"type": "string"}, "minItems": 10, "maxItems": 100}` of a list of strings, or `{}` for any JSON). See [tests](../../tests/test-json-schema-to-grammar.cpp) for supported features. Default: no JSON schema. + `token_healing`: Set token healing strategy. Default: `0`, which is disabled. + `seed`: Set the random number generator (RNG) seed. Default: `-1`, which is a random seed. `ignore_eos`: Ignore end of stream token and continue generating. Default: `false` diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 360f571e42867..0d556ac246404 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -185,6 +185,7 @@ struct server_slot { // stats size_t n_sent_text = 0; // number of sent text character size_t n_sent_token_probs = 0; + size_t n_th_prefix = 0; // size of remaining token healing prefix int64_t t_start_process_prompt; int64_t t_start_generation; @@ -206,6 +207,7 @@ struct server_slot { infill = false; ga_i = 0; n_past_se = 0; + n_th_prefix = 0; generated_token_probs.clear(); } @@ -1094,6 +1096,36 @@ struct server_context { } } + { + const auto & token_healing_str = data.find("token_healing"); + auto & th_enabled = slot.sparams.token_healing_enabled; + th_enabled = default_sparams.token_healing_enabled; + if (token_healing_str != data.end() && token_healing_str->is_string()) { + const auto value = token_healing_str->get(); + auto & th_type = slot.sparams.token_healing_type; + auto & th_n_rollback = slot.sparams.token_healing_n_rollback; + th_enabled = true; + /**/ if (value == "0" ) { th_enabled = false; } + else if (value == "1" ) { th_type = llama_token_healing_type::ROLLBACK_LAST; } + else if (value == "d1") { th_type = llama_token_healing_type::DYNAMIC_ONCE; } + else if (value == "d" ) { th_type = llama_token_healing_type::DYNAMIC_MULTI; } + else if (value[0] == 'r' ) { + th_type = llama_token_healing_type::ROLLBACK_MULTI; + th_n_rollback = std::stoi(value.substr(1)); + if (th_n_rollback <= 0) { + th_enabled = false; + } + } else { th_enabled = false; } + + LOG_VERBOSE("token healing", { + {"id_slot", slot.id}, + {"enabled", th_enabled}, + {"type", th_type}, + {"n_rollback", th_n_rollback} + }); + } + } + { if (slot.ctx_sampling != nullptr) { llama_sampling_free(slot.ctx_sampling); @@ -1189,14 +1221,26 @@ struct server_context { } bool process_token(completion_token_output & result, server_slot & slot) { - // remember which tokens were sampled - used for repetition penalties during sampling const std::string token_str = llama_token_to_piece(ctx, result.tok, params.special); slot.sampled = result.tok; - - // search stop word and delete it - slot.generated_text += token_str; slot.has_next_token = true; + // Suppress generating the token healing prefix to not repeat the input prompt's suffix + bool is_token_healing = false; + if (slot.n_th_prefix > 0) { + if (slot.n_th_prefix < token_str.size()) { + slot.generated_text += token_str.substr(slot.n_th_prefix); + slot.n_th_prefix = 0; + is_token_healing = false; // to send partial token text when streaming + } else { + slot.n_th_prefix -= token_str.size(); + is_token_healing = true; + } + } else { + slot.generated_text += token_str; + } + + // remember which tokens were sampled - used for repetition penalties during sampling if (slot.ctx_sampling->params.use_penalty_prompt_tokens && result.tok != -1) { // we can change penalty_prompt_tokens because it is always created from scratch each request slot.ctx_sampling->params.penalty_prompt_tokens.push_back(result.tok); @@ -1224,7 +1268,7 @@ struct server_context { break; } - if (!incomplete) { + if (!incomplete && !is_token_healing) { size_t pos = std::min(slot.n_sent_text, slot.generated_text.size()); const std::string str_test = slot.generated_text.substr(pos); @@ -1256,7 +1300,7 @@ struct server_context { } } - if (incomplete) { + if (incomplete || is_token_healing) { slot.has_next_token = true; } @@ -1361,7 +1405,8 @@ struct server_context { {"n_probs", slot.sparams.n_probs}, {"min_keep", slot.sparams.min_keep}, {"grammar", slot.sparams.grammar}, - {"samplers", samplers_sequence} + {"samplers", samplers_sequence}, + {"token_healing_enabled", slot.sparams.token_healing_enabled} }; } @@ -2106,6 +2151,21 @@ struct server_context { continue; } + // Roll back prompt tokens if token healing + llama_token_healing_output token_healing_out{}; + if (slot.sparams.token_healing_enabled) { + token_healing_out = llama_token_healing_rollback(ctx, slot.sparams.token_healing_type, + prompt_tokens, slot.sparams.token_healing_n_rollback); + slot.n_th_prefix = token_healing_out.prefix.size(); + slot.n_prompt_tokens = prompt_tokens.size(); + LOG_VERBOSE("token healing prompt", { + {"id_slot", slot.id}, + {"id_task", slot.id_task}, + {"removed_suffix", token_healing_out.prefix}, + {"n_tokens_removed", token_healing_out.n_tokens_removed} + }); + } + if (slot.embedding) { // this prompt is too large to process - discard it if (slot.n_prompt_tokens > n_ubatch) { @@ -2156,6 +2216,9 @@ struct server_context { } llama_sampling_reset(slot.ctx_sampling); + if (slot.sparams.token_healing_enabled) { + llama_token_healing_set_prefix(slot.ctx_sampling, token_healing_out.prefix); + } if (!slot.params.cache_prompt) { slot.n_past_se = 0; From 3ba5c55bc47c97136fab135f07526d8bbd6be667 Mon Sep 17 00:00:00 2001 From: mare5x Date: Sun, 30 Jun 2024 22:30:15 +0200 Subject: [PATCH 06/10] server : token healing for infilling/FIM --- examples/server/server.cpp | 38 +++++++++++++++++++++++--------------- 1 file changed, 23 insertions(+), 15 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 0d556ac246404..01aac8ed53427 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -2088,6 +2088,8 @@ struct server_context { slot.t_start_process_prompt = ggml_time_us(); slot.t_start_generation = 0; + llama_token_healing_output token_healing_out{}; + if (slot.infill) { const bool add_bos = llama_should_add_bos_token(model); bool suff_rm_leading_spc = true; @@ -2107,6 +2109,12 @@ struct server_context { prefix_tokens.insert(prefix_tokens.begin(), llama_token_prefix(model)); suffix_tokens.insert(suffix_tokens.begin(), llama_token_suffix(model)); + if (slot.sparams.token_healing_enabled) { + // For FIM roll back only the prefix part (i.e. cursor location) + token_healing_out = llama_token_healing_rollback(ctx, slot.sparams.token_healing_type, + prefix_tokens, slot.sparams.token_healing_n_rollback); + } + auto embd_inp = params.spm_infill ? suffix_tokens : prefix_tokens; auto embd_end = params.spm_infill ? prefix_tokens : suffix_tokens; if (add_bos) { @@ -2122,6 +2130,11 @@ struct server_context { prompt_tokens = embd_inp; } else { prompt_tokens = tokenize(slot.prompt, system_prompt.empty()); // add BOS if there isn't system prompt + + if (slot.sparams.token_healing_enabled) { + token_healing_out = llama_token_healing_rollback(ctx, slot.sparams.token_healing_type, + prompt_tokens, slot.sparams.token_healing_n_rollback); + } } slot.n_past = 0; @@ -2136,6 +2149,16 @@ struct server_context { {"prompt_tokens", tokens_to_str(ctx, prompt_tokens.cbegin(), prompt_tokens.cend())}, }); + if (slot.sparams.token_healing_enabled) { + slot.n_th_prefix = token_healing_out.prefix.size(); + LOG_VERBOSE("token healing prompt", { + {"id_slot", slot.id}, + {"id_task", slot.id_task}, + {"removed_suffix", token_healing_out.prefix}, + {"n_tokens_removed", token_healing_out.n_tokens_removed} + }); + } + // empty prompt passed -> release the slot and send empty response if (prompt_tokens.empty()) { LOG_INFO("empty prompt - releasing slot", { @@ -2151,21 +2174,6 @@ struct server_context { continue; } - // Roll back prompt tokens if token healing - llama_token_healing_output token_healing_out{}; - if (slot.sparams.token_healing_enabled) { - token_healing_out = llama_token_healing_rollback(ctx, slot.sparams.token_healing_type, - prompt_tokens, slot.sparams.token_healing_n_rollback); - slot.n_th_prefix = token_healing_out.prefix.size(); - slot.n_prompt_tokens = prompt_tokens.size(); - LOG_VERBOSE("token healing prompt", { - {"id_slot", slot.id}, - {"id_task", slot.id_task}, - {"removed_suffix", token_healing_out.prefix}, - {"n_tokens_removed", token_healing_out.n_tokens_removed} - }); - } - if (slot.embedding) { // this prompt is too large to process - discard it if (slot.n_prompt_tokens > n_ubatch) { From ea4abc9d8255d27dbcb3d87473579781d04a59d8 Mon Sep 17 00:00:00 2001 From: mare5x Date: Mon, 1 Jul 2024 11:51:39 +0200 Subject: [PATCH 07/10] token healing : refactor argument parsing Unify `main` and `server` token healing argument handling. --- common/common.cpp | 15 +----------- common/sampling.cpp | 28 +++++++++++++++++++---- common/sampling.h | 13 ++++++++--- examples/main/main.cpp | 18 +++++++-------- examples/server/server.cpp | 47 +++++++++++++++----------------------- 5 files changed, 61 insertions(+), 60 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index e5eccc54e0b3b..141abaef8ca0b 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1095,21 +1095,8 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa } if (arg == "-th" || arg == "--token-healing") { CHECK_ARG - sparams.token_healing_enabled = true; - auto & th_type = sparams.token_healing_type; - auto & th_n_rollback = sparams.token_healing_n_rollback; std::string value(argv[i]); - /**/ if (value == "0" ) { sparams.token_healing_enabled = false; } - else if (value == "1" ) { th_type = llama_token_healing_type::ROLLBACK_LAST; } - else if (value == "d1") { th_type = llama_token_healing_type::DYNAMIC_ONCE; } - else if (value == "d" ) { th_type = llama_token_healing_type::DYNAMIC_MULTI; } - else if (value[0] == 'r' ) { - th_type = llama_token_healing_type::ROLLBACK_MULTI; - th_n_rollback = std::stoi(value.substr(1)); - if (th_n_rollback <= 0) { - sparams.token_healing_enabled = false; - } - } else { invalid_param = true; } + invalid_param = !llama_token_healing_parse_params(value, sparams.token_healing); return true; } if (arg == "--override-kv") { diff --git a/common/sampling.cpp b/common/sampling.cpp index b5c6b9ad3a7e5..2d1610b39670e 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -154,6 +154,25 @@ void llama_token_healing_set_prefix(llama_sampling_context * ctx_sampling, const ctx_sampling->token_healing_prefix = prefix; } +bool llama_token_healing_parse_params(const std::string & params, llama_token_healing_params & th_params) { + th_params.enabled = true; + th_params.n_rollback = -1; + /**/ if (params == "0" ) { th_params.enabled = false; } + else if (params == "1" ) { th_params.type = llama_token_healing_type::ROLLBACK_LAST; } + else if (params == "d1") { th_params.type = llama_token_healing_type::DYNAMIC_ONCE; } + else if (params == "d" ) { th_params.type = llama_token_healing_type::DYNAMIC_MULTI; } + else if (params[0] == 'r' ) { + th_params.type = llama_token_healing_type::ROLLBACK_MULTI; + th_params.n_rollback = std::stoi(params.substr(1)); + if (th_params.n_rollback <= 0) { + return false; + } + } else { + return false; + } + return true; +} + // // Sampling // @@ -552,11 +571,10 @@ static llama_token_data_array llama_sampling_prepare_impl( cur.resize(n_vocab); // Constrain tokens based on the remaining token healing prefix (if any) - const auto & th_type = params.token_healing_type; const auto & th_prefix = ctx_sampling->token_healing_prefix; - if (params.token_healing_enabled && !th_prefix.empty()) { - const bool is_multi_step = th_type == llama_token_healing_type::ROLLBACK_MULTI || - th_type == llama_token_healing_type::DYNAMIC_MULTI; + if (params.token_healing.enabled && !th_prefix.empty()) { + const bool is_multi_step = params.token_healing.type == llama_token_healing_type::ROLLBACK_MULTI || + params.token_healing.type == llama_token_healing_type::DYNAMIC_MULTI; std::vector th_candidates = token_healing_get_candidates(ctx_main, th_prefix, is_multi_step); LOG("token_healing: prefix = '%s'\n", th_prefix.c_str()); @@ -635,7 +653,7 @@ void llama_sampling_accept( llama_grammar_accept_token(ctx_sampling->grammar, ctx_main, id); } - if (ctx_sampling->params.token_healing_enabled && apply_grammar) { + if (ctx_sampling->params.token_healing.enabled && apply_grammar) { std::string & th_prefix = ctx_sampling->token_healing_prefix; if (!th_prefix.empty()) { const std::string new_token_piece = llama_token_to_piece(ctx_main, id); diff --git a/common/sampling.h b/common/sampling.h index 094b40c8912c5..a269ab11ea27a 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -26,6 +26,12 @@ enum class llama_token_healing_type : uint8_t { DYNAMIC_MULTI // dynamic roll back, multiple constrained decoding steps }; +struct llama_token_healing_params { + bool enabled = false; + llama_token_healing_type type = llama_token_healing_type::DYNAMIC_MULTI; + int n_rollback = -1; // number of tokens to roll back +}; + // sampling parameters typedef struct llama_sampling_params { int32_t n_prev = 64; // number of previous tokens to remember @@ -70,9 +76,7 @@ typedef struct llama_sampling_params { std::vector penalty_prompt_tokens; bool use_penalty_prompt_tokens = false; - llama_token_healing_type token_healing_type = llama_token_healing_type::ROLLBACK_LAST; - bool token_healing_enabled = false; - int token_healing_n_rollback = -1; // number of tokens to roll back + llama_token_healing_params token_healing; } llama_sampling_params; // general sampler context @@ -190,3 +194,6 @@ llama_token_healing_output llama_token_healing_rollback( int max_to_remove = -1); void llama_token_healing_set_prefix(llama_sampling_context * ctx_sampling, const std::string & prefix); + +// Helper for parsing token healing params from a string. +bool llama_token_healing_parse_params(const std::string & params, llama_token_healing_params & th_params); diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 6976b269773fa..e8a0eefb9ce9a 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -291,14 +291,14 @@ int main(int argc, char ** argv) { LOG("tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str()); } - if (sparams.token_healing_enabled && (params.conversation || !params.input_suffix.empty())) { - sparams.token_healing_enabled = false; + if (sparams.token_healing.enabled && (params.conversation || !params.input_suffix.empty())) { + sparams.token_healing.enabled = false; LOG("token healing: disabled due to custom suffix/conversation mode"); } llama_token_healing_output token_healing_out{}; - if (!params.interactive_first && sparams.token_healing_enabled) { - token_healing_out = llama_token_healing_rollback(ctx, sparams.token_healing_type, embd_inp, - sparams.token_healing_n_rollback); + if (!params.interactive_first && sparams.token_healing.enabled) { + token_healing_out = llama_token_healing_rollback(ctx, sparams.token_healing.type, embd_inp, + sparams.token_healing.n_rollback); } // Should not run without any tokens @@ -956,13 +956,13 @@ int main(int argc, char ** argv) { embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end()); embd_inp.insert(embd_inp.end(), line_sfx.begin(), line_sfx.end()); - if (sparams.token_healing_enabled) { + if (sparams.token_healing.enabled) { // Limit token healing rollback to new tokens only (otherwise would need to shift everything) const int n_new_tokens = embd_inp.size() - original_size; - const int max_to_remove = sparams.token_healing_n_rollback < 0 + const int max_to_remove = sparams.token_healing.n_rollback < 0 ? n_new_tokens - : std::min(sparams.token_healing_n_rollback, n_new_tokens); - token_healing_out = llama_token_healing_rollback(ctx, sparams.token_healing_type, embd_inp, max_to_remove); + : std::min(sparams.token_healing.n_rollback, n_new_tokens); + token_healing_out = llama_token_healing_rollback(ctx, sparams.token_healing.type, embd_inp, max_to_remove); n_bytes_to_skip = token_healing_out.prefix.size(); } diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 01aac8ed53427..ef2d7fa218c78 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1098,31 +1098,20 @@ struct server_context { { const auto & token_healing_str = data.find("token_healing"); - auto & th_enabled = slot.sparams.token_healing_enabled; - th_enabled = default_sparams.token_healing_enabled; if (token_healing_str != data.end() && token_healing_str->is_string()) { const auto value = token_healing_str->get(); - auto & th_type = slot.sparams.token_healing_type; - auto & th_n_rollback = slot.sparams.token_healing_n_rollback; - th_enabled = true; - /**/ if (value == "0" ) { th_enabled = false; } - else if (value == "1" ) { th_type = llama_token_healing_type::ROLLBACK_LAST; } - else if (value == "d1") { th_type = llama_token_healing_type::DYNAMIC_ONCE; } - else if (value == "d" ) { th_type = llama_token_healing_type::DYNAMIC_MULTI; } - else if (value[0] == 'r' ) { - th_type = llama_token_healing_type::ROLLBACK_MULTI; - th_n_rollback = std::stoi(value.substr(1)); - if (th_n_rollback <= 0) { - th_enabled = false; - } - } else { th_enabled = false; } - + if (!llama_token_healing_parse_params(value, slot.sparams.token_healing)) { + send_error(task, "\"token_healing\" parse error", ERROR_TYPE_INVALID_REQUEST); + return false; + } LOG_VERBOSE("token healing", { {"id_slot", slot.id}, - {"enabled", th_enabled}, - {"type", th_type}, - {"n_rollback", th_n_rollback} + {"enabled", slot.sparams.token_healing.enabled}, + {"type", slot.sparams.token_healing.type}, + {"n_rollback", slot.sparams.token_healing.n_rollback} }); + } else { + slot.sparams.token_healing = default_sparams.token_healing; } } @@ -1406,7 +1395,7 @@ struct server_context { {"min_keep", slot.sparams.min_keep}, {"grammar", slot.sparams.grammar}, {"samplers", samplers_sequence}, - {"token_healing_enabled", slot.sparams.token_healing_enabled} + {"token_healing_enabled", slot.sparams.token_healing.enabled} }; } @@ -2109,10 +2098,10 @@ struct server_context { prefix_tokens.insert(prefix_tokens.begin(), llama_token_prefix(model)); suffix_tokens.insert(suffix_tokens.begin(), llama_token_suffix(model)); - if (slot.sparams.token_healing_enabled) { + if (slot.sparams.token_healing.enabled) { // For FIM roll back only the prefix part (i.e. cursor location) - token_healing_out = llama_token_healing_rollback(ctx, slot.sparams.token_healing_type, - prefix_tokens, slot.sparams.token_healing_n_rollback); + token_healing_out = llama_token_healing_rollback(ctx, slot.sparams.token_healing.type, + prefix_tokens, slot.sparams.token_healing.n_rollback); } auto embd_inp = params.spm_infill ? suffix_tokens : prefix_tokens; @@ -2131,9 +2120,9 @@ struct server_context { } else { prompt_tokens = tokenize(slot.prompt, system_prompt.empty()); // add BOS if there isn't system prompt - if (slot.sparams.token_healing_enabled) { - token_healing_out = llama_token_healing_rollback(ctx, slot.sparams.token_healing_type, - prompt_tokens, slot.sparams.token_healing_n_rollback); + if (slot.sparams.token_healing.enabled) { + token_healing_out = llama_token_healing_rollback(ctx, slot.sparams.token_healing.type, + prompt_tokens, slot.sparams.token_healing.n_rollback); } } @@ -2149,7 +2138,7 @@ struct server_context { {"prompt_tokens", tokens_to_str(ctx, prompt_tokens.cbegin(), prompt_tokens.cend())}, }); - if (slot.sparams.token_healing_enabled) { + if (slot.sparams.token_healing.enabled) { slot.n_th_prefix = token_healing_out.prefix.size(); LOG_VERBOSE("token healing prompt", { {"id_slot", slot.id}, @@ -2224,7 +2213,7 @@ struct server_context { } llama_sampling_reset(slot.ctx_sampling); - if (slot.sparams.token_healing_enabled) { + if (slot.sparams.token_healing.enabled) { llama_token_healing_set_prefix(slot.ctx_sampling, token_healing_out.prefix); } From b3173681918c3ddd4611a955145bc96513ec7b63 Mon Sep 17 00:00:00 2001 From: mare5x Date: Mon, 1 Jul 2024 12:23:21 +0200 Subject: [PATCH 08/10] token healing : change argument order --- common/sampling.cpp | 2 +- common/sampling.h | 2 +- examples/main/main.cpp | 6 +++--- examples/server/server.cpp | 8 ++++---- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index 2d1610b39670e..e9f828befbbe6 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -128,8 +128,8 @@ static llama_token_healing_output llama_token_healing_get_prefix( llama_token_healing_output llama_token_healing_rollback( const llama_context * ctx_main, - llama_token_healing_type th_type, std::vector & tokens, + llama_token_healing_type th_type, int max_to_remove) { // NB. To avoid returning empty `tokens`, at least 1 token will remain in `tokens` after rolling back. // It is the caller's responsibility to add BOS to the start of the prompt if they want to roll back the whole prompt. diff --git a/common/sampling.h b/common/sampling.h index a269ab11ea27a..257c2aaeb59fb 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -189,8 +189,8 @@ struct llama_token_healing_output { // Call `llama_token_healing_set_prefix` with the returned prefix before the first sampling. llama_token_healing_output llama_token_healing_rollback( const llama_context * ctx_main, - llama_token_healing_type th_type, std::vector & tokens, + llama_token_healing_type th_type, int max_to_remove = -1); void llama_token_healing_set_prefix(llama_sampling_context * ctx_sampling, const std::string & prefix); diff --git a/examples/main/main.cpp b/examples/main/main.cpp index e8a0eefb9ce9a..07144a7cb9fc8 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -297,8 +297,8 @@ int main(int argc, char ** argv) { } llama_token_healing_output token_healing_out{}; if (!params.interactive_first && sparams.token_healing.enabled) { - token_healing_out = llama_token_healing_rollback(ctx, sparams.token_healing.type, embd_inp, - sparams.token_healing.n_rollback); + token_healing_out = llama_token_healing_rollback(ctx, embd_inp, + sparams.token_healing.type, sparams.token_healing.n_rollback); } // Should not run without any tokens @@ -962,7 +962,7 @@ int main(int argc, char ** argv) { const int max_to_remove = sparams.token_healing.n_rollback < 0 ? n_new_tokens : std::min(sparams.token_healing.n_rollback, n_new_tokens); - token_healing_out = llama_token_healing_rollback(ctx, sparams.token_healing.type, embd_inp, max_to_remove); + token_healing_out = llama_token_healing_rollback(ctx, embd_inp, sparams.token_healing.type, max_to_remove); n_bytes_to_skip = token_healing_out.prefix.size(); } diff --git a/examples/server/server.cpp b/examples/server/server.cpp index ef2d7fa218c78..a564e32abd9df 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -2100,8 +2100,8 @@ struct server_context { if (slot.sparams.token_healing.enabled) { // For FIM roll back only the prefix part (i.e. cursor location) - token_healing_out = llama_token_healing_rollback(ctx, slot.sparams.token_healing.type, - prefix_tokens, slot.sparams.token_healing.n_rollback); + token_healing_out = llama_token_healing_rollback(ctx, prefix_tokens, + slot.sparams.token_healing.type, slot.sparams.token_healing.n_rollback); } auto embd_inp = params.spm_infill ? suffix_tokens : prefix_tokens; @@ -2121,8 +2121,8 @@ struct server_context { prompt_tokens = tokenize(slot.prompt, system_prompt.empty()); // add BOS if there isn't system prompt if (slot.sparams.token_healing.enabled) { - token_healing_out = llama_token_healing_rollback(ctx, slot.sparams.token_healing.type, - prompt_tokens, slot.sparams.token_healing.n_rollback); + token_healing_out = llama_token_healing_rollback(ctx, prompt_tokens, + slot.sparams.token_healing.type, slot.sparams.token_healing.n_rollback); } } From 940ab817849a170051784469a2875ab7b2bc239b Mon Sep 17 00:00:00 2001 From: mare5x Date: Mon, 8 Jul 2024 15:53:23 +0200 Subject: [PATCH 09/10] readme : list possible token healing values --- examples/server/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/server/README.md b/examples/server/README.md index f2cea4741cac0..208fddae50373 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -436,7 +436,7 @@ node index.js `json_schema`: Set a JSON schema for grammar-based sampling (e.g. `{"items": {"type": "string"}, "minItems": 10, "maxItems": 100}` of a list of strings, or `{}` for any JSON). See [tests](../../tests/test-json-schema-to-grammar.cpp) for supported features. Default: no JSON schema. - `token_healing`: Set token healing strategy. Default: `0`, which is disabled. + `token_healing`: Set the token healing strategy. Default: `0`, which is disabled. Possible values: `1` to replace one token, `d1` to replace the longest suffix with a single token, `d` to replace the longest suffix, `rN` to roll back N tokens (e.g. `r3`). See [here](../main/README.md#token-healing) for more details. `seed`: Set the random number generator (RNG) seed. Default: `-1`, which is a random seed. From b27f87d6da0c95a41ea5b0ec7ecd4aa4bc82f95f Mon Sep 17 00:00:00 2001 From: mare5x Date: Mon, 8 Jul 2024 16:18:19 +0200 Subject: [PATCH 10/10] token healing : fix rebase bug --- common/sampling.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index e9f828befbbe6..a999b908c9f2a 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -568,8 +568,6 @@ static llama_token_data_array llama_sampling_prepare_impl( llama_sample_apply_guidance(ctx_main, logits, logits_guidance, params.cfg_scale); } - cur.resize(n_vocab); - // Constrain tokens based on the remaining token healing prefix (if any) const auto & th_prefix = ctx_sampling->token_healing_prefix; if (params.token_healing.enabled && !th_prefix.empty()) { @@ -583,10 +581,12 @@ static llama_token_data_array llama_sampling_prepare_impl( } // N.B. We could also set token constraints by setting rejected tokens' logits to -inf + cur.clear(); for (const llama_token token_id : th_candidates) { - cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f}; + cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); } } else { + cur.resize(n_vocab); for (llama_token token_id = 0; token_id < n_vocab; token_id++) { cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f}; }