From 8c9784c65ddabc8043ccba249c17a83c6a29c334 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Sat, 20 Jul 2024 13:38:59 +0200 Subject: [PATCH] lookup: single sequence -> tree of sequences --- common/ngram-cache.cpp | 235 +++++++++++++++++++++++-------- common/ngram-cache.h | 3 +- examples/lookup/lookup-stats.cpp | 37 +++-- examples/lookup/lookup.cpp | 139 +++++++++++++----- 4 files changed, 308 insertions(+), 106 deletions(-) diff --git a/common/ngram-cache.cpp b/common/ngram-cache.cpp index 3ca112ef1613d..81f8ff755ccc3 100644 --- a/common/ngram-cache.cpp +++ b/common/ngram-cache.cpp @@ -52,52 +52,101 @@ static llama_token get_token(const std::vector & inp, const std::ve return i < inp.size() ? inp[i] : draft[1 + i - inp.size()]; } -// If sample size or percentage are below these thresholds the draft is aborted early: -constexpr int draft_min_sample_size_lax[LLAMA_NGRAM_MAX] = { 2, 2, 1, 1}; -constexpr int draft_min_percent_lax[LLAMA_NGRAM_MAX] = {66, 50, 50, 50}; +// Sample size and percentage must meet these thresholds to be added to the draft tree: +constexpr int draft_min_sample_size_lax[LLAMA_NGRAM_MAX] = { 1, 1, 1, 1}; +constexpr int draft_min_percent_lax[LLAMA_NGRAM_MAX] = {20, 20, 10, 10}; constexpr int draft_min_sample_size_strict[LLAMA_NGRAM_MAX] = { 4, 3, 2, 2}; -constexpr int draft_min_percent_strict[LLAMA_NGRAM_MAX] = {75, 66, 66, 66}; +constexpr int draft_min_percent_strict[LLAMA_NGRAM_MAX] = {50, 50, 50, 50}; + +struct draft_candidate { + llama_draft_t draft; + float nll; + int nsampled; +}; + +struct compare_draft_candidate { + bool operator()(const draft_candidate & a, const draft_candidate & b){ + if (a.nsampled > b.nsampled) { + return true; + } + if (a.nsampled < b.nsampled) { + return false; + } + return a.nll < b.nll; + } +}; + +// Helper function that tries to draft tokens from only the static ngram cache: +static void try_draft( + llama_ngram_cache & nc_static, const llama_ngram & ngram_static, + const int * min_sample_size, const int * min_percent, const draft_candidate & cp, + const int ngram_min, std::vector & drafts_new) { + + const int nsc = (ngram_min + LLAMA_NGRAM_STATIC) - (cp.draft.size() - 1); + if (nsc < (ngram_min + LLAMA_NGRAM_STATIC + 1)/2) { + return; + } -// Helper function that tries to draft a token from only the static ngram cache: -static llama_token try_draft(llama_ngram_cache & nc_static, const llama_ngram ngram_static) { llama_ngram_cache::iterator part_static_it = nc_static.find(ngram_static); if (part_static_it == nc_static.end()) { - return -1; + return; } const llama_ngram_cache_part part_static = part_static_it->second; - int max_count_static = 0; int sum_count_static = 0; - llama_token max_token = -1; for (std::pair token_count_static : part_static) { - const llama_token token = token_count_static.first; const int32_t count_static = token_count_static.second; - if (count_static > max_count_static) { - max_token = token; - max_count_static = count_static; - } sum_count_static += count_static; } - if (sum_count_static < draft_min_sample_size_lax[LLAMA_NGRAM_STATIC-1]) { - return -1; - } - if (100*max_count_static < draft_min_percent_lax[LLAMA_NGRAM_STATIC-1]*sum_count_static) { - return -1; + for (std::pair token_count_static : part_static) { + const llama_token token = token_count_static.first; + const int32_t count_static = token_count_static.second; + + if (sum_count_static < min_sample_size[LLAMA_NGRAM_STATIC-1]) { + continue; + } + if (100*count_static < min_percent[LLAMA_NGRAM_STATIC-1]*sum_count_static) { + continue;; + } + + draft_candidate cc; + for (const llama_token & t : cp.draft) { + cc.draft.push_back(t); + } + cc.draft.push_back(token); + cc.nll = cp.nll - logf(1.0f*count_static/sum_count_static); + cc.nsampled = nsc; + + bool duplicate = false; + for (const draft_candidate & co : drafts_new) { + if (co.draft == cc.draft) { + duplicate = true; + break; + } + } + if (duplicate) { + continue; + } + + drafts_new.push_back(cc); } - return max_token; } -// Try to draft a token from primary cache (context/dynamic), validate with static cache: -static llama_token try_draft( +// Try to draft tokens from primary cache (context/dynamic), validate with static cache: +static void try_draft( llama_ngram_cache & nc_primary, const std::vector & ngrams_primary, llama_ngram_cache_part & part_static, - const int * min_sample_size, const int * min_percent) { + const int * min_sample_size, const int * min_percent, const draft_candidate & cp, + const int ngram_min, std::vector & drafts_new) { - llama_token drafted_token = -1; + for (int i = ngrams_primary.size()-1; i >= 0; --i) { + const int nsc = (ngram_min + i) - (cp.draft.size() - 1); + if (nsc < (ngram_min + i + 1)/2) { + break; + } - for (int i = ngrams_primary.size()-1; i >= 0 && drafted_token == -1; --i) { const llama_ngram ngram_primary = ngrams_primary[i]; llama_ngram_cache::iterator part_primary_it = nc_primary.find(ngram_primary); @@ -106,10 +155,8 @@ static llama_token try_draft( } const llama_ngram_cache_part part_primary = part_primary_it->second; - int max_count_primary = 0; - int max_count_static = 0; int sum_count_primary = 0; - llama_token max_token = -1; + int sum_count_prod = 0; for (std::pair token_count_primary : part_primary) { const llama_token token = token_count_primary.first; @@ -119,44 +166,100 @@ static llama_token try_draft( const int32_t count_primary = token_count_primary.second; const int32_t count_static = token_count_static_it != part_static.end() ? 100*token_count_static_it->second : 1; - if (count_primary*count_static > max_count_primary*max_count_static) { - max_token = token; - max_count_primary = count_primary; - max_count_static = count_static; - } sum_count_primary += count_primary; + sum_count_prod += count_primary*count_static; } - if (sum_count_primary < min_sample_size[i]) { - continue; - } - if (100*max_count_primary < min_percent[i]*sum_count_primary) { - continue;; + for (std::pair token_count_primary : part_primary) { + const llama_token token = token_count_primary.first; + + llama_ngram_cache_part::iterator token_count_static_it = part_static.find(token); + + const int32_t count_primary = token_count_primary.second; + const int32_t count_static = token_count_static_it != part_static.end() ? 100*token_count_static_it->second : 1; + const int32_t count_prod = count_primary*count_static; + + if (sum_count_primary < min_sample_size[i]) { + continue; + } + + if (100*count_prod < min_percent[i]*sum_count_prod) { + continue; + } + + draft_candidate cc; + for (const llama_token & t : cp.draft) { + cc.draft.push_back(t); + } + cc.draft.push_back(token); + cc.nll = cp.nll - logf(1.0f*count_prod/sum_count_prod); + cc.nsampled = nsc; + + bool duplicate = false; + for (const draft_candidate & co : drafts_new) { + if (co.draft == cc.draft) { + duplicate = true; + break; + } + } + if (duplicate) { + continue; + } + + drafts_new.push_back(cc); } - drafted_token = max_token; } - - return drafted_token; } void llama_ngram_cache_draft( - std::vector & inp, std::vector & draft, int n_draft, int ngram_min, int ngram_max, + std::vector & inp, std::vector> & drafts, int n_draft, int ngram_min, int ngram_max, llama_ngram_cache & nc_context, llama_ngram_cache & nc_dynamic, llama_ngram_cache & nc_static ) { - GGML_ASSERT(draft.size() == 1); + if (n_draft == 0) { + return; + } + + GGML_ASSERT(drafts.size() == 1); + GGML_ASSERT(drafts[0].size() == 1); const int inp_size = inp.size(); - if (inp_size < LLAMA_NGRAM_STATIC) { + if (inp_size < std::max(ngram_max, LLAMA_NGRAM_STATIC)) { return; } - while ((int) draft.size()-1 < n_draft) { - llama_token drafted_token = -1; + // While building the tree, store drafts with potential children in a heap: + std::vector drafts_wip; + + { + draft_candidate candidate; + candidate.draft.push_back(drafts[0][0]); + candidate.nll = 0.0f; + candidate.nsampled = LLAMA_NGRAM_MAX; + drafts_wip.push_back(candidate); + } + + drafts.clear(); + int i_draft = 0; + + // Temporarily hold new drafts in vector, only add part of them in the last iteration to exactly meet n_draft. + std::vector drafts_new; - const int ngram_start_static = inp_size-LLAMA_NGRAM_STATIC + draft.size()-1; + while (i_draft + ((int) drafts_new.size()) < n_draft && !(drafts_wip.empty() && drafts_new.empty())) { + for (const draft_candidate & ndc : drafts_new) { + drafts_wip.push_back(ndc); + std::push_heap(drafts_wip.begin(), drafts_wip.end(), compare_draft_candidate()); + i_draft++; + } + drafts_new.clear(); + + std::pop_heap(drafts_wip.begin(), drafts_wip.end(), compare_draft_candidate()); + const draft_candidate cp = drafts_wip.back(); // cp = candidate parent + drafts_wip.pop_back(); + + const int ngram_start_static = inp_size-LLAMA_NGRAM_STATIC + cp.draft.size()-1; llama_ngram ngram_static; for (int j = ngram_start_static; j < ngram_start_static + LLAMA_NGRAM_STATIC; ++j) { - ngram_static.tokens[j-ngram_start_static] = get_token(inp, draft, j); + ngram_static.tokens[j-ngram_start_static] = get_token(inp, cp.draft, j); } llama_ngram_cache::iterator part_static_it = nc_static.find(ngram_static); llama_ngram_cache_part part_static; @@ -167,29 +270,37 @@ void llama_ngram_cache_draft( // cd = context + dynamic std::vector ngrams_cd; for (int ngram_size_cd = ngram_min; ngram_size_cd <= ngram_max; ++ngram_size_cd) { - const int ngram_start_cd = inp_size-ngram_size_cd + draft.size()-1; + const int ngram_start_cd = inp_size-ngram_size_cd + cp.draft.size()-1; llama_ngram ngram_cd; for (int j = ngram_start_cd; j < ngram_start_cd + ngram_size_cd; ++j) { - ngram_cd.tokens[j-ngram_start_cd] = get_token(inp, draft, j); + ngram_cd.tokens[j-ngram_start_cd] = get_token(inp, cp.draft, j); } ngrams_cd.push_back(ngram_cd); } - if (drafted_token == -1) { - drafted_token = try_draft(nc_context, ngrams_cd, part_static, draft_min_sample_size_lax, draft_min_percent_lax); - } - if (drafted_token == -1) { - drafted_token = try_draft(nc_dynamic, ngrams_cd, part_static, draft_min_sample_size_strict, draft_min_percent_strict); - } - if (drafted_token == -1) { - drafted_token = try_draft(nc_static, ngram_static); + + try_draft(nc_context, ngrams_cd, part_static, draft_min_sample_size_lax, draft_min_percent_lax, cp, ngram_min, drafts_new); + try_draft(nc_dynamic, ngrams_cd, part_static, draft_min_sample_size_strict, draft_min_percent_lax, cp, ngram_min, drafts_new); + try_draft(nc_static, ngram_static, draft_min_sample_size_strict, draft_min_percent_strict, cp, ngram_min, drafts_new); + + if (drafts_new.empty()) { + drafts.push_back(cp.draft); + i_draft++; } + } - if (drafted_token == -1) { + for (const draft_candidate & dc : drafts_wip) { // dc = draft child + drafts.push_back(dc.draft); + } + + std::sort(drafts_new.begin(), drafts_new.end(), compare_draft_candidate()); + + for (const draft_candidate & dc : drafts_new) { + drafts.push_back(dc.draft); + i_draft++; + + if (i_draft >= n_draft) { break; } - - LOG(" - draft candidate: token=%d\n", drafted_token); - draft.push_back(drafted_token); } } diff --git a/common/ngram-cache.h b/common/ngram-cache.h index ab4c9b3766546..430020754491a 100644 --- a/common/ngram-cache.h +++ b/common/ngram-cache.h @@ -60,6 +60,7 @@ typedef std::unordered_map llama_ngram_cache_part; // n-gram -> empirical distribution of following tokens typedef std::unordered_map llama_ngram_cache; +typedef std::vector llama_draft_t; // Update an ngram cache with tokens. // ngram_cache: the cache to modify. @@ -82,7 +83,7 @@ void llama_ngram_cache_update( // nc_dynamic: ngram cache based on previous user generations. // nc_static: ngram cache generated from a large text corpus, used for validation. void llama_ngram_cache_draft( - std::vector & inp, std::vector & draft, int n_draft, int ngram_min, int ngram_max, + std::vector & inp, std::vector & drafts, int n_draft, int ngram_min, int ngram_max, llama_ngram_cache & nc_context, llama_ngram_cache & nc_dynamic, llama_ngram_cache & nc_static); // Save an ngram cache to a file. diff --git a/examples/lookup/lookup-stats.cpp b/examples/lookup/lookup-stats.cpp index 2fe67100e6c03..c51b2b2f2104f 100644 --- a/examples/lookup/lookup-stats.cpp +++ b/examples/lookup/lookup-stats.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #include #include @@ -80,22 +81,42 @@ int main(int argc, char ** argv){ while ((int) pseudo_output.size() < n_ctx) { // Simulate drafting and decoding from draft: - std::vector draft; - draft.push_back(pseudo_output.back()); + std::vector drafts; + llama_draft_t draft0; + draft0.push_back(pseudo_output.back()); + drafts.push_back(draft0); { const int64_t t_start_draft_us = ggml_time_us(); - llama_ngram_cache_draft(pseudo_output, draft, n_draft, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, ngram_cache_context, ngram_cache_dynamic, ngram_cache_static); + llama_ngram_cache_draft( + pseudo_output, drafts, n_draft, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, ngram_cache_context, ngram_cache_dynamic, ngram_cache_static); t_draft_us += ggml_time_us() - t_start_draft_us; } + GGML_ASSERT((int) drafts.size() <= n_draft || n_draft <= 0); - n_drafted += draft.size() - 1; + // FIXME wrong KV mask for converging sequences (does not seem to happen in practice). + for (int j = 1; j < n_draft + 1; ++j) { + std::set seen_tokens; - for (size_t j = 1; j < draft.size() && (int) pseudo_output.size() < n_ctx; ++j) { + for (const llama_draft_t & draft : drafts) { + if (j < (int) draft.size() && seen_tokens.find(draft[j]) == seen_tokens.end()) { + seen_tokens.emplace(draft[j]); + n_drafted++; + } + } + } + + for (int j = 1; j < n_draft + 1 && (int) pseudo_output.size() < n_ctx; ++j) { const llama_token ground_truth = inp_slice[pseudo_output.size()]; - const llama_token drafted = draft[j]; - if (ground_truth != drafted) { + bool ground_truth_in_drafts = false; + for (const llama_draft_t & draft : drafts) { + if (j < (int) draft.size() && draft[j] == ground_truth) { + ground_truth_in_drafts = true; + break; + } + } + if (!ground_truth_in_drafts) { break; } @@ -119,7 +140,7 @@ int main(int argc, char ** argv){ } } - draft.erase(draft.begin()); + drafts.clear(); } if (i_start > 0 && i_start / 100000 != (i_start - n_ctx) / 100000) { diff --git a/examples/lookup/lookup.cpp b/examples/lookup/lookup.cpp index bb571bac4d778..54daf8983982e 100644 --- a/examples/lookup/lookup.cpp +++ b/examples/lookup/lookup.cpp @@ -7,6 +7,7 @@ #include #include #include +#include #include #include #include @@ -21,6 +22,7 @@ int main(int argc, char ** argv){ // max. number of additional tokens to draft if match is found const int n_draft = params.n_draft; + const int n_seq = std::max(n_draft, 1); const bool dump_kv_cache = params.dump_kv_cache; @@ -108,9 +110,12 @@ int main(int argc, char ** argv){ struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams); - std::vector draft; + std::vector drafts; - llama_batch batch_tgt = llama_batch_init(params.n_ctx, 0, 1); + llama_batch batch_tgt = llama_batch_init(max_context_size, 0, n_seq); + std::vector> sampling_idx_store; + sampling_idx_store.resize(n_seq); + sampling_idx_store[0].push_back(0); // debug struct llama_kv_cache_view kvc_view = llama_kv_cache_view_init(ctx, 1); @@ -124,13 +129,11 @@ int main(int argc, char ** argv){ llama_kv_cache_dump_view_seqs(kvc_view, 40); } - // print current draft sequence - LOG("drafted %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, draft).c_str()); - int i_dft = 0; + int seq_best = 0; while (true) { // sample from the target model - llama_token id = llama_sampling_sample(ctx_sampling, ctx, NULL, i_dft); + llama_token id = llama_sampling_sample(ctx_sampling, ctx, NULL, sampling_idx_store[seq_best][i_dft]); llama_sampling_accept(ctx_sampling, ctx, id, true); @@ -147,24 +150,32 @@ int main(int argc, char ** argv){ ++n_predict; // check if the target token matches the draft - if (i_dft < (int) draft.size() && id == draft[i_dft]) { - LOG("the sampled target token matches the %dth drafted token (%d, '%s') - accepted\n", i_dft, id, token_str.c_str()); - ++n_accept; - ++n_past; - ++i_dft; - inp.push_back(id); - { - // Update context ngram cache with the newly accepted token: - const int64_t t_start_draft_us = ggml_time_us(); - llama_ngram_cache_update(ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, inp, 1, false); - t_draft_us += ggml_time_us() - t_start_draft_us; - } - - if (params.use_color) { - // color accepted draft token - printf("\033[34m%s\033[0m", token_str.c_str()); - fflush(stdout); + bool accepted = false; + for (int j = 0; j < (int) drafts.size() && !has_eos && !drafts.empty(); ++j) { + if (i_dft + 1 < (int) drafts[j].size() && id == drafts[j][i_dft + 1]) { + LOG("draft success: (%d, '%s'), seq_id=%d\n", id, token_str.c_str(), j); + ++n_accept; + ++n_past; + ++i_dft; + inp.push_back(id); + seq_best = j; + { + // Update context ngram cache with the newly accepted token: + const int64_t t_start_draft_us = ggml_time_us(); + llama_ngram_cache_update(ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, inp, 1, false); + t_draft_us += ggml_time_us() - t_start_draft_us; + } + + if (params.use_color) { + // color accepted draft token + printf("\033[34m%s\033[0m", token_str.c_str()); + fflush(stdout); + } + accepted = true; + break; } + } + if (accepted) { continue; } @@ -174,10 +185,10 @@ int main(int argc, char ** argv){ fflush(stdout); - LOG("the sampled target token (%d, '%s') did not match, or we ran out of drafted tokens\n", id, token_str.c_str()); + LOG("sampled: (%d, '%s')\n", id, token_str.c_str()); - draft.clear(); - draft.push_back(id); + drafts.clear(); + drafts.push_back({id}); inp.push_back(id); { // Update context ngram cache with the newly accepted token: @@ -194,29 +205,87 @@ int main(int argc, char ** argv){ // KV cache management // clean the cache of draft tokens that weren't accepted + if (seq_best != 0 && i_dft > 0) { + llama_kv_cache_seq_cp(ctx, seq_best, 0, n_past-i_dft, n_past); + } + llama_kv_cache_seq_keep(ctx, 0); llama_kv_cache_seq_rm(ctx, 0, n_past, -1); llama_batch_clear(batch_tgt); - llama_batch_add(batch_tgt, draft[0], n_past, { 0 }, true); + for (int j = 0; j < n_seq; ++j) { + sampling_idx_store[j].clear(); + } // Draft already contains a single token sampled from the model: - GGML_ASSERT(draft.size() == 1); - GGML_ASSERT(draft[0] == inp.back()); + GGML_ASSERT(drafts.size() == 1); + GGML_ASSERT(drafts[0].size() == 1); + GGML_ASSERT(drafts[0][0] == inp.back()); const int64_t t_start_draft_us = ggml_time_us(); - llama_ngram_cache_draft(inp, draft, n_draft, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, ngram_cache_context, ngram_cache_dynamic, ngram_cache_static); + llama_ngram_cache_draft(inp, drafts, n_draft, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, ngram_cache_context, ngram_cache_dynamic, ngram_cache_static); + + for (int j = 1; j < (int) drafts.size(); ++j) { + llama_kv_cache_seq_cp(ctx, 0, j, -1, -1); + } + + int draft_max = 0; + for (const llama_draft_t & draft : drafts) { + draft_max = std::max(draft_max, (int) draft.size()); + } + + if (draft_max > 1) { + LOG("drafts:\n"); + for (const llama_draft_t & draft : drafts) { + LOG(" - %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, draft).c_str()); + } + } + + // FIXME wrong KV mask for converging sequences (does not seem to happen in practice). + for (int i = 0; i < draft_max; ++i) { + std::set seen_tokens; + + while (true) { + llama_token current_token = -1; + std::vector current_seq_ids; + + for (int j = 0; j < (int) drafts.size(); ++j) { + if (i >= (int) drafts[j].size()) { + continue; + } + + if (current_token == -1) { + if (seen_tokens.find(drafts[j][i]) != seen_tokens.end()) { + continue; + } - for (size_t i = 1; i < draft.size(); ++i) { - llama_batch_add(batch_tgt, draft[i], n_past + i, { 0 }, true); + current_token = drafts[j][i]; + seen_tokens.emplace(current_token); + } + + if (drafts[j][i] != current_token) { + continue; + } + + current_seq_ids.push_back(j); + } + + if (current_token == -1) { + break; + } + + for (const llama_seq_id & sid : current_seq_ids) { + sampling_idx_store[sid].push_back(batch_tgt.n_tokens); + } + llama_batch_add(batch_tgt, current_token, n_past + i, current_seq_ids, true); + n_drafted++; + } } + n_drafted--; // 1 out of the added token was sampled; t_draft_us += ggml_time_us() - t_start_draft_us; - n_drafted += draft.size() - 1; llama_decode(ctx, batch_tgt); ++n_past; - - draft.erase(draft.begin()); } auto t_dec_end = ggml_time_us();