From c77bb3203c0fbda0e4021226f4add0c852d18398 Mon Sep 17 00:00:00 2001 From: mare5x Date: Tue, 30 Apr 2024 13:38:14 +0200 Subject: [PATCH 1/6] examples : add simple token healing example --- examples/CMakeLists.txt | 1 + examples/simple-token-healing/CMakeLists.txt | 11 + examples/simple-token-healing/README.md | 70 ++++ .../simple-token-healing-1.cpp | 232 ++++++++++++ .../simple-token-healing.cpp | 353 ++++++++++++++++++ 5 files changed, 667 insertions(+) create mode 100644 examples/simple-token-healing/CMakeLists.txt create mode 100644 examples/simple-token-healing/README.md create mode 100644 examples/simple-token-healing/simple-token-healing-1.cpp create mode 100644 examples/simple-token-healing/simple-token-healing.cpp diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index f421769cc2f0a..11cbd8f61dc07 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -38,6 +38,7 @@ else() add_subdirectory(retrieval) add_subdirectory(save-load-state) add_subdirectory(simple) + add_subdirectory(simple-token-healing) add_subdirectory(passkey) add_subdirectory(speculative) add_subdirectory(lookahead) diff --git a/examples/simple-token-healing/CMakeLists.txt b/examples/simple-token-healing/CMakeLists.txt new file mode 100644 index 0000000000000..1d41611dd8e76 --- /dev/null +++ b/examples/simple-token-healing/CMakeLists.txt @@ -0,0 +1,11 @@ +set(TARGET simple-token-healing) +add_executable(${TARGET} simple-token-healing.cpp) +install(TARGETS ${TARGET} RUNTIME) +target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_11) + +set(TARGET simple-token-healing-1) +add_executable(${TARGET} simple-token-healing-1.cpp) +install(TARGETS ${TARGET} RUNTIME) +target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_11) diff --git a/examples/simple-token-healing/README.md b/examples/simple-token-healing/README.md new file mode 100644 index 0000000000000..7e546986688de --- /dev/null +++ b/examples/simple-token-healing/README.md @@ -0,0 +1,70 @@ +# llama.cpp/example/simple-token-healing + +This example extends [simple](../simple/README.md) with [token healing](https://github.com/guidance-ai/guidance/blob/main/notebooks/art_of_prompt_design/prompt_boundaries_and_token_healing.ipynb). + +Without token healing: +```bash +./simple ./models/phi-2/ggml-model-q4_0.gguf "print('Hel" +... +main: n_len = 32, n_ctx = 2048, n_kv_req = 32 + +print('Helping the customer') +... +``` + +Heal the last token (`1`): +```bash +./simple-token-healing ./models/phi-2/ggml-model-q4_0.gguf "print('Hel" 1 +... +token_healing: prefix = 'Hel' (1 tokens) + [ 12621] 'Hel' + [ 15496] 'Hello' + [ 22087] 'Help' + [ 28254] 'Hell' + [ 47429] 'Helper' + +main: n_len = 32, n_ctx = 2048, n_kv_req = 32 + +print('Hello, World!') +... +``` + +Backtrack multiple tokens until there doesn't exist a token which can cover the prompt's suffix (`n`): +```bash +./simple-token-healing ./models/phi-2/ggml-model-q4_0.gguf "print('Hello, worl" n +... +token_healing: prefix = ' worl' (2 tokens) + [ 995] ' world' + [ 8688] ' worldwide' + [ 11621] ' worlds' + [ 29081] ' worldview' + [ 43249] ' worldly' + +main: n_len = 32, n_ctx = 2048, n_kv_req = 32 + +print('Hello, world!') +... +``` + +Backtrack multiple tokens but don't constrain the decoding to a single token (`m`): +```bash +./simple-token-healing ./models/phi-2/ggml-model-q4_0.gguf "print('Hello, worl" m +... +token_healing: prefix = ' worl' (2 tokens) + +main: n_len = 32, n_ctx = 2048, n_kv_req = 32 + +print('Hello, +token_healing: prefix = ' worl' + [ 220] ' ' + [ 266] ' w' + [ 476] ' wor' + [ 995] ' world' + [ 8688] ' worldwide' + [ 11621] ' worlds' + [ 24486] ' wo' + [ 29081] ' worldview' + [ 43249] ' worldly' + world!') +... +``` diff --git a/examples/simple-token-healing/simple-token-healing-1.cpp b/examples/simple-token-healing/simple-token-healing-1.cpp new file mode 100644 index 0000000000000..6febeb38f6ff6 --- /dev/null +++ b/examples/simple-token-healing/simple-token-healing-1.cpp @@ -0,0 +1,232 @@ +#include "common.h" +#include "llama.h" + +#include +#include +#include +#include + +static std::vector heal_last_token(const llama_context * ctx, const std::vector & tokens_list) { + const llama_token last_token_id = tokens_list.back(); + const llama_model * model = llama_get_model(ctx); + const int32_t n_vocab = llama_n_vocab(model); + + // Don't roll back e.g. <|endoftext|> (set parse_special=true in llama_tokenize) + if (llama_token_get_type(model, last_token_id) != LLAMA_TOKEN_TYPE_NORMAL) { + return {}; + } + + const std::string last_piece = llama_token_to_piece(ctx, last_token_id); + fprintf(stderr, "token_healing: prefix = '%s'\n", last_piece.c_str()); + + fprintf(stderr, "token_healing: candidates:\n"); + fprintf(stderr, " [%6d] '%s'\n", last_token_id, last_piece.c_str()); + std::vector candidates = { last_token_id }; + for (llama_token token_id = 0; token_id < n_vocab; ++token_id) { + if (token_id == last_token_id) { + continue; + } + std::string token_piece = llama_token_to_piece(ctx, token_id); + if (token_piece.rfind(last_piece, 0) != std::string::npos) { + candidates.push_back(token_id); + fprintf(stderr, " [%6d] '%s'\n", token_id, token_piece.c_str()); + } + } + if (candidates.size() == 1) { + // No healing necessary if the last token is the only candidate. + return {}; + } + return candidates; +} + +int main(int argc, char ** argv) { + gpt_params params; + + if (argc == 1 || argv[1][0] == '-') { + printf("usage: %s MODEL_PATH [PROMPT]\n" , argv[0]); + return 1 ; + } + + if (argc >= 2) { + params.model = argv[1]; + } + + if (argc >= 3) { + params.prompt = argv[2]; + } + + if (params.prompt.empty()) { + params.prompt = "Hello my name is"; + } + + // total length of the sequence including the prompt + const int n_len = 32; + + // init LLM + + llama_backend_init(); + llama_numa_init(params.numa); + + // initialize the model + + llama_model_params model_params = llama_model_default_params(); + + // model_params.n_gpu_layers = 99; // offload all layers to the GPU + + llama_model * model = llama_load_model_from_file(params.model.c_str(), model_params); + + if (model == NULL) { + fprintf(stderr , "%s: error: unable to load model\n" , __func__); + return 1; + } + + // initialize the context + + llama_context_params ctx_params = llama_context_default_params(); + + ctx_params.seed = 1234; + ctx_params.n_ctx = 2048; + ctx_params.n_threads = params.n_threads; + ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch; + + llama_context * ctx = llama_new_context_with_model(model, ctx_params); + + if (ctx == NULL) { + fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__); + return 1; + } + + // tokenize the prompt + + std::vector tokens_list; + tokens_list = ::llama_tokenize(ctx, params.prompt, true); + + // Roll back the last token and constrain tokens to generate in the next step to match the removed last token. + std::vector token_healing_candidates = heal_last_token(ctx, tokens_list); + if (!token_healing_candidates.empty()) { + tokens_list.pop_back(); + } + if (tokens_list.empty()) { + // If we remove the first token, llama_decode would crash with an empty sequence, so add bos. + tokens_list.emplace_back(llama_token_bos(model)); + } + + const int n_ctx = llama_n_ctx(ctx); + const int n_kv_req = tokens_list.size() + (n_len - tokens_list.size()); + + LOG_TEE("\n%s: n_len = %d, n_ctx = %d, n_kv_req = %d\n", __func__, n_len, n_ctx, n_kv_req); + + // make sure the KV cache is big enough to hold all the prompt and generated tokens + if (n_kv_req > n_ctx) { + LOG_TEE("%s: error: n_kv_req > n_ctx, the required KV cache size is not big enough\n", __func__); + LOG_TEE("%s: either reduce n_len or increase n_ctx\n", __func__); + return 1; + } + + // print the prompt token-by-token + + fprintf(stderr, "\n"); + + for (auto id : tokens_list) { + fprintf(stderr, "%s", llama_token_to_piece(ctx, id).c_str()); + } + + fflush(stderr); + + // create a llama_batch with size 512 + // we use this object to submit token data for decoding + + llama_batch batch = llama_batch_init(512, 0, 1); + + // evaluate the initial prompt + for (size_t i = 0; i < tokens_list.size(); i++) { + llama_batch_add(batch, tokens_list[i], i, { 0 }, false); + } + + // llama_decode will output logits only for the last token of the prompt + batch.logits[batch.n_tokens - 1] = true; + + if (llama_decode(ctx, batch) != 0) { + LOG_TEE("%s: llama_decode() failed\n", __func__); + return 1; + } + + // main loop + + int n_cur = batch.n_tokens; + int n_decode = 0; + + const auto t_main_start = ggml_time_us(); + + while (n_cur <= n_len) { + // sample the next token + { + auto n_vocab = llama_n_vocab(model); + auto * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1); + + std::vector candidates; + candidates.reserve(n_vocab); + + if (n_decode == 0 && !token_healing_candidates.empty()) { + for (const llama_token token_id : token_healing_candidates) { + candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f }); + } + } else { + for (llama_token token_id = 0; token_id < n_vocab; token_id++) { + candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f }); + } + } + + llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; + + // sample the most likely token + const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p); + + // is it an end of generation? + if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) { + LOG_TEE("\n"); + + break; + } + + LOG_TEE("%s", llama_token_to_piece(ctx, new_token_id).c_str()); + fflush(stdout); + + // prepare the next batch + llama_batch_clear(batch); + + // push this new token for next evaluation + llama_batch_add(batch, new_token_id, n_cur, { 0 }, true); + + n_decode += 1; + } + + n_cur += 1; + + // evaluate the current batch with the transformer model + if (llama_decode(ctx, batch)) { + fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1); + return 1; + } + } + + LOG_TEE("\n"); + + const auto t_main_end = ggml_time_us(); + + LOG_TEE("%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n", + __func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f)); + + llama_print_timings(ctx); + + fprintf(stderr, "\n"); + + llama_batch_free(batch); + + llama_free(ctx); + llama_free_model(model); + + llama_backend_free(); + + return 0; +} diff --git a/examples/simple-token-healing/simple-token-healing.cpp b/examples/simple-token-healing/simple-token-healing.cpp new file mode 100644 index 0000000000000..48f736a0e2f03 --- /dev/null +++ b/examples/simple-token-healing/simple-token-healing.cpp @@ -0,0 +1,353 @@ +#include "common.h" +#include "llama.h" + +#include +#include +#include +#include + +#define TH_VERBOSE // print token healing candidates + +enum class token_healing_type : uint8_t { + LAST, // replace last token only + MULTI_ONCE, // replace multiple last tokens with a single token + MULTI // replace multiple last tokens with multiple decoding steps +}; + +struct token_healing_context { + std::string prefix; // remaining prefix to generate (the input prompt's suffix) + + std::vector vocab; // map token id to token piece + // TODO consider using a prefix tree +}; + +static inline bool startswith(const std::string & str, const std::string & prefix) { + return str.rfind(prefix, 0) != std::string::npos; +} + +static bool token_healing_prefix_exists(const token_healing_context * th_ctx, const std::string & prefix) { + for (const std::string & token : th_ctx->vocab) { + if (startswith(token, prefix)) { + return true; + } + } + return false; +} + +static std::vector token_healing_find_prefix( + const token_healing_context * th_ctx, + const std::string & prefix, + const bool include_partial_prefix) { + // Example: prefix=" world" -> " world", " worldwide", ... + // If `include_partial_prefix`, include also: " w", " wo", ... + std::vector candidates; + const auto & vocab = th_ctx->vocab; + for (size_t token_id = 0; token_id < vocab.size(); ++token_id) { + if (startswith(vocab[token_id], prefix) + || (include_partial_prefix && startswith(prefix, vocab[token_id]))) { + candidates.push_back((llama_token)token_id); + } + } + return candidates; +} + +static token_healing_context * token_healing_init(const llama_context * ctx) { + auto * th_ctx = new token_healing_context; + const llama_model * model = llama_get_model(ctx); + const int32_t n_vocab = llama_n_vocab(model); + std::vector & vocab = th_ctx->vocab; + vocab.reserve(n_vocab); + for (llama_token token_id = 0; token_id < n_vocab; token_id++) { + vocab.emplace_back(llama_token_to_piece(ctx, token_id, true)); + } + return th_ctx; +} + +static void token_healing_free(token_healing_context * th_ctx) { + delete th_ctx; +} + +static int token_healing_start( + const llama_context * ctx, + std::vector & tokens_list, + const token_healing_type th_type, + token_healing_context * th_ctx) { + if (tokens_list.empty()) { + return 0; + } + const llama_model * model = llama_get_model(ctx); + const int n_ctx = tokens_list.size(); + const int max_to_remove = (th_type == token_healing_type::LAST) ? 1 : n_ctx; + int n_removed = 0; + std::string prefix; + // Backtrack tokens until there does not exist a token that can cover the prompt + while (n_removed < max_to_remove) { + const llama_token next_token = tokens_list[n_ctx - n_removed - 1]; + if (llama_token_get_type(model, next_token) != LLAMA_TOKEN_TYPE_NORMAL) { + // Don't roll back e.g. <|endoftext|> (if parse_special=true in llama_tokenize) + break; + } + std::string new_prefix = llama_token_to_piece(ctx, next_token) + prefix; + if (!token_healing_prefix_exists(th_ctx, new_prefix)) { + break; + } + n_removed += 1; + prefix = new_prefix; + } + th_ctx->prefix = prefix; + + if (n_removed == 0) { + return 0; + } + const std::vector candidates = token_healing_find_prefix(th_ctx, prefix, false); + fprintf(stderr, "token_healing: prefix = '%s' (%d tokens)\n", prefix.c_str(), n_removed); + if (n_removed == 1 && candidates.size() == 1) { + fprintf(stderr, "token_healing: nothing to heal\n"); + return 0; + } +#ifdef TH_VERBOSE + if (th_type != token_healing_type::MULTI) { + for (const llama_token token_id : candidates) { + fprintf(stderr, " [%6d] '%s'\n", token_id, th_ctx->vocab[token_id].c_str()); + } + } +#endif + for (int i = 0; i < n_removed; ++i) { + tokens_list.pop_back(); + } + if (tokens_list.empty()) { + // If the first token was removed, llama_decode would crash with an empty sequence, so add bos. + tokens_list.emplace_back(llama_token_bos(model)); + } + return n_removed; +} + +int main(int argc, char ** argv) { + gpt_params params; + + if (argc == 1 || argv[1][0] == '-') { + printf("usage: %s MODEL_PATH [PROMPT] [TOKEN_HEALING 0|1|n|m]\n" , argv[0]); + return 1 ; + } + + if (argc >= 2) { + params.model = argv[1]; + } + + if (argc >= 3) { + params.prompt = argv[2]; + } + + bool token_healing_enabled = true; + auto th_type = token_healing_type::LAST; + if (argc >= 4) { + std::string value(argv[3]); + /**/ if (value == "0") { token_healing_enabled = false; } + else if (value == "1") { th_type = token_healing_type::LAST; } + else if (value == "n") { th_type = token_healing_type::MULTI_ONCE; } + else if (value == "m") { th_type = token_healing_type::MULTI; } + else { + printf("usage: %s MODEL_PATH [PROMPT] [TOKEN_HEALING 0|1|n|m]\n" , argv[0]); + return 1; + } + } + + if (params.prompt.empty()) { + params.prompt = "Hello my name is"; + } + + // total length of the sequence including the prompt + const int n_len = 32; + + // init LLM + + llama_backend_init(); + llama_numa_init(params.numa); + + // initialize the model + + llama_model_params model_params = llama_model_default_params(); + + // model_params.n_gpu_layers = 99; // offload all layers to the GPU + + llama_model * model = llama_load_model_from_file(params.model.c_str(), model_params); + + if (model == NULL) { + fprintf(stderr , "%s: error: unable to load model\n" , __func__); + return 1; + } + + // initialize the context + + llama_context_params ctx_params = llama_context_default_params(); + + ctx_params.seed = 1234; + ctx_params.n_ctx = 2048; + ctx_params.n_threads = params.n_threads; + ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch; + + llama_context * ctx = llama_new_context_with_model(model, ctx_params); + + if (ctx == NULL) { + fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__); + return 1; + } + + // tokenize the prompt + + std::vector tokens_list; + tokens_list = ::llama_tokenize(ctx, params.prompt, true); + + token_healing_context * th_ctx = nullptr; + if (token_healing_enabled) { + th_ctx = token_healing_init(ctx); + int th_n_tokens_removed = token_healing_start(ctx, tokens_list, th_type, th_ctx); + if (th_n_tokens_removed == 0) { + token_healing_enabled = false; + } + } + + const int n_ctx = llama_n_ctx(ctx); + const int n_kv_req = tokens_list.size() + (n_len - tokens_list.size()); + + LOG_TEE("\n%s: n_len = %d, n_ctx = %d, n_kv_req = %d\n", __func__, n_len, n_ctx, n_kv_req); + + // make sure the KV cache is big enough to hold all the prompt and generated tokens + if (n_kv_req > n_ctx) { + LOG_TEE("%s: error: n_kv_req > n_ctx, the required KV cache size is not big enough\n", __func__); + LOG_TEE("%s: either reduce n_len or increase n_ctx\n", __func__); + return 1; + } + + // print the prompt token-by-token + + fprintf(stderr, "\n"); + + for (auto id : tokens_list) { + fprintf(stderr, "%s", llama_token_to_piece(ctx, id).c_str()); + } + + fflush(stderr); + + // create a llama_batch with size 512 + // we use this object to submit token data for decoding + + llama_batch batch = llama_batch_init(512, 0, 1); + + // evaluate the initial prompt + for (size_t i = 0; i < tokens_list.size(); i++) { + llama_batch_add(batch, tokens_list[i], i, { 0 }, false); + } + + // llama_decode will output logits only for the last token of the prompt + batch.logits[batch.n_tokens - 1] = true; + + if (llama_decode(ctx, batch) != 0) { + LOG_TEE("%s: llama_decode() failed\n", __func__); + return 1; + } + + // main loop + + int n_cur = batch.n_tokens; + int n_decode = 0; + + const auto t_main_start = ggml_time_us(); + + while (n_cur <= n_len) { + // sample the next token + { + auto n_vocab = llama_n_vocab(model); + auto * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1); + + std::vector candidates; + candidates.reserve(n_vocab); + + if (token_healing_enabled) { + // Constrain tokens based on the remaining token healing prefix + // N.B. We could also set token constraints by setting rejected tokens' logits to -inf + std::vector th_candidates; + if (th_type == token_healing_type::LAST || th_type == token_healing_type::MULTI_ONCE) { + th_candidates = token_healing_find_prefix(th_ctx, th_ctx->prefix, false); + } else { + th_candidates = token_healing_find_prefix(th_ctx, th_ctx->prefix, true); +#ifdef TH_VERBOSE + fprintf(stderr, "\ntoken_healing: prefix = '%s'\n", th_ctx->prefix.c_str()); + for (const llama_token token_id : th_candidates) { + fprintf(stderr, " [%6d] '%s'\n", token_id, th_ctx->vocab[token_id].c_str()); + } +#endif + } + for (const llama_token token_id: th_candidates) { + candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f }); + } + } else { + for (llama_token token_id = 0; token_id < n_vocab; token_id++) { + candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f }); + } + } + + llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; + + // sample the most likely token + const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p); + + // is it an end of generation? + if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) { + LOG_TEE("\n"); + break; + } + + std::string new_token_piece = llama_token_to_piece(ctx, new_token_id); + LOG_TEE("%s", new_token_piece.c_str()); + fflush(stdout); + + if (token_healing_enabled) { + if (new_token_piece.size() < th_ctx->prefix.size()) { + // Shift prefix constraint (for multi step token healing) + th_ctx->prefix = th_ctx->prefix.substr(new_token_piece.size()); + } else { + th_ctx->prefix.clear(); + token_healing_enabled = false; + } + } + + // prepare the next batch + llama_batch_clear(batch); + + // push this new token for next evaluation + llama_batch_add(batch, new_token_id, n_cur, { 0 }, true); + + n_decode += 1; + } + + n_cur += 1; + + // evaluate the current batch with the transformer model + if (llama_decode(ctx, batch)) { + fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1); + return 1; + } + } + + LOG_TEE("\n"); + + const auto t_main_end = ggml_time_us(); + + LOG_TEE("%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n", + __func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f)); + + llama_print_timings(ctx); + + fprintf(stderr, "\n"); + + token_healing_free(th_ctx); + llama_batch_free(batch); + + llama_free(ctx); + llama_free_model(model); + + llama_backend_free(); + + return 0; +} From 88ef908c9020b42ad829b8fa3af60ed6c9a1575d Mon Sep 17 00:00:00 2001 From: mare5x Date: Tue, 30 Apr 2024 20:04:35 +0200 Subject: [PATCH 2/6] examples : more roll back options for token healing --- examples/simple-token-healing/README.md | 51 +++++++++++--- .../simple-token-healing.cpp | 66 +++++++++++-------- 2 files changed, 83 insertions(+), 34 deletions(-) diff --git a/examples/simple-token-healing/README.md b/examples/simple-token-healing/README.md index 7e546986688de..533c118bd8c48 100644 --- a/examples/simple-token-healing/README.md +++ b/examples/simple-token-healing/README.md @@ -1,10 +1,13 @@ # llama.cpp/example/simple-token-healing -This example extends [simple](../simple/README.md) with [token healing](https://github.com/guidance-ai/guidance/blob/main/notebooks/art_of_prompt_design/prompt_boundaries_and_token_healing.ipynb). +This example extends [simple](../simple/README.md) with token healing (aka. token alignment). -Without token healing: +`usage: ./simple-token-healing MODEL_PATH [PROMPT] [TOKEN_HEALING 0|1|d1|d|r[N]]` + +## Examples +`0`: Without token healing (same as running `./simple ...`): ```bash -./simple ./models/phi-2/ggml-model-q4_0.gguf "print('Hel" +./simple-token-healing ./models/phi-2/ggml-model-q4_0.gguf "print('Hel" 0 ... main: n_len = 32, n_ctx = 2048, n_kv_req = 32 @@ -12,7 +15,7 @@ print('Helping the customer') ... ``` -Heal the last token (`1`): +`1`: Roll back the last token and constrain the bytes of the next token to start with the chopped off last token [0, 2]: ```bash ./simple-token-healing ./models/phi-2/ggml-model-q4_0.gguf "print('Hel" 1 ... @@ -29,9 +32,9 @@ print('Hello, World!') ... ``` -Backtrack multiple tokens until there doesn't exist a token which can cover the prompt's suffix (`n`): +`d1`: Roll back multiple tokens until there doesn't exist a token which can cover the prompt's suffix and do a single constrained decoding step [2]: ```bash -./simple-token-healing ./models/phi-2/ggml-model-q4_0.gguf "print('Hello, worl" n +./simple-token-healing ./models/phi-2/ggml-model-q4_0.gguf "print('Hello, worl" d1 ... token_healing: prefix = ' worl' (2 tokens) [ 995] ' world' @@ -46,9 +49,9 @@ print('Hello, world!') ... ``` -Backtrack multiple tokens but don't constrain the decoding to a single token (`m`): +`d`: Roll back multiple tokens until there doesn't exist a token which can cover the prompt's suffix but allow multiple decoding steps: ```bash -./simple-token-healing ./models/phi-2/ggml-model-q4_0.gguf "print('Hello, worl" m +./simple-token-healing ./models/phi-2/ggml-model-q4_0.gguf "print('Hello, worl" d ... token_healing: prefix = ' worl' (2 tokens) @@ -68,3 +71,35 @@ token_healing: prefix = ' worl' world!') ... ``` + +`r[N]`: Roll back `N` tokens and constrain the decoding to the bytes of those tokens (multiple decoding steps) [1]. +The paper [1] recommends `N=3`: +```bash +./simple-token-healing ./models/phi-2/ggml-model-q4_0.gguf "print('Hello, worl" r3 +... +token_healing: prefix = ', worl' (3 tokens) + +main: n_len = 32, n_ctx = 2048, n_kv_req = 32 + +print('Hello +token_healing: prefix = ', worl' + [ 11] ',' +, +token_healing: prefix = ' worl' + [ 220] ' ' + [ 266] ' w' + [ 476] ' wor' + [ 995] ' world' + [ 8688] ' worldwide' + [ 11621] ' worlds' + [ 24486] ' wo' + [ 29081] ' worldview' + [ 43249] ' worldly' + world!') +... +``` + +## Sources +- [0] https://github.com/guidance-ai/guidance/blob/main/notebooks/art_of_prompt_design/prompt_boundaries_and_token_healing.ipynb +- [1] https://arxiv.org/abs/2403.08688 +- [2] https://arxiv.org/abs/2402.01035 diff --git a/examples/simple-token-healing/simple-token-healing.cpp b/examples/simple-token-healing/simple-token-healing.cpp index 48f736a0e2f03..79b1693ad91d6 100644 --- a/examples/simple-token-healing/simple-token-healing.cpp +++ b/examples/simple-token-healing/simple-token-healing.cpp @@ -9,19 +9,20 @@ #define TH_VERBOSE // print token healing candidates enum class token_healing_type : uint8_t { - LAST, // replace last token only - MULTI_ONCE, // replace multiple last tokens with a single token - MULTI // replace multiple last tokens with multiple decoding steps + ROLLBACK_LAST, // roll back last token with a single constrained decoding step + ROLLBACK_MULTI, // roll back a fixed amount of tokens, multiple constrained decoding steps + DYNAMIC_ONCE, // dynamic roll back, single constrained decoding step + DYNAMIC_MULTI // dynamic roll back, multiple constrained decoding steps }; struct token_healing_context { std::string prefix; // remaining prefix to generate (the input prompt's suffix) - std::vector vocab; // map token id to token piece + std::vector vocab; // map token id to token piece // TODO consider using a prefix tree }; -static inline bool startswith(const std::string & str, const std::string & prefix) { +static bool startswith(const std::string & str, const std::string & prefix) { return str.rfind(prefix, 0) != std::string::npos; } @@ -67,28 +68,31 @@ static void token_healing_free(token_healing_context * th_ctx) { delete th_ctx; } -static int token_healing_start( +static int token_healing_heal( const llama_context * ctx, std::vector & tokens_list, const token_healing_type th_type, - token_healing_context * th_ctx) { + token_healing_context * th_ctx, + int n_rollback = 1) { if (tokens_list.empty()) { return 0; } const llama_model * model = llama_get_model(ctx); + const bool is_dynamic = th_type == token_healing_type::DYNAMIC_ONCE || th_type == token_healing_type::DYNAMIC_MULTI; const int n_ctx = tokens_list.size(); - const int max_to_remove = (th_type == token_healing_type::LAST) ? 1 : n_ctx; + const int max_to_remove = is_dynamic ? n_ctx : std::min(n_rollback, n_ctx); int n_removed = 0; std::string prefix; - // Backtrack tokens until there does not exist a token that can cover the prompt + // Roll back tokens a fixed amount or until there does not exist a token that can cover the prompt + // and stop early if a special token is encountered while (n_removed < max_to_remove) { - const llama_token next_token = tokens_list[n_ctx - n_removed - 1]; - if (llama_token_get_type(model, next_token) != LLAMA_TOKEN_TYPE_NORMAL) { + const llama_token next_token_id = tokens_list[n_ctx - n_removed - 1]; + if (llama_token_get_type(model, next_token_id) != LLAMA_TOKEN_TYPE_NORMAL) { // Don't roll back e.g. <|endoftext|> (if parse_special=true in llama_tokenize) break; } - std::string new_prefix = llama_token_to_piece(ctx, next_token) + prefix; - if (!token_healing_prefix_exists(th_ctx, new_prefix)) { + std::string new_prefix = th_ctx->vocab[next_token_id] + prefix; + if (is_dynamic && !token_healing_prefix_exists(th_ctx, new_prefix)) { break; } n_removed += 1; @@ -99,14 +103,17 @@ static int token_healing_start( if (n_removed == 0) { return 0; } - const std::vector candidates = token_healing_find_prefix(th_ctx, prefix, false); + // If constrained decoding would give back the original prompt, there is no need to modify the context + const bool is_multi_decoding = th_type == token_healing_type::DYNAMIC_MULTI || th_type == token_healing_type::ROLLBACK_MULTI; + const std::vector candidates = token_healing_find_prefix(th_ctx, prefix, is_multi_decoding); fprintf(stderr, "token_healing: prefix = '%s' (%d tokens)\n", prefix.c_str(), n_removed); if (n_removed == 1 && candidates.size() == 1) { fprintf(stderr, "token_healing: nothing to heal\n"); return 0; } #ifdef TH_VERBOSE - if (th_type != token_healing_type::MULTI) { + if (!is_multi_decoding) { + // Other healing types get printed during decoding for (const llama_token token_id : candidates) { fprintf(stderr, " [%6d] '%s'\n", token_id, th_ctx->vocab[token_id].c_str()); } @@ -126,8 +133,8 @@ int main(int argc, char ** argv) { gpt_params params; if (argc == 1 || argv[1][0] == '-') { - printf("usage: %s MODEL_PATH [PROMPT] [TOKEN_HEALING 0|1|n|m]\n" , argv[0]); - return 1 ; + printf("usage: %s MODEL_PATH [PROMPT] [TOKEN_HEALING 0|1|d1|d|r[N]]\n" , argv[0]); + return 1; } if (argc >= 2) { @@ -139,15 +146,22 @@ int main(int argc, char ** argv) { } bool token_healing_enabled = true; - auto th_type = token_healing_type::LAST; + auto th_type = token_healing_type::DYNAMIC_MULTI; + int th_n_rollback = 1; if (argc >= 4) { std::string value(argv[3]); - /**/ if (value == "0") { token_healing_enabled = false; } - else if (value == "1") { th_type = token_healing_type::LAST; } - else if (value == "n") { th_type = token_healing_type::MULTI_ONCE; } - else if (value == "m") { th_type = token_healing_type::MULTI; } - else { - printf("usage: %s MODEL_PATH [PROMPT] [TOKEN_HEALING 0|1|n|m]\n" , argv[0]); + /**/ if (value == "0" ) { token_healing_enabled = false; } + else if (value == "1" ) { th_type = token_healing_type::ROLLBACK_LAST; th_n_rollback = 1; } + else if (value == "d1") { th_type = token_healing_type::DYNAMIC_ONCE; } + else if (value == "d" ) { th_type = token_healing_type::DYNAMIC_MULTI; } + else if (value[0] == 'r' ) { + th_type = token_healing_type::ROLLBACK_MULTI; + th_n_rollback = std::stoi(value.substr(1)); + if (th_n_rollback <= 0) { + token_healing_enabled = false; + } + } else { + printf("usage: %s MODEL_PATH [PROMPT] [TOKEN_HEALING 0|1|d1|d|r[N]]\n" , argv[0]); return 1; } } @@ -201,7 +215,7 @@ int main(int argc, char ** argv) { token_healing_context * th_ctx = nullptr; if (token_healing_enabled) { th_ctx = token_healing_init(ctx); - int th_n_tokens_removed = token_healing_start(ctx, tokens_list, th_type, th_ctx); + int th_n_tokens_removed = token_healing_heal(ctx, tokens_list, th_type, th_ctx, th_n_rollback); if (th_n_tokens_removed == 0) { token_healing_enabled = false; } @@ -267,7 +281,7 @@ int main(int argc, char ** argv) { // Constrain tokens based on the remaining token healing prefix // N.B. We could also set token constraints by setting rejected tokens' logits to -inf std::vector th_candidates; - if (th_type == token_healing_type::LAST || th_type == token_healing_type::MULTI_ONCE) { + if (th_type == token_healing_type::ROLLBACK_LAST || th_type == token_healing_type::DYNAMIC_ONCE) { th_candidates = token_healing_find_prefix(th_ctx, th_ctx->prefix, false); } else { th_candidates = token_healing_find_prefix(th_ctx, th_ctx->prefix, true); From 951b6593b2ec2ebf27cf4dd03c81f8a35326c358 Mon Sep 17 00:00:00 2001 From: mare5x Date: Fri, 3 May 2024 13:50:31 +0200 Subject: [PATCH 3/6] main : first attempt at token healing in `main` --- common/common.cpp | 25 ++++ common/sampling.cpp | 136 +++++++++++++++++- common/sampling.h | 23 +++ examples/main/main.cpp | 7 + .../simple-token-healing.cpp | 33 ++--- 5 files changed, 200 insertions(+), 24 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 243b88abf1aab..7f1d136053511 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1288,6 +1288,28 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa sparams.grammar = json_schema_to_grammar(json::parse(argv[i])); return true; } + if (arg == "-th" || arg == "--token-healing") { + if (++i >= argc) { + invalid_param = true; + return true; + } + sparams.token_healing_enabled = true; + auto & th_type = sparams.token_healing_type; + auto & th_n_rollback = sparams.token_healing_n_rollback; + std::string value(argv[i]); + /**/ if (value == "0" ) { sparams.token_healing_enabled = false; } + else if (value == "1" ) { th_type = llama_token_healing_type::ROLLBACK_LAST; th_n_rollback = 1; } + else if (value == "d1") { th_type = llama_token_healing_type::DYNAMIC_ONCE; } + else if (value == "d" ) { th_type = llama_token_healing_type::DYNAMIC_MULTI; } + else if (value[0] == 'r' ) { + th_type = llama_token_healing_type::ROLLBACK_MULTI; + th_n_rollback = std::stoi(value.substr(1)); + if (th_n_rollback <= 0) { + sparams.token_healing_enabled = false; + } + } else { invalid_param = true; } + return true; + } if (arg == "--override-kv") { if (++i >= argc) { invalid_param = true; @@ -1480,6 +1502,9 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" -j SCHEMA, --json-schema SCHEMA\n"); printf(" JSON schema to constrain generations (https://json-schema.org/), e.g. `{}` for any JSON object.\n"); printf(" For schemas w/ external $refs, use --grammar + example/json_schema_to_grammar.py instead\n"); + printf(" -th {0,1,d1,d,r{N}}, --token-healing {0,1,d1,d,r{N}}\n"); + printf(" Token healing type. (default: 0, disabled)\n"); + printf(" 1: replace one token, d1: replace longest suffix with one token, d: replace longest suffix, r{N}: roll back N tokens\n"); printf(" --cfg-negative-prompt PROMPT\n"); printf(" negative prompt to use for guidance. (default: empty)\n"); printf(" --cfg-negative-prompt-file FNAME\n"); diff --git a/common/sampling.cpp b/common/sampling.cpp index cc83600d9926e..5549369e8b74b 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -2,6 +2,96 @@ #include "sampling.h" #include +// +// Token healing (internal) +// + +static bool startswith(const std::string & str, const std::string & prefix) { + return str.rfind(prefix, 0) != std::string::npos; +} + +static bool token_healing_prefix_exists(const llama_context * ctx_main, const std::string & prefix) { + const int32_t n_vocab = llama_n_vocab(llama_get_model(ctx_main)); + for (llama_token token_id = 0; token_id < n_vocab; ++token_id) { + if (startswith(llama_token_to_piece(ctx_main, token_id), prefix)) { + return true; + } + } + return false; +} + +static std::vector token_healing_find_prefix( + const llama_context * ctx_main, + const std::string & prefix, + const bool include_partial_prefix) { + // Example: prefix=" world" -> " world", " worldwide", ... + // If `include_partial_prefix`, include also: " w", " wo", ... + std::vector candidates; + const int32_t n_vocab = llama_n_vocab(llama_get_model(ctx_main)); + for (llama_token token_id = 0; token_id < n_vocab; ++token_id) { + std::string token = llama_token_to_piece(ctx_main, token_id); + if (startswith(token, prefix) || + (include_partial_prefix && startswith(prefix, token))) { + candidates.push_back(token_id); + } + } + return candidates; +} + +// +// Token healing (external) +// + +std::string llama_token_healing_prepare( + const llama_context * ctx_main, + llama_token_healing_type th_type, + std::vector & tokens, + int n_rollback) { + if (tokens.empty()) { + return ""; + } + const llama_model * model = llama_get_model(ctx_main); + const bool is_dynamic = th_type == llama_token_healing_type::DYNAMIC_ONCE || th_type == llama_token_healing_type::DYNAMIC_MULTI; + const int n_ctx = tokens.size(); + const int max_to_remove = is_dynamic ? n_ctx : std::min(n_rollback, n_ctx); + int n_removed = 0; + std::string prefix; + // Roll back tokens a fixed amount or until there does not exist a token that can cover the prompt + // and stop early if a special token is encountered + while (n_removed < max_to_remove) { + const llama_token next_token_id = tokens[n_ctx - n_removed - 1]; + if (llama_token_get_type(model, next_token_id) != LLAMA_TOKEN_TYPE_NORMAL) { + // Don't roll back e.g. <|endoftext|> (if parse_special=true in llama_tokenize) + break; + } + std::string new_prefix = llama_token_to_piece(ctx_main, next_token_id) + prefix; + if (is_dynamic && !token_healing_prefix_exists(ctx_main, new_prefix)) { + break; + } + n_removed += 1; + prefix = new_prefix; + } + + if (n_removed == 0) { // E.g. if the last token is a special token + return ""; + } + // If constrained decoding would give back the original prompt, there is no need to modify the context + const bool is_multi_step = th_type == llama_token_healing_type::ROLLBACK_MULTI || + th_type == llama_token_healing_type::DYNAMIC_MULTI; + const std::vector candidates = token_healing_find_prefix(ctx_main, prefix, is_multi_step); + LOG("token_healing: prefix = '%s' (%d tokens)\n", prefix.c_str(), n_removed); + if (n_removed == 1 && candidates.size() == 1) { + LOG("token_healing: nothing to heal\n"); + return ""; + } + tokens.resize(n_ctx - n_removed); + return prefix; +} + +// +// Sampling +// + struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params) { struct llama_sampling_context * result = new llama_sampling_context(); @@ -33,6 +123,8 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_ grammar_rules.size(), result->parsed_grammar.symbol_ids.at("root")); } + result->token_healing_prefix.clear(); + result->prev.resize(params.n_prev); llama_sampling_set_rng_seed(result, params.seed); @@ -62,6 +154,8 @@ void llama_sampling_reset(llama_sampling_context * ctx) { grammar_rules.size(), ctx->parsed_grammar.symbol_ids.at("root")); } + ctx->token_healing_prefix.clear(); + std::fill(ctx->prev.begin(), ctx->prev.end(), 0); ctx->cur.clear(); } @@ -119,7 +213,7 @@ std::string llama_sampling_print(const llama_sampling_params & params) { } std::string llama_sampling_order_print(const llama_sampling_params & params) { - std::string result = "CFG -> Penalties "; + std::string result = "(Token healing) -> CFG -> Penalties "; if (params.mirostat == 0) { for (auto sampler_type : params.samplers_sequence) { const auto sampler_type_name = sampler_type_to_name_string(sampler_type); @@ -297,12 +391,33 @@ static llama_token_data_array llama_sampling_prepare_impl( cur.clear(); - for (llama_token token_id = 0; token_id < n_vocab; token_id++) { - cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); + // Constrain tokens based on the remaining token healing prefix (if any) + const auto & th_type = params.token_healing_type; + const auto & th_prefix = ctx_sampling->token_healing_prefix; + if (params.token_healing_enabled && !th_prefix.empty()) { + const bool is_multi_step = th_type == llama_token_healing_type::ROLLBACK_MULTI || + th_type == llama_token_healing_type::DYNAMIC_MULTI; + std::vector th_candidates = token_healing_find_prefix(ctx_main, th_prefix, is_multi_step); + + LOG("token_healing: prefix = '%s'\n", th_prefix.c_str()); + for (const llama_token token_id : th_candidates) { + LOG(" [%6d] '%s'\n", token_id, llama_token_to_piece(ctx_main, token_id).c_str()); + } + + // N.B. We could also set token constraints by setting rejected tokens' logits to -inf + for (const llama_token token_id: th_candidates) { + cur.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f }); + } + } else { + for (llama_token token_id = 0; token_id < n_vocab; token_id++) { + cur.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f }); + } } llama_token_data_array cur_p = { cur.data(), cur.size(), false }; + // TODO should we skip penalties and grammar while token healing? + // apply penalties const auto& penalty_tokens = params.use_penalty_prompt_tokens ? params.penalty_prompt_tokens : prev; const int penalty_tokens_used_size = std::min((int)penalty_tokens.size(), penalty_last_n); @@ -361,4 +476,19 @@ void llama_sampling_accept( if (ctx_sampling->grammar != NULL && apply_grammar) { llama_grammar_accept_token(ctx_main, ctx_sampling->grammar, id); } + + if (ctx_sampling->params.token_healing_enabled && apply_grammar) { + std::string & th_prefix = ctx_sampling->token_healing_prefix; + if (!th_prefix.empty()) { + const std::string new_token_piece = llama_token_to_piece(ctx_main, id); + if (new_token_piece.size() < th_prefix.size()) { + // Shift prefix constraint (for multi step token healing) + th_prefix = th_prefix.substr(new_token_piece.size()); + } else { + // Prefix has been generated => no more constrained generation + th_prefix.clear(); + LOG("token_healing: done\n"); + } + } + } } diff --git a/common/sampling.h b/common/sampling.h index cf7081e3674f1..e2b870f00531b 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -19,6 +19,13 @@ enum class llama_sampler_type : char { TEMPERATURE = 't' }; +enum class llama_token_healing_type : uint8_t { + ROLLBACK_LAST, // roll back last token with a single constrained decoding step + ROLLBACK_MULTI, // roll back a fixed amount of tokens, multiple constrained decoding steps + DYNAMIC_ONCE, // dynamic roll back, single constrained decoding step + DYNAMIC_MULTI // dynamic roll back, multiple constrained decoding steps +}; + // sampling parameters typedef struct llama_sampling_params { int32_t n_prev = 64; // number of previous tokens to remember @@ -62,6 +69,10 @@ typedef struct llama_sampling_params { std::vector penalty_prompt_tokens; bool use_penalty_prompt_tokens = false; + + llama_token_healing_type token_healing_type = llama_token_healing_type::ROLLBACK_LAST; + bool token_healing_enabled = false; + int token_healing_n_rollback = 1; // number of tokens to roll back } llama_sampling_params; // general sampler context @@ -78,6 +89,8 @@ struct llama_sampling_context { // internal grammar_parser::parse_state parsed_grammar; + std::string token_healing_prefix; + // TODO: replace with ring-buffer std::vector prev; std::vector cur; @@ -152,3 +165,13 @@ void llama_sampling_accept( struct llama_context * ctx_main, llama_token id, bool apply_grammar); + +// +// Token healing +// + +std::string llama_token_healing_prepare( + const llama_context * ctx_main, + llama_token_healing_type th_type, + std::vector & tokens, + int n_rollback = 1); diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 5c693657c8993..c9e6d2de9dfc1 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -264,6 +264,12 @@ int main(int argc, char ** argv) { LOG("prompt: \"%s\"\n", log_tostr(params.prompt)); LOG("tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str()); + std::string token_healing_prefix; + if (sparams.token_healing_enabled) { + token_healing_prefix = llama_token_healing_prepare(ctx, sparams.token_healing_type, embd_inp, + sparams.token_healing_n_rollback); + } + // Should not run without any tokens if (embd_inp.empty()) { embd_inp.push_back(llama_token_bos(model)); @@ -520,6 +526,7 @@ int main(int argc, char ** argv) { } struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams); + ctx_sampling->token_healing_prefix = token_healing_prefix; while ((n_remain != 0 && !is_antiprompt) || params.interactive) { // predict diff --git a/examples/simple-token-healing/simple-token-healing.cpp b/examples/simple-token-healing/simple-token-healing.cpp index 79b1693ad91d6..05091b9c33c62 100644 --- a/examples/simple-token-healing/simple-token-healing.cpp +++ b/examples/simple-token-healing/simple-token-healing.cpp @@ -8,13 +8,6 @@ #define TH_VERBOSE // print token healing candidates -enum class token_healing_type : uint8_t { - ROLLBACK_LAST, // roll back last token with a single constrained decoding step - ROLLBACK_MULTI, // roll back a fixed amount of tokens, multiple constrained decoding steps - DYNAMIC_ONCE, // dynamic roll back, single constrained decoding step - DYNAMIC_MULTI // dynamic roll back, multiple constrained decoding steps -}; - struct token_healing_context { std::string prefix; // remaining prefix to generate (the input prompt's suffix) @@ -44,8 +37,8 @@ static std::vector token_healing_find_prefix( std::vector candidates; const auto & vocab = th_ctx->vocab; for (size_t token_id = 0; token_id < vocab.size(); ++token_id) { - if (startswith(vocab[token_id], prefix) - || (include_partial_prefix && startswith(prefix, vocab[token_id]))) { + if (startswith(vocab[token_id], prefix) || + (include_partial_prefix && startswith(prefix, vocab[token_id]))) { candidates.push_back((llama_token)token_id); } } @@ -71,14 +64,14 @@ static void token_healing_free(token_healing_context * th_ctx) { static int token_healing_heal( const llama_context * ctx, std::vector & tokens_list, - const token_healing_type th_type, + const llama_token_healing_type th_type, token_healing_context * th_ctx, int n_rollback = 1) { if (tokens_list.empty()) { return 0; } const llama_model * model = llama_get_model(ctx); - const bool is_dynamic = th_type == token_healing_type::DYNAMIC_ONCE || th_type == token_healing_type::DYNAMIC_MULTI; + const bool is_dynamic = th_type == llama_token_healing_type::DYNAMIC_ONCE || th_type == llama_token_healing_type::DYNAMIC_MULTI; const int n_ctx = tokens_list.size(); const int max_to_remove = is_dynamic ? n_ctx : std::min(n_rollback, n_ctx); int n_removed = 0; @@ -104,7 +97,7 @@ static int token_healing_heal( return 0; } // If constrained decoding would give back the original prompt, there is no need to modify the context - const bool is_multi_decoding = th_type == token_healing_type::DYNAMIC_MULTI || th_type == token_healing_type::ROLLBACK_MULTI; + const bool is_multi_decoding = th_type == llama_token_healing_type::DYNAMIC_MULTI || th_type == llama_token_healing_type::ROLLBACK_MULTI; const std::vector candidates = token_healing_find_prefix(th_ctx, prefix, is_multi_decoding); fprintf(stderr, "token_healing: prefix = '%s' (%d tokens)\n", prefix.c_str(), n_removed); if (n_removed == 1 && candidates.size() == 1) { @@ -119,9 +112,7 @@ static int token_healing_heal( } } #endif - for (int i = 0; i < n_removed; ++i) { - tokens_list.pop_back(); - } + tokens_list.resize(n_ctx - n_removed); if (tokens_list.empty()) { // If the first token was removed, llama_decode would crash with an empty sequence, so add bos. tokens_list.emplace_back(llama_token_bos(model)); @@ -146,16 +137,16 @@ int main(int argc, char ** argv) { } bool token_healing_enabled = true; - auto th_type = token_healing_type::DYNAMIC_MULTI; + auto th_type = llama_token_healing_type::DYNAMIC_MULTI; int th_n_rollback = 1; if (argc >= 4) { std::string value(argv[3]); /**/ if (value == "0" ) { token_healing_enabled = false; } - else if (value == "1" ) { th_type = token_healing_type::ROLLBACK_LAST; th_n_rollback = 1; } - else if (value == "d1") { th_type = token_healing_type::DYNAMIC_ONCE; } - else if (value == "d" ) { th_type = token_healing_type::DYNAMIC_MULTI; } + else if (value == "1" ) { th_type = llama_token_healing_type::ROLLBACK_LAST; th_n_rollback = 1; } + else if (value == "d1") { th_type = llama_token_healing_type::DYNAMIC_ONCE; } + else if (value == "d" ) { th_type = llama_token_healing_type::DYNAMIC_MULTI; } else if (value[0] == 'r' ) { - th_type = token_healing_type::ROLLBACK_MULTI; + th_type = llama_token_healing_type::ROLLBACK_MULTI; th_n_rollback = std::stoi(value.substr(1)); if (th_n_rollback <= 0) { token_healing_enabled = false; @@ -281,7 +272,7 @@ int main(int argc, char ** argv) { // Constrain tokens based on the remaining token healing prefix // N.B. We could also set token constraints by setting rejected tokens' logits to -inf std::vector th_candidates; - if (th_type == token_healing_type::ROLLBACK_LAST || th_type == token_healing_type::DYNAMIC_ONCE) { + if (th_type == llama_token_healing_type::ROLLBACK_LAST || th_type == llama_token_healing_type::DYNAMIC_ONCE) { th_candidates = token_healing_find_prefix(th_ctx, th_ctx->prefix, false); } else { th_candidates = token_healing_find_prefix(th_ctx, th_ctx->prefix, true); From 7d0cc78bc32725fc4ffb29bf854754575990039a Mon Sep 17 00:00:00 2001 From: mare5x Date: Fri, 3 May 2024 19:50:00 +0200 Subject: [PATCH 4/6] main : better token healing support for interactive mode --- common/common.cpp | 2 +- common/sampling.cpp | 31 ++++++++++++++++++++----------- common/sampling.h | 5 +++-- examples/main/main.cpp | 23 +++++++++++++++++++++-- 4 files changed, 45 insertions(+), 16 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 7f1d136053511..b75cfdf952365 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1298,7 +1298,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa auto & th_n_rollback = sparams.token_healing_n_rollback; std::string value(argv[i]); /**/ if (value == "0" ) { sparams.token_healing_enabled = false; } - else if (value == "1" ) { th_type = llama_token_healing_type::ROLLBACK_LAST; th_n_rollback = 1; } + else if (value == "1" ) { th_type = llama_token_healing_type::ROLLBACK_LAST; } else if (value == "d1") { th_type = llama_token_healing_type::DYNAMIC_ONCE; } else if (value == "d" ) { th_type = llama_token_healing_type::DYNAMIC_MULTI; } else if (value[0] == 'r' ) { diff --git a/common/sampling.cpp b/common/sampling.cpp index 5549369e8b74b..03c2664bb35ec 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -46,20 +46,26 @@ std::string llama_token_healing_prepare( const llama_context * ctx_main, llama_token_healing_type th_type, std::vector & tokens, - int n_rollback) { + int max_to_remove, + int * n_removed) { + if (n_removed != nullptr) { + *n_removed = 0; + } if (tokens.empty()) { return ""; } + const llama_model * model = llama_get_model(ctx_main); const bool is_dynamic = th_type == llama_token_healing_type::DYNAMIC_ONCE || th_type == llama_token_healing_type::DYNAMIC_MULTI; const int n_ctx = tokens.size(); - const int max_to_remove = is_dynamic ? n_ctx : std::min(n_rollback, n_ctx); - int n_removed = 0; + max_to_remove = th_type == llama_token_healing_type::ROLLBACK_LAST ? 1 : max_to_remove; + max_to_remove = max_to_remove < 0 ? n_ctx : std::min(max_to_remove, n_ctx); + int removed = 0; std::string prefix; // Roll back tokens a fixed amount or until there does not exist a token that can cover the prompt // and stop early if a special token is encountered - while (n_removed < max_to_remove) { - const llama_token next_token_id = tokens[n_ctx - n_removed - 1]; + while (removed < max_to_remove) { + const llama_token next_token_id = tokens[n_ctx - removed - 1]; if (llama_token_get_type(model, next_token_id) != LLAMA_TOKEN_TYPE_NORMAL) { // Don't roll back e.g. <|endoftext|> (if parse_special=true in llama_tokenize) break; @@ -68,23 +74,26 @@ std::string llama_token_healing_prepare( if (is_dynamic && !token_healing_prefix_exists(ctx_main, new_prefix)) { break; } - n_removed += 1; + removed += 1; prefix = new_prefix; } - - if (n_removed == 0) { // E.g. if the last token is a special token + if (removed == 0) { // E.g. if the last token is a special token return ""; } // If constrained decoding would give back the original prompt, there is no need to modify the context const bool is_multi_step = th_type == llama_token_healing_type::ROLLBACK_MULTI || th_type == llama_token_healing_type::DYNAMIC_MULTI; const std::vector candidates = token_healing_find_prefix(ctx_main, prefix, is_multi_step); - LOG("token_healing: prefix = '%s' (%d tokens)\n", prefix.c_str(), n_removed); - if (n_removed == 1 && candidates.size() == 1) { + LOG("token_healing: prefix = '%s' (%d tokens)\n", prefix.c_str(), removed); + if (removed == 1 && candidates.size() == 1) { LOG("token_healing: nothing to heal\n"); return ""; } - tokens.resize(n_ctx - n_removed); + // Finalize outputs + if (n_removed != nullptr) { + *n_removed = removed; + } + tokens.resize(n_ctx - removed); return prefix; } diff --git a/common/sampling.h b/common/sampling.h index e2b870f00531b..90198bec98f9f 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -72,7 +72,7 @@ typedef struct llama_sampling_params { llama_token_healing_type token_healing_type = llama_token_healing_type::ROLLBACK_LAST; bool token_healing_enabled = false; - int token_healing_n_rollback = 1; // number of tokens to roll back + int token_healing_n_rollback = -1; // number of tokens to roll back } llama_sampling_params; // general sampler context @@ -174,4 +174,5 @@ std::string llama_token_healing_prepare( const llama_context * ctx_main, llama_token_healing_type th_type, std::vector & tokens, - int n_rollback = 1); + int max_to_remove = -1, + int * n_removed = nullptr); diff --git a/examples/main/main.cpp b/examples/main/main.cpp index c9e6d2de9dfc1..aedc40334cb0a 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -264,8 +264,12 @@ int main(int argc, char ** argv) { LOG("prompt: \"%s\"\n", log_tostr(params.prompt)); LOG("tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str()); + if (sparams.token_healing_enabled && (params.instruct || params.chatml || !params.input_suffix.empty())) { + sparams.token_healing_enabled = false; + LOG("token_healing: disabled due to custom suffix"); + } std::string token_healing_prefix; - if (sparams.token_healing_enabled) { + if (!params.interactive_first && sparams.token_healing_enabled) { token_healing_prefix = llama_token_healing_prepare(ctx, sparams.token_healing_type, embd_inp, sparams.token_healing_n_rollback); } @@ -820,6 +824,7 @@ int main(int argc, char ** argv) { } } + int token_healing_n_removed = 0; if (n_past > 0 && is_interacting) { LOG("waiting for user input\n"); @@ -903,13 +908,23 @@ int main(int argc, char ** argv) { embd_inp.insert(embd_inp.end(), cml_sfx.begin(), cml_sfx.end()); } + if (sparams.token_healing_enabled) { + // Limit token healing rollback to new tokens only (otherwise would need to shift everything) + const int n_new_tokens = embd_inp.size() - original_size; + const int max_to_remove = sparams.token_healing_n_rollback < 0 + ? n_new_tokens + : std::min(sparams.token_healing_n_rollback, n_new_tokens); + token_healing_prefix = llama_token_healing_prepare(ctx, sparams.token_healing_type, embd_inp, + max_to_remove, &token_healing_n_removed); + } + for (size_t i = original_size; i < embd_inp.size(); ++i) { const llama_token token = embd_inp[i]; output_tokens.push_back(token); output_ss << llama_token_to_piece(ctx, token); } - n_remain -= line_inp.size(); + n_remain -= line_inp.size() + token_healing_n_removed; LOG("n_remain: %d\n", n_remain); } else { LOG("empty line, passing control back\n"); @@ -921,6 +936,10 @@ int main(int argc, char ** argv) { if (n_past > 0) { if (is_interacting) { llama_sampling_reset(ctx_sampling); + if (token_healing_n_removed > 0) { + // Set new prefix after an interaction + ctx_sampling->token_healing_prefix = token_healing_prefix; + } } is_interacting = false; } From d4cbccb1034e9fe42b07aaa608c254268ae66e94 Mon Sep 17 00:00:00 2001 From: mare5x Date: Fri, 3 May 2024 21:56:11 +0200 Subject: [PATCH 5/6] main : skip printing token healing prefix twice --- examples/main/main.cpp | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/examples/main/main.cpp b/examples/main/main.cpp index aedc40334cb0a..fd26fc380b26d 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -509,6 +509,7 @@ int main(int argc, char ** argv) { int n_consumed = 0; int n_session_consumed = 0; int n_past_guidance = 0; + int n_bytes_to_skip = 0; // to skip printing when generating token healing prefix std::vector input_tokens; g_input_tokens = &input_tokens; std::vector output_tokens; g_output_tokens = &output_tokens; @@ -745,7 +746,16 @@ int main(int argc, char ** argv) { if (input_echo && display) { for (auto id : embd) { const std::string token_str = llama_token_to_piece(ctx, id); - printf("%s", token_str.c_str()); + + // Suppress printing while generating token healing prefix (only for interactive mode; kinda hacky...) + if (n_bytes_to_skip > 0 && n_bytes_to_skip < (int)token_str.size()) { + printf("%s", token_str.substr(n_bytes_to_skip).c_str()); + n_bytes_to_skip = 0; + } else if (n_bytes_to_skip > 0) { + n_bytes_to_skip -= token_str.size(); + } else { + printf("%s", token_str.c_str()); + } if (embd.size() > 1) { input_tokens.push_back(id); @@ -939,6 +949,7 @@ int main(int argc, char ** argv) { if (token_healing_n_removed > 0) { // Set new prefix after an interaction ctx_sampling->token_healing_prefix = token_healing_prefix; + n_bytes_to_skip = ctx_sampling->token_healing_prefix.size(); } } is_interacting = false; From 7b6fdc28191ba6505918e2f0a6a493c8e1a6b963 Mon Sep 17 00:00:00 2001 From: mare5x Date: Mon, 6 May 2024 21:25:12 +0200 Subject: [PATCH 6/6] main : small token healing cleanup --- common/sampling.cpp | 8 ++++---- common/sampling.h | 4 ++++ examples/main/main.cpp | 13 +++++++------ 3 files changed, 15 insertions(+), 10 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index 03c2664bb35ec..7e7bf5ea1d144 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -97,6 +97,10 @@ std::string llama_token_healing_prepare( return prefix; } +void llama_token_healing_set_prefix(llama_sampling_context * ctx_sampling, const std::string & prefix) { + ctx_sampling->token_healing_prefix = prefix; +} + // // Sampling // @@ -132,8 +136,6 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_ grammar_rules.size(), result->parsed_grammar.symbol_ids.at("root")); } - result->token_healing_prefix.clear(); - result->prev.resize(params.n_prev); llama_sampling_set_rng_seed(result, params.seed); @@ -425,8 +427,6 @@ static llama_token_data_array llama_sampling_prepare_impl( llama_token_data_array cur_p = { cur.data(), cur.size(), false }; - // TODO should we skip penalties and grammar while token healing? - // apply penalties const auto& penalty_tokens = params.use_penalty_prompt_tokens ? params.penalty_prompt_tokens : prev; const int penalty_tokens_used_size = std::min((int)penalty_tokens.size(), penalty_last_n); diff --git a/common/sampling.h b/common/sampling.h index 90198bec98f9f..2aa7bc2bdd8b1 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -170,9 +170,13 @@ void llama_sampling_accept( // Token healing // +// Roll back `tokens` for constrained generation according to the token healing +// strategy. Returns the prefix for constrained generation. std::string llama_token_healing_prepare( const llama_context * ctx_main, llama_token_healing_type th_type, std::vector & tokens, int max_to_remove = -1, int * n_removed = nullptr); + +void llama_token_healing_set_prefix(llama_sampling_context * ctx_sampling, const std::string & prefix); diff --git a/examples/main/main.cpp b/examples/main/main.cpp index fd26fc380b26d..70834b01a8eca 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -269,9 +269,10 @@ int main(int argc, char ** argv) { LOG("token_healing: disabled due to custom suffix"); } std::string token_healing_prefix; + int token_healing_n_removed = 0; if (!params.interactive_first && sparams.token_healing_enabled) { token_healing_prefix = llama_token_healing_prepare(ctx, sparams.token_healing_type, embd_inp, - sparams.token_healing_n_rollback); + sparams.token_healing_n_rollback, &token_healing_n_removed); } // Should not run without any tokens @@ -293,7 +294,7 @@ int main(int argc, char ** argv) { std::vector original_inp = ::llama_tokenize(ctx, params.prompt, true, true); LOG("original_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, original_inp).c_str()); - original_prompt_len = original_inp.size(); + original_prompt_len = original_inp.size() - token_healing_n_removed; guidance_offset = (int)guidance_inp.size() - original_prompt_len; LOG("original_prompt_len: %s", log_tostr(original_prompt_len)); LOG("guidance_offset: %s", log_tostr(guidance_offset)); @@ -531,7 +532,7 @@ int main(int argc, char ** argv) { } struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams); - ctx_sampling->token_healing_prefix = token_healing_prefix; + llama_token_healing_set_prefix(ctx_sampling, token_healing_prefix); while ((n_remain != 0 && !is_antiprompt) || params.interactive) { // predict @@ -834,7 +835,7 @@ int main(int argc, char ** argv) { } } - int token_healing_n_removed = 0; + token_healing_n_removed = 0; if (n_past > 0 && is_interacting) { LOG("waiting for user input\n"); @@ -926,6 +927,7 @@ int main(int argc, char ** argv) { : std::min(sparams.token_healing_n_rollback, n_new_tokens); token_healing_prefix = llama_token_healing_prepare(ctx, sparams.token_healing_type, embd_inp, max_to_remove, &token_healing_n_removed); + n_bytes_to_skip = token_healing_prefix.size(); } for (size_t i = original_size; i < embd_inp.size(); ++i) { @@ -948,8 +950,7 @@ int main(int argc, char ** argv) { llama_sampling_reset(ctx_sampling); if (token_healing_n_removed > 0) { // Set new prefix after an interaction - ctx_sampling->token_healing_prefix = token_healing_prefix; - n_bytes_to_skip = ctx_sampling->token_healing_prefix.size(); + llama_token_healing_set_prefix(ctx_sampling, token_healing_prefix); } } is_interacting = false;