Skip to content

lookup: Use tree of sequences instead of single sequence #8648

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
235 changes: 173 additions & 62 deletions common/ngram-cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,52 +52,101 @@ static llama_token get_token(const std::vector<llama_token> & inp, const std::ve
return i < inp.size() ? inp[i] : draft[1 + i - inp.size()];
}

// If sample size or percentage are below these thresholds the draft is aborted early:
constexpr int draft_min_sample_size_lax[LLAMA_NGRAM_MAX] = { 2, 2, 1, 1};
constexpr int draft_min_percent_lax[LLAMA_NGRAM_MAX] = {66, 50, 50, 50};
// Sample size and percentage must meet these thresholds to be added to the draft tree:
constexpr int draft_min_sample_size_lax[LLAMA_NGRAM_MAX] = { 1, 1, 1, 1};
constexpr int draft_min_percent_lax[LLAMA_NGRAM_MAX] = {20, 20, 10, 10};
constexpr int draft_min_sample_size_strict[LLAMA_NGRAM_MAX] = { 4, 3, 2, 2};
constexpr int draft_min_percent_strict[LLAMA_NGRAM_MAX] = {75, 66, 66, 66};
constexpr int draft_min_percent_strict[LLAMA_NGRAM_MAX] = {50, 50, 50, 50};

struct draft_candidate {
llama_draft_t draft;
float nll;
int nsampled;
};

struct compare_draft_candidate {
bool operator()(const draft_candidate & a, const draft_candidate & b){
if (a.nsampled > b.nsampled) {
return true;
}
if (a.nsampled < b.nsampled) {
return false;
}
return a.nll < b.nll;
}
};

// Helper function that tries to draft tokens from only the static ngram cache:
static void try_draft(
llama_ngram_cache & nc_static, const llama_ngram & ngram_static,
const int * min_sample_size, const int * min_percent, const draft_candidate & cp,
const int ngram_min, std::vector<draft_candidate> & drafts_new) {

const int nsc = (ngram_min + LLAMA_NGRAM_STATIC) - (cp.draft.size() - 1);
if (nsc < (ngram_min + LLAMA_NGRAM_STATIC + 1)/2) {
return;
}

// Helper function that tries to draft a token from only the static ngram cache:
static llama_token try_draft(llama_ngram_cache & nc_static, const llama_ngram ngram_static) {
llama_ngram_cache::iterator part_static_it = nc_static.find(ngram_static);
if (part_static_it == nc_static.end()) {
return -1;
return;
}
const llama_ngram_cache_part part_static = part_static_it->second;

int max_count_static = 0;
int sum_count_static = 0;
llama_token max_token = -1;

for (std::pair<llama_token, int> token_count_static : part_static) {
const llama_token token = token_count_static.first;
const int32_t count_static = token_count_static.second;

if (count_static > max_count_static) {
max_token = token;
max_count_static = count_static;
}
sum_count_static += count_static;
}

if (sum_count_static < draft_min_sample_size_lax[LLAMA_NGRAM_STATIC-1]) {
return -1;
}
if (100*max_count_static < draft_min_percent_lax[LLAMA_NGRAM_STATIC-1]*sum_count_static) {
return -1;
for (std::pair<llama_token, int> token_count_static : part_static) {
const llama_token token = token_count_static.first;
const int32_t count_static = token_count_static.second;

if (sum_count_static < min_sample_size[LLAMA_NGRAM_STATIC-1]) {
continue;
}
if (100*count_static < min_percent[LLAMA_NGRAM_STATIC-1]*sum_count_static) {
continue;;
}

draft_candidate cc;
for (const llama_token & t : cp.draft) {
cc.draft.push_back(t);
}
cc.draft.push_back(token);
cc.nll = cp.nll - logf(1.0f*count_static/sum_count_static);
cc.nsampled = nsc;

bool duplicate = false;
for (const draft_candidate & co : drafts_new) {
if (co.draft == cc.draft) {
duplicate = true;
break;
}
}
if (duplicate) {
continue;
}

drafts_new.push_back(cc);
}
return max_token;
}

// Try to draft a token from primary cache (context/dynamic), validate with static cache:
static llama_token try_draft(
// Try to draft tokens from primary cache (context/dynamic), validate with static cache:
static void try_draft(
llama_ngram_cache & nc_primary, const std::vector<llama_ngram> & ngrams_primary, llama_ngram_cache_part & part_static,
const int * min_sample_size, const int * min_percent) {
const int * min_sample_size, const int * min_percent, const draft_candidate & cp,
const int ngram_min, std::vector<draft_candidate> & drafts_new) {

llama_token drafted_token = -1;
for (int i = ngrams_primary.size()-1; i >= 0; --i) {
const int nsc = (ngram_min + i) - (cp.draft.size() - 1);
if (nsc < (ngram_min + i + 1)/2) {
break;
}

for (int i = ngrams_primary.size()-1; i >= 0 && drafted_token == -1; --i) {
const llama_ngram ngram_primary = ngrams_primary[i];

llama_ngram_cache::iterator part_primary_it = nc_primary.find(ngram_primary);
Expand All @@ -106,10 +155,8 @@ static llama_token try_draft(
}
const llama_ngram_cache_part part_primary = part_primary_it->second;

int max_count_primary = 0;
int max_count_static = 0;
int sum_count_primary = 0;
llama_token max_token = -1;
int sum_count_prod = 0;

for (std::pair<llama_token, int> token_count_primary : part_primary) {
const llama_token token = token_count_primary.first;
Expand All @@ -119,44 +166,100 @@ static llama_token try_draft(
const int32_t count_primary = token_count_primary.second;
const int32_t count_static = token_count_static_it != part_static.end() ? 100*token_count_static_it->second : 1;

if (count_primary*count_static > max_count_primary*max_count_static) {
max_token = token;
max_count_primary = count_primary;
max_count_static = count_static;
}
sum_count_primary += count_primary;
sum_count_prod += count_primary*count_static;
}

if (sum_count_primary < min_sample_size[i]) {
continue;
}
if (100*max_count_primary < min_percent[i]*sum_count_primary) {
continue;;
for (std::pair<llama_token, int> token_count_primary : part_primary) {
const llama_token token = token_count_primary.first;

llama_ngram_cache_part::iterator token_count_static_it = part_static.find(token);

const int32_t count_primary = token_count_primary.second;
const int32_t count_static = token_count_static_it != part_static.end() ? 100*token_count_static_it->second : 1;
const int32_t count_prod = count_primary*count_static;

if (sum_count_primary < min_sample_size[i]) {
continue;
}

if (100*count_prod < min_percent[i]*sum_count_prod) {
continue;
}

draft_candidate cc;
for (const llama_token & t : cp.draft) {
cc.draft.push_back(t);
}
cc.draft.push_back(token);
cc.nll = cp.nll - logf(1.0f*count_prod/sum_count_prod);
cc.nsampled = nsc;

bool duplicate = false;
for (const draft_candidate & co : drafts_new) {
if (co.draft == cc.draft) {
duplicate = true;
break;
}
}
if (duplicate) {
continue;
}

drafts_new.push_back(cc);
}
drafted_token = max_token;
}

return drafted_token;
}

void llama_ngram_cache_draft(
std::vector<llama_token> & inp, std::vector<llama_token> & draft, int n_draft, int ngram_min, int ngram_max,
std::vector<llama_token> & inp, std::vector<std::vector<llama_token>> & drafts, int n_draft, int ngram_min, int ngram_max,
llama_ngram_cache & nc_context, llama_ngram_cache & nc_dynamic, llama_ngram_cache & nc_static
) {
GGML_ASSERT(draft.size() == 1);
if (n_draft == 0) {
return;
}

GGML_ASSERT(drafts.size() == 1);
GGML_ASSERT(drafts[0].size() == 1);
const int inp_size = inp.size();

if (inp_size < LLAMA_NGRAM_STATIC) {
if (inp_size < std::max(ngram_max, LLAMA_NGRAM_STATIC)) {
return;
}

while ((int) draft.size()-1 < n_draft) {
llama_token drafted_token = -1;
// While building the tree, store drafts with potential children in a heap:
std::vector<draft_candidate> drafts_wip;

{
draft_candidate candidate;
candidate.draft.push_back(drafts[0][0]);
candidate.nll = 0.0f;
candidate.nsampled = LLAMA_NGRAM_MAX;
drafts_wip.push_back(candidate);
}

drafts.clear();
int i_draft = 0;

// Temporarily hold new drafts in vector, only add part of them in the last iteration to exactly meet n_draft.
std::vector<draft_candidate> drafts_new;

const int ngram_start_static = inp_size-LLAMA_NGRAM_STATIC + draft.size()-1;
while (i_draft + ((int) drafts_new.size()) < n_draft && !(drafts_wip.empty() && drafts_new.empty())) {
for (const draft_candidate & ndc : drafts_new) {
drafts_wip.push_back(ndc);
std::push_heap(drafts_wip.begin(), drafts_wip.end(), compare_draft_candidate());
i_draft++;
}
drafts_new.clear();

std::pop_heap(drafts_wip.begin(), drafts_wip.end(), compare_draft_candidate());
const draft_candidate cp = drafts_wip.back(); // cp = candidate parent
drafts_wip.pop_back();

const int ngram_start_static = inp_size-LLAMA_NGRAM_STATIC + cp.draft.size()-1;
llama_ngram ngram_static;
for (int j = ngram_start_static; j < ngram_start_static + LLAMA_NGRAM_STATIC; ++j) {
ngram_static.tokens[j-ngram_start_static] = get_token(inp, draft, j);
ngram_static.tokens[j-ngram_start_static] = get_token(inp, cp.draft, j);
}
llama_ngram_cache::iterator part_static_it = nc_static.find(ngram_static);
llama_ngram_cache_part part_static;
Expand All @@ -167,29 +270,37 @@ void llama_ngram_cache_draft(
// cd = context + dynamic
std::vector<llama_ngram> ngrams_cd;
for (int ngram_size_cd = ngram_min; ngram_size_cd <= ngram_max; ++ngram_size_cd) {
const int ngram_start_cd = inp_size-ngram_size_cd + draft.size()-1;
const int ngram_start_cd = inp_size-ngram_size_cd + cp.draft.size()-1;
llama_ngram ngram_cd;
for (int j = ngram_start_cd; j < ngram_start_cd + ngram_size_cd; ++j) {
ngram_cd.tokens[j-ngram_start_cd] = get_token(inp, draft, j);
ngram_cd.tokens[j-ngram_start_cd] = get_token(inp, cp.draft, j);
}
ngrams_cd.push_back(ngram_cd);
}
if (drafted_token == -1) {
drafted_token = try_draft(nc_context, ngrams_cd, part_static, draft_min_sample_size_lax, draft_min_percent_lax);
}
if (drafted_token == -1) {
drafted_token = try_draft(nc_dynamic, ngrams_cd, part_static, draft_min_sample_size_strict, draft_min_percent_strict);
}
if (drafted_token == -1) {
drafted_token = try_draft(nc_static, ngram_static);

try_draft(nc_context, ngrams_cd, part_static, draft_min_sample_size_lax, draft_min_percent_lax, cp, ngram_min, drafts_new);
try_draft(nc_dynamic, ngrams_cd, part_static, draft_min_sample_size_strict, draft_min_percent_lax, cp, ngram_min, drafts_new);
try_draft(nc_static, ngram_static, draft_min_sample_size_strict, draft_min_percent_strict, cp, ngram_min, drafts_new);

if (drafts_new.empty()) {
drafts.push_back(cp.draft);
i_draft++;
}
}

if (drafted_token == -1) {
for (const draft_candidate & dc : drafts_wip) { // dc = draft child
drafts.push_back(dc.draft);
}

std::sort(drafts_new.begin(), drafts_new.end(), compare_draft_candidate());

for (const draft_candidate & dc : drafts_new) {
drafts.push_back(dc.draft);
i_draft++;

if (i_draft >= n_draft) {
break;
}

LOG(" - draft candidate: token=%d\n", drafted_token);
draft.push_back(drafted_token);
}
}

Expand Down
3 changes: 2 additions & 1 deletion common/ngram-cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ typedef std::unordered_map<llama_token, int32_t> llama_ngram_cache_part;
// n-gram -> empirical distribution of following tokens
typedef std::unordered_map<llama_ngram, llama_ngram_cache_part, llama_ngram_hash_function> llama_ngram_cache;

typedef std::vector<llama_token> llama_draft_t;

// Update an ngram cache with tokens.
// ngram_cache: the cache to modify.
Expand All @@ -82,7 +83,7 @@ void llama_ngram_cache_update(
// nc_dynamic: ngram cache based on previous user generations.
// nc_static: ngram cache generated from a large text corpus, used for validation.
void llama_ngram_cache_draft(
std::vector<llama_token> & inp, std::vector<llama_token> & draft, int n_draft, int ngram_min, int ngram_max,
std::vector<llama_token> & inp, std::vector<llama_draft_t> & drafts, int n_draft, int ngram_min, int ngram_max,
llama_ngram_cache & nc_context, llama_ngram_cache & nc_dynamic, llama_ngram_cache & nc_static);

// Save an ngram cache to a file.
Expand Down
37 changes: 29 additions & 8 deletions examples/lookup/lookup-stats.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <cstdint>
#include <cstdio>
#include <fstream>
#include <set>
#include <string>
#include <vector>
#include <unordered_map>
Expand Down Expand Up @@ -80,22 +81,42 @@ int main(int argc, char ** argv){

while ((int) pseudo_output.size() < n_ctx) {
// Simulate drafting and decoding from draft:
std::vector<llama_token> draft;
draft.push_back(pseudo_output.back());
std::vector<llama_draft_t> drafts;
llama_draft_t draft0;
draft0.push_back(pseudo_output.back());
drafts.push_back(draft0);

{
const int64_t t_start_draft_us = ggml_time_us();
llama_ngram_cache_draft(pseudo_output, draft, n_draft, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, ngram_cache_context, ngram_cache_dynamic, ngram_cache_static);
llama_ngram_cache_draft(
pseudo_output, drafts, n_draft, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, ngram_cache_context, ngram_cache_dynamic, ngram_cache_static);
t_draft_us += ggml_time_us() - t_start_draft_us;
}
GGML_ASSERT((int) drafts.size() <= n_draft || n_draft <= 0);

n_drafted += draft.size() - 1;
// FIXME wrong KV mask for converging sequences (does not seem to happen in practice).
for (int j = 1; j < n_draft + 1; ++j) {
std::set<llama_token> seen_tokens;

for (size_t j = 1; j < draft.size() && (int) pseudo_output.size() < n_ctx; ++j) {
for (const llama_draft_t & draft : drafts) {
if (j < (int) draft.size() && seen_tokens.find(draft[j]) == seen_tokens.end()) {
seen_tokens.emplace(draft[j]);
n_drafted++;
}
}
}

for (int j = 1; j < n_draft + 1 && (int) pseudo_output.size() < n_ctx; ++j) {
const llama_token ground_truth = inp_slice[pseudo_output.size()];
const llama_token drafted = draft[j];

if (ground_truth != drafted) {
bool ground_truth_in_drafts = false;
for (const llama_draft_t & draft : drafts) {
if (j < (int) draft.size() && draft[j] == ground_truth) {
ground_truth_in_drafts = true;
break;
}
}
if (!ground_truth_in_drafts) {
break;
}

Expand All @@ -119,7 +140,7 @@ int main(int argc, char ** argv){
}
}

draft.erase(draft.begin());
drafts.clear();

}
if (i_start > 0 && i_start / 100000 != (i_start - n_ctx) / 100000) {
Expand Down
Loading
Loading