diff --git a/common/common.cpp b/common/common.cpp index 560e20d080d0f..141abaef8ca0b 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1093,6 +1093,12 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa sparams.grammar = json_schema_to_grammar(json::parse(argv[i])); return true; } + if (arg == "-th" || arg == "--token-healing") { + CHECK_ARG + std::string value(argv[i]); + invalid_param = !llama_token_healing_parse_params(value, sparams.token_healing); + return true; + } if (arg == "--override-kv") { CHECK_ARG if (!string_parse_kv_override(argv[i], params.kv_overrides)) { @@ -1501,6 +1507,10 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param "if suffix/prefix are specified, template will be disabled\n" "only commonly used templates are accepted:\n" "https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template" }); + + options.push_back({ "main", "-th, --token-healing {0,1,d1,d,r{N}}", + "Token healing type. (default: 0, disabled)\n" + "1: replace one token, d1: replace longest suffix with one token, d: replace longest suffix, r{N}: roll back N tokens" }); options.push_back({ "grammar" }); options.push_back({ "*", " --grammar GRAMMAR", "BNF-like grammar to constrain generations (see samples in grammars/ dir) (default: '%s')", sparams.grammar.c_str() }); options.push_back({ "*", " --grammar-file FNAME", "file to read grammar from" }); diff --git a/common/sampling.cpp b/common/sampling.cpp index 079e405168dff..a999b908c9f2a 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -2,6 +2,181 @@ #include "sampling.h" #include +// +// Token healing (internal) +// + +static bool startswith(const std::string & str, const std::string & prefix) { + return str.rfind(prefix, 0) != std::string::npos; +} + +static bool token_healing_prefix_exists(const llama_context * ctx_main, const std::string & prefix) { + const int32_t n_vocab = llama_n_vocab(llama_get_model(ctx_main)); + for (llama_token token_id = 0; token_id < n_vocab; ++token_id) { + std::string token = llama_token_to_piece(ctx_main, token_id); + if (startswith(token, prefix)) { + return true; + } + } + return false; +} + +static std::vector token_healing_get_candidates( + const llama_context * ctx_main, + const std::string & prefix, + const bool include_partial_prefix) { + // Example: prefix=" world" -> " world", " worldwide", ... + // If `include_partial_prefix`, include also: " w", " wo", ... + std::vector candidates; + const int32_t n_vocab = llama_n_vocab(llama_get_model(ctx_main)); + for (llama_token token_id = 0; token_id < n_vocab; ++token_id) { + std::string token = llama_token_to_piece(ctx_main, token_id); + if (startswith(token, prefix) || + (include_partial_prefix && startswith(prefix, token))) { + candidates.push_back(token_id); + } + } + return candidates; +} + +static size_t get_max_token_length(const llama_context * ctx_main) { + const int32_t n_vocab = llama_n_vocab(llama_get_model(ctx_main)); + size_t len = 0; + for (llama_token token_id = 0; token_id < n_vocab; ++token_id) { + std::string token = llama_token_to_piece(ctx_main, token_id); + len = std::max(len, token.size()); + } + return len; +} + +static llama_token_healing_output llama_token_healing_get_prefix( + const llama_context * ctx_main, + const llama_token_healing_type th_type, + const std::vector & tokens, + int max_to_remove) { + if (tokens.size() <= 1) { + return {}; + } + + const int n_ctx = tokens.size(); + max_to_remove = th_type == llama_token_healing_type::ROLLBACK_LAST ? 1 : max_to_remove; + max_to_remove = max_to_remove < 0 ? n_ctx - 1 : std::min(max_to_remove, n_ctx - 1); // 1 token must remain + + int removed = 0; + std::string prefix; + + const llama_model * model = llama_get_model(ctx_main); + auto is_special_token = [&](const llama_token token_id) { + return llama_token_is_control(model, token_id) + || llama_token_bos (model) == token_id + || llama_token_eos (model) == token_id + || llama_token_cls (model) == token_id + || llama_token_sep (model) == token_id + || llama_token_pad (model) == token_id + || llama_token_prefix (model) == token_id + || llama_token_middle (model) == token_id + || llama_token_suffix (model) == token_id + || llama_token_eot (model) == token_id; + }; + + if (th_type == llama_token_healing_type::DYNAMIC_ONCE || th_type == llama_token_healing_type::DYNAMIC_MULTI) { + // The number of bytes to roll back cannot exceed the length of the longest token. + const size_t n_longest_token = get_max_token_length(ctx_main); + size_t len = 0; + while (removed < max_to_remove) { + const llama_token next_token_id = tokens[n_ctx - removed - 1]; + if (is_special_token(next_token_id)) { + break; + } + const size_t next_token_size = llama_token_to_piece(ctx_main, next_token_id).size(); + if (len + next_token_size > n_longest_token) { + break; + } + len += next_token_size; + removed += 1; + } + + while (removed > 0) { + prefix.clear(); + for (int i = n_ctx - removed; i < n_ctx; i++) { + prefix += llama_token_to_piece(ctx_main, tokens[i]); + } + if (token_healing_prefix_exists(ctx_main, prefix)) { + break; // Stop on longest valid prefix + } + removed -= 1; + } + } else { + // Roll back tokens a fixed amount and stop early if a special token is encountered. + while (removed < max_to_remove) { + const llama_token next_token_id = tokens[n_ctx - removed - 1]; + if (is_special_token(next_token_id)) { + break; + } + removed += 1; + } + for (int i = n_ctx - removed; i < n_ctx; i++) { + prefix += llama_token_to_piece(ctx_main, tokens[i]); + } + } + return {prefix, removed}; +} + +// +// Token healing (external) +// + +llama_token_healing_output llama_token_healing_rollback( + const llama_context * ctx_main, + std::vector & tokens, + llama_token_healing_type th_type, + int max_to_remove) { + // NB. To avoid returning empty `tokens`, at least 1 token will remain in `tokens` after rolling back. + // It is the caller's responsibility to add BOS to the start of the prompt if they want to roll back the whole prompt. + llama_token_healing_output out = llama_token_healing_get_prefix(ctx_main, th_type, tokens, max_to_remove); + + // If constrained decoding would give back the original prompt, there is no need to modify the prompt. + const bool is_multi_step = th_type == llama_token_healing_type::ROLLBACK_MULTI || + th_type == llama_token_healing_type::DYNAMIC_MULTI; + const std::vector candidates = token_healing_get_candidates(ctx_main, out.prefix, is_multi_step); + LOG("token_healing: prefix = '%s' (%d tokens)\n", out.prefix.c_str(), out.n_tokens_removed); + if (out.n_tokens_removed == 1 && candidates.size() == 1) { + LOG("token_healing: nothing to heal\n"); + return {}; + } + + // Finally, trim prompt tokens + tokens.resize(tokens.size() - out.n_tokens_removed); + return out; +} + +void llama_token_healing_set_prefix(llama_sampling_context * ctx_sampling, const std::string & prefix) { + ctx_sampling->token_healing_prefix = prefix; +} + +bool llama_token_healing_parse_params(const std::string & params, llama_token_healing_params & th_params) { + th_params.enabled = true; + th_params.n_rollback = -1; + /**/ if (params == "0" ) { th_params.enabled = false; } + else if (params == "1" ) { th_params.type = llama_token_healing_type::ROLLBACK_LAST; } + else if (params == "d1") { th_params.type = llama_token_healing_type::DYNAMIC_ONCE; } + else if (params == "d" ) { th_params.type = llama_token_healing_type::DYNAMIC_MULTI; } + else if (params[0] == 'r' ) { + th_params.type = llama_token_healing_type::ROLLBACK_MULTI; + th_params.n_rollback = std::stoi(params.substr(1)); + if (th_params.n_rollback <= 0) { + return false; + } + } else { + return false; + } + return true; +} + +// +// Sampling +// + struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params) { struct llama_sampling_context * result = new llama_sampling_context(); @@ -72,6 +247,8 @@ void llama_sampling_reset(llama_sampling_context * ctx) { ctx->grammar = grammar; } + ctx->token_healing_prefix.clear(); + std::fill(ctx->prev.begin(), ctx->prev.end(), 0); ctx->cur.clear(); ctx->n_valid = 0; @@ -130,7 +307,7 @@ std::string llama_sampling_print(const llama_sampling_params & params) { } std::string llama_sampling_order_print(const llama_sampling_params & params) { - std::string result = "CFG -> Penalties "; + std::string result = "(Token healing) -> CFG -> Penalties "; if (params.mirostat == 0) { for (auto sampler_type : params.samplers_sequence) { const auto sampler_type_name = llama_sampling_type_to_str(sampler_type); @@ -391,10 +568,28 @@ static llama_token_data_array llama_sampling_prepare_impl( llama_sample_apply_guidance(ctx_main, logits, logits_guidance, params.cfg_scale); } - cur.resize(n_vocab); + // Constrain tokens based on the remaining token healing prefix (if any) + const auto & th_prefix = ctx_sampling->token_healing_prefix; + if (params.token_healing.enabled && !th_prefix.empty()) { + const bool is_multi_step = params.token_healing.type == llama_token_healing_type::ROLLBACK_MULTI || + params.token_healing.type == llama_token_healing_type::DYNAMIC_MULTI; + std::vector th_candidates = token_healing_get_candidates(ctx_main, th_prefix, is_multi_step); + + LOG("token_healing: prefix = '%s'\n", th_prefix.c_str()); + for (const llama_token token_id : th_candidates) { + LOG(" [%6d] '%s'\n", token_id, llama_token_to_piece(ctx_main, token_id).c_str()); + } - for (llama_token token_id = 0; token_id < n_vocab; token_id++) { - cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f}; + // N.B. We could also set token constraints by setting rejected tokens' logits to -inf + cur.clear(); + for (const llama_token token_id : th_candidates) { + cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); + } + } else { + cur.resize(n_vocab); + for (llama_token token_id = 0; token_id < n_vocab; token_id++) { + cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f}; + } } llama_token_data_array cur_p = { cur.data(), cur.size(), false }; @@ -457,4 +652,19 @@ void llama_sampling_accept( if (ctx_sampling->grammar != NULL && apply_grammar) { llama_grammar_accept_token(ctx_sampling->grammar, ctx_main, id); } + + if (ctx_sampling->params.token_healing.enabled && apply_grammar) { + std::string & th_prefix = ctx_sampling->token_healing_prefix; + if (!th_prefix.empty()) { + const std::string new_token_piece = llama_token_to_piece(ctx_main, id); + if (new_token_piece.size() < th_prefix.size()) { + // Shift prefix constraint (for multi step token healing) + th_prefix = th_prefix.substr(new_token_piece.size()); + } else { + // Prefix has been generated => no more constrained generation + th_prefix.clear(); + LOG("token_healing: done\n"); + } + } + } } diff --git a/common/sampling.h b/common/sampling.h index eeaa53b8bcd00..257c2aaeb59fb 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -19,6 +19,19 @@ enum class llama_sampler_type : char { TEMPERATURE = 't' }; +enum class llama_token_healing_type : uint8_t { + ROLLBACK_LAST, // roll back last token with a single constrained decoding step + ROLLBACK_MULTI, // roll back a fixed amount of tokens, multiple constrained decoding steps + DYNAMIC_ONCE, // dynamic roll back, single constrained decoding step + DYNAMIC_MULTI // dynamic roll back, multiple constrained decoding steps +}; + +struct llama_token_healing_params { + bool enabled = false; + llama_token_healing_type type = llama_token_healing_type::DYNAMIC_MULTI; + int n_rollback = -1; // number of tokens to roll back +}; + // sampling parameters typedef struct llama_sampling_params { int32_t n_prev = 64; // number of previous tokens to remember @@ -62,6 +75,8 @@ typedef struct llama_sampling_params { std::vector penalty_prompt_tokens; bool use_penalty_prompt_tokens = false; + + llama_token_healing_params token_healing; } llama_sampling_params; // general sampler context @@ -78,6 +93,8 @@ struct llama_sampling_context { // internal grammar_parser::parse_state parsed_grammar; + std::string token_healing_prefix; // remaining prefix to constrain sampling + // TODO: replace with ring-buffer std::vector prev; std::vector cur; @@ -158,3 +175,25 @@ void llama_sampling_accept( struct llama_context * ctx_main, llama_token id, bool apply_grammar); + +// +// Token healing +// + +struct llama_token_healing_output { + std::string prefix; + int n_tokens_removed; +}; + +// Roll back `tokens` for constrained generation according to the token healing strategy. +// Call `llama_token_healing_set_prefix` with the returned prefix before the first sampling. +llama_token_healing_output llama_token_healing_rollback( + const llama_context * ctx_main, + std::vector & tokens, + llama_token_healing_type th_type, + int max_to_remove = -1); + +void llama_token_healing_set_prefix(llama_sampling_context * ctx_sampling, const std::string & prefix); + +// Helper for parsing token healing params from a string. +bool llama_token_healing_parse_params(const std::string & params, llama_token_healing_params & th_params); diff --git a/examples/main/README.md b/examples/main/README.md index 9396a34fa5a31..2dd2691ce1802 100644 --- a/examples/main/README.md +++ b/examples/main/README.md @@ -251,6 +251,19 @@ A more practical use case might be to prevent the generation of `\code{begin}` a Example usage: `--logit-bias 29905-inf` +### Token healing + +- `-th {0,1,d1,d,r{N}}, --token-healing {0,1,d1,d,r{N}}`: Set the token healing strategy (default: 0, 0 = disabled). + +Token healing (a.k.a. token alignment) alleviates tokenization artifacts for text completion. + +- `-th 1`: Roll back the last token and constrain the bytes of the next token to start with the chopped off last token [0, 2]. +- `-th d1`: Roll back multiple tokens until there doesn't exist a token which can cover the prompt's suffix and do a single constrained decoding step [2]. +- `-th d`: Like `d1` but allow multiple decoding steps until the removed suffix is generated. +- `-th r{N}`: Like `d` but roll back `N` tokens, where `-th r3` is recommended [1]. + +Sources: [0](https://github.com/guidance-ai/guidance/blob/main/notebooks/art_of_prompt_design/prompt_boundaries_and_token_healing.ipynb), [1](https://arxiv.org/abs/2403.08688), [2](https://arxiv.org/abs/2402.01035). + ### RNG Seed - `-s SEED, --seed SEED`: Set the random number generator (RNG) seed (default: -1, -1 = random seed). diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 6e0635a66cd06..07144a7cb9fc8 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -291,6 +291,16 @@ int main(int argc, char ** argv) { LOG("tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str()); } + if (sparams.token_healing.enabled && (params.conversation || !params.input_suffix.empty())) { + sparams.token_healing.enabled = false; + LOG("token healing: disabled due to custom suffix/conversation mode"); + } + llama_token_healing_output token_healing_out{}; + if (!params.interactive_first && sparams.token_healing.enabled) { + token_healing_out = llama_token_healing_rollback(ctx, embd_inp, + sparams.token_healing.type, sparams.token_healing.n_rollback); + } + // Should not run without any tokens if (embd_inp.empty()) { if (add_bos) { @@ -315,7 +325,7 @@ int main(int argc, char ** argv) { std::vector original_inp = ::llama_tokenize(ctx, params.prompt, true, true); LOG("original_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, original_inp).c_str()); - original_prompt_len = original_inp.size(); + original_prompt_len = original_inp.size() - token_healing_out.n_tokens_removed; guidance_offset = (int)guidance_inp.size() - original_prompt_len; LOG("original_prompt_len: %s", log_tostr(original_prompt_len)); LOG("guidance_offset: %s", log_tostr(guidance_offset)); @@ -510,6 +520,7 @@ int main(int argc, char ** argv) { int n_consumed = 0; int n_session_consumed = 0; int n_past_guidance = 0; + int n_bytes_to_skip = 0; // to skip printing when generating token healing prefix std::vector input_tokens; g_input_tokens = &input_tokens; std::vector output_tokens; g_output_tokens = &output_tokens; @@ -536,6 +547,7 @@ int main(int argc, char ** argv) { fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__); exit(1); } + llama_token_healing_set_prefix(ctx_sampling, token_healing_out.prefix); if (llama_model_has_encoder(model)) { int enc_input_size = embd_inp.size(); @@ -770,7 +782,15 @@ int main(int argc, char ** argv) { const std::string token_str = llama_token_to_piece(ctx, id, params.special); // Console/Stream Output - fprintf(stdout, "%s", token_str.c_str()); + // Suppress printing while generating token healing prefix + if (n_bytes_to_skip > 0 && n_bytes_to_skip < (int)token_str.size()) { + fprintf(stdout, "%s", token_str.substr(n_bytes_to_skip).c_str()); + n_bytes_to_skip = 0; + } else if (n_bytes_to_skip > 0) { + n_bytes_to_skip -= token_str.size(); + } else { + fprintf(stdout, "%s", token_str.c_str()); + } // Record Displayed Tokens To Log // Note: Generated tokens are created one by one hence this check @@ -862,6 +882,8 @@ int main(int argc, char ** argv) { assistant_ss << llama_token_to_piece(ctx, id, false); } + token_healing_out = {}; + if (n_past > 0 && is_interacting) { LOG("waiting for user input\n"); @@ -934,6 +956,16 @@ int main(int argc, char ** argv) { embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end()); embd_inp.insert(embd_inp.end(), line_sfx.begin(), line_sfx.end()); + if (sparams.token_healing.enabled) { + // Limit token healing rollback to new tokens only (otherwise would need to shift everything) + const int n_new_tokens = embd_inp.size() - original_size; + const int max_to_remove = sparams.token_healing.n_rollback < 0 + ? n_new_tokens + : std::min(sparams.token_healing.n_rollback, n_new_tokens); + token_healing_out = llama_token_healing_rollback(ctx, embd_inp, sparams.token_healing.type, max_to_remove); + n_bytes_to_skip = token_healing_out.prefix.size(); + } + for (size_t i = original_size; i < embd_inp.size(); ++i) { const llama_token token = embd_inp[i]; output_tokens.push_back(token); @@ -943,7 +975,7 @@ int main(int argc, char ** argv) { // reset assistant message assistant_ss.str(""); - n_remain -= line_inp.size(); + n_remain -= line_inp.size() + token_healing_out.n_tokens_removed; LOG("n_remain: %d\n", n_remain); } else { LOG("empty line, passing control back\n"); @@ -955,6 +987,10 @@ int main(int argc, char ** argv) { if (n_past > 0) { if (is_interacting) { llama_sampling_reset(ctx_sampling); + if (token_healing_out.n_tokens_removed > 0) { + // Set new prefix after an interaction + llama_token_healing_set_prefix(ctx_sampling, token_healing_out.prefix); + } } is_interacting = false; } diff --git a/examples/server/README.md b/examples/server/README.md index e17595fe87f25..208fddae50373 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -436,6 +436,8 @@ node index.js `json_schema`: Set a JSON schema for grammar-based sampling (e.g. `{"items": {"type": "string"}, "minItems": 10, "maxItems": 100}` of a list of strings, or `{}` for any JSON). See [tests](../../tests/test-json-schema-to-grammar.cpp) for supported features. Default: no JSON schema. + `token_healing`: Set the token healing strategy. Default: `0`, which is disabled. Possible values: `1` to replace one token, `d1` to replace the longest suffix with a single token, `d` to replace the longest suffix, `rN` to roll back N tokens (e.g. `r3`). See [here](../main/README.md#token-healing) for more details. + `seed`: Set the random number generator (RNG) seed. Default: `-1`, which is a random seed. `ignore_eos`: Ignore end of stream token and continue generating. Default: `false` diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 360f571e42867..a564e32abd9df 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -185,6 +185,7 @@ struct server_slot { // stats size_t n_sent_text = 0; // number of sent text character size_t n_sent_token_probs = 0; + size_t n_th_prefix = 0; // size of remaining token healing prefix int64_t t_start_process_prompt; int64_t t_start_generation; @@ -206,6 +207,7 @@ struct server_slot { infill = false; ga_i = 0; n_past_se = 0; + n_th_prefix = 0; generated_token_probs.clear(); } @@ -1094,6 +1096,25 @@ struct server_context { } } + { + const auto & token_healing_str = data.find("token_healing"); + if (token_healing_str != data.end() && token_healing_str->is_string()) { + const auto value = token_healing_str->get(); + if (!llama_token_healing_parse_params(value, slot.sparams.token_healing)) { + send_error(task, "\"token_healing\" parse error", ERROR_TYPE_INVALID_REQUEST); + return false; + } + LOG_VERBOSE("token healing", { + {"id_slot", slot.id}, + {"enabled", slot.sparams.token_healing.enabled}, + {"type", slot.sparams.token_healing.type}, + {"n_rollback", slot.sparams.token_healing.n_rollback} + }); + } else { + slot.sparams.token_healing = default_sparams.token_healing; + } + } + { if (slot.ctx_sampling != nullptr) { llama_sampling_free(slot.ctx_sampling); @@ -1189,14 +1210,26 @@ struct server_context { } bool process_token(completion_token_output & result, server_slot & slot) { - // remember which tokens were sampled - used for repetition penalties during sampling const std::string token_str = llama_token_to_piece(ctx, result.tok, params.special); slot.sampled = result.tok; - - // search stop word and delete it - slot.generated_text += token_str; slot.has_next_token = true; + // Suppress generating the token healing prefix to not repeat the input prompt's suffix + bool is_token_healing = false; + if (slot.n_th_prefix > 0) { + if (slot.n_th_prefix < token_str.size()) { + slot.generated_text += token_str.substr(slot.n_th_prefix); + slot.n_th_prefix = 0; + is_token_healing = false; // to send partial token text when streaming + } else { + slot.n_th_prefix -= token_str.size(); + is_token_healing = true; + } + } else { + slot.generated_text += token_str; + } + + // remember which tokens were sampled - used for repetition penalties during sampling if (slot.ctx_sampling->params.use_penalty_prompt_tokens && result.tok != -1) { // we can change penalty_prompt_tokens because it is always created from scratch each request slot.ctx_sampling->params.penalty_prompt_tokens.push_back(result.tok); @@ -1224,7 +1257,7 @@ struct server_context { break; } - if (!incomplete) { + if (!incomplete && !is_token_healing) { size_t pos = std::min(slot.n_sent_text, slot.generated_text.size()); const std::string str_test = slot.generated_text.substr(pos); @@ -1256,7 +1289,7 @@ struct server_context { } } - if (incomplete) { + if (incomplete || is_token_healing) { slot.has_next_token = true; } @@ -1361,7 +1394,8 @@ struct server_context { {"n_probs", slot.sparams.n_probs}, {"min_keep", slot.sparams.min_keep}, {"grammar", slot.sparams.grammar}, - {"samplers", samplers_sequence} + {"samplers", samplers_sequence}, + {"token_healing_enabled", slot.sparams.token_healing.enabled} }; } @@ -2043,6 +2077,8 @@ struct server_context { slot.t_start_process_prompt = ggml_time_us(); slot.t_start_generation = 0; + llama_token_healing_output token_healing_out{}; + if (slot.infill) { const bool add_bos = llama_should_add_bos_token(model); bool suff_rm_leading_spc = true; @@ -2062,6 +2098,12 @@ struct server_context { prefix_tokens.insert(prefix_tokens.begin(), llama_token_prefix(model)); suffix_tokens.insert(suffix_tokens.begin(), llama_token_suffix(model)); + if (slot.sparams.token_healing.enabled) { + // For FIM roll back only the prefix part (i.e. cursor location) + token_healing_out = llama_token_healing_rollback(ctx, prefix_tokens, + slot.sparams.token_healing.type, slot.sparams.token_healing.n_rollback); + } + auto embd_inp = params.spm_infill ? suffix_tokens : prefix_tokens; auto embd_end = params.spm_infill ? prefix_tokens : suffix_tokens; if (add_bos) { @@ -2077,6 +2119,11 @@ struct server_context { prompt_tokens = embd_inp; } else { prompt_tokens = tokenize(slot.prompt, system_prompt.empty()); // add BOS if there isn't system prompt + + if (slot.sparams.token_healing.enabled) { + token_healing_out = llama_token_healing_rollback(ctx, prompt_tokens, + slot.sparams.token_healing.type, slot.sparams.token_healing.n_rollback); + } } slot.n_past = 0; @@ -2091,6 +2138,16 @@ struct server_context { {"prompt_tokens", tokens_to_str(ctx, prompt_tokens.cbegin(), prompt_tokens.cend())}, }); + if (slot.sparams.token_healing.enabled) { + slot.n_th_prefix = token_healing_out.prefix.size(); + LOG_VERBOSE("token healing prompt", { + {"id_slot", slot.id}, + {"id_task", slot.id_task}, + {"removed_suffix", token_healing_out.prefix}, + {"n_tokens_removed", token_healing_out.n_tokens_removed} + }); + } + // empty prompt passed -> release the slot and send empty response if (prompt_tokens.empty()) { LOG_INFO("empty prompt - releasing slot", { @@ -2156,6 +2213,9 @@ struct server_context { } llama_sampling_reset(slot.ctx_sampling); + if (slot.sparams.token_healing.enabled) { + llama_token_healing_set_prefix(slot.ctx_sampling, token_healing_out.prefix); + } if (!slot.params.cache_prompt) { slot.n_past_se = 0;