Skip to content

Commit 559d5c2

Browse files
committed
lookup: Use tree of sequences instead of single sequence ggml-org#8648
Lookup tree PR by Johannes Gaessler Fix non renamed llama_batch into common_batch Update lookup.cpp
1 parent 6f95483 commit 559d5c2

File tree

4 files changed

+314
-112
lines changed

4 files changed

+314
-112
lines changed

common/ngram-cache.cpp

Lines changed: 174 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -55,52 +55,101 @@ static llama_token get_token(const std::vector<llama_token> & inp, const std::ve
5555
return i < inp.size() ? inp[i] : draft[1 + i - inp.size()];
5656
}
5757

58-
// If sample size or percentage are below these thresholds the draft is aborted early:
59-
constexpr int draft_min_sample_size_lax[LLAMA_NGRAM_MAX] = { 2, 2, 1, 1};
60-
constexpr int draft_min_percent_lax[LLAMA_NGRAM_MAX] = {66, 50, 50, 50};
58+
// Sample size and percentage must meet these thresholds to be added to the draft tree:
59+
constexpr int draft_min_sample_size_lax[LLAMA_NGRAM_MAX] = { 1, 1, 1, 1};
60+
constexpr int draft_min_percent_lax[LLAMA_NGRAM_MAX] = {20, 20, 10, 10};
6161
constexpr int draft_min_sample_size_strict[LLAMA_NGRAM_MAX] = { 4, 3, 2, 2};
62-
constexpr int draft_min_percent_strict[LLAMA_NGRAM_MAX] = {75, 66, 66, 66};
62+
constexpr int draft_min_percent_strict[LLAMA_NGRAM_MAX] = {50, 50, 50, 50};
63+
64+
struct draft_candidate {
65+
llama_draft_t draft;
66+
float nll;
67+
int nsampled;
68+
};
69+
70+
struct compare_draft_candidate {
71+
bool operator()(const draft_candidate & a, const draft_candidate & b){
72+
if (a.nsampled > b.nsampled) {
73+
return true;
74+
}
75+
if (a.nsampled < b.nsampled) {
76+
return false;
77+
}
78+
return a.nll < b.nll;
79+
}
80+
};
81+
82+
// Helper function that tries to draft tokens from only the static ngram cache:
83+
static void try_draft(
84+
common_ngram_cache & nc_static, const common_ngram & ngram_static,
85+
const int * min_sample_size, const int * min_percent, const draft_candidate & cp,
86+
const int ngram_min, std::vector<draft_candidate> & drafts_new) {
87+
88+
const int nsc = (ngram_min + common_ngram_STATIC) - (cp.draft.size() - 1);
89+
if (nsc < (ngram_min + common_ngram_STATIC + 1)/2) {
90+
return;
91+
}
6392

64-
// Helper function that tries to draft a token from only the static ngram cache:
65-
static llama_token try_draft(common_ngram_cache & nc_static, const common_ngram ngram_static) {
6693
common_ngram_cache::iterator part_static_it = nc_static.find(ngram_static);
6794
if (part_static_it == nc_static.end()) {
68-
return -1;
95+
return;
6996
}
7097
const common_ngram_cache_part part_static = part_static_it->second;
7198

72-
int max_count_static = 0;
7399
int sum_count_static = 0;
74-
llama_token max_token = -1;
75100

76101
for (std::pair<llama_token, int> token_count_static : part_static) {
77-
const llama_token token = token_count_static.first;
78102
const int32_t count_static = token_count_static.second;
79103

80-
if (count_static > max_count_static) {
81-
max_token = token;
82-
max_count_static = count_static;
83-
}
84104
sum_count_static += count_static;
85105
}
86106

87-
if (sum_count_static < draft_min_sample_size_lax[LLAMA_NGRAM_STATIC-1]) {
88-
return -1;
89-
}
90-
if (100*max_count_static < draft_min_percent_lax[LLAMA_NGRAM_STATIC-1]*sum_count_static) {
91-
return -1;
107+
for (std::pair<llama_token, int> token_count_static : part_static) {
108+
const llama_token token = token_count_static.first;
109+
const int32_t count_static = token_count_static.second;
110+
111+
if (sum_count_static < min_sample_size[common_ngram_STATIC-1]) {
112+
continue;
113+
}
114+
if (100*count_static < min_percent[common_ngram_STATIC-1]*sum_count_static) {
115+
continue;;
116+
}
117+
118+
draft_candidate cc;
119+
for (const llama_token & t : cp.draft) {
120+
cc.draft.push_back(t);
121+
}
122+
cc.draft.push_back(token);
123+
cc.nll = cp.nll - logf(1.0f*count_static/sum_count_static);
124+
cc.nsampled = nsc;
125+
126+
bool duplicate = false;
127+
for (const draft_candidate & co : drafts_new) {
128+
if (co.draft == cc.draft) {
129+
duplicate = true;
130+
break;
131+
}
132+
}
133+
if (duplicate) {
134+
continue;
135+
}
136+
137+
drafts_new.push_back(cc);
92138
}
93-
return max_token;
94139
}
95140

96-
// Try to draft a token from primary cache (context/dynamic), validate with static cache:
97-
static llama_token try_draft(
141+
// Try to draft tokens from primary cache (context/dynamic), validate with static cache:
142+
static void try_draft(
98143
common_ngram_cache & nc_primary, const std::vector<common_ngram> & ngrams_primary, common_ngram_cache_part & part_static,
99-
const int * min_sample_size, const int * min_percent) {
144+
const int * min_sample_size, const int * min_percent, const draft_candidate & cp,
145+
const int ngram_min, std::vector<draft_candidate> & drafts_new) {
100146

101-
llama_token drafted_token = -1;
147+
for (int i = ngrams_primary.size()-1; i >= 0; --i) {
148+
const int nsc = (ngram_min + i) - (cp.draft.size() - 1);
149+
if (nsc < (ngram_min + i + 1)/2) {
150+
break;
151+
}
102152

103-
for (int i = ngrams_primary.size()-1; i >= 0 && drafted_token == -1; --i) {
104153
const common_ngram ngram_primary = ngrams_primary[i];
105154

106155
common_ngram_cache::iterator part_primary_it = nc_primary.find(ngram_primary);
@@ -109,10 +158,8 @@ static llama_token try_draft(
109158
}
110159
const common_ngram_cache_part part_primary = part_primary_it->second;
111160

112-
int max_count_primary = 0;
113-
int max_count_static = 0;
114161
int sum_count_primary = 0;
115-
llama_token max_token = -1;
162+
int sum_count_prod = 0;
116163

117164
for (std::pair<llama_token, int> token_count_primary : part_primary) {
118165
const llama_token token = token_count_primary.first;
@@ -122,44 +169,100 @@ static llama_token try_draft(
122169
const int32_t count_primary = token_count_primary.second;
123170
const int32_t count_static = token_count_static_it != part_static.end() ? 100*token_count_static_it->second : 1;
124171

125-
if (count_primary*count_static > max_count_primary*max_count_static) {
126-
max_token = token;
127-
max_count_primary = count_primary;
128-
max_count_static = count_static;
129-
}
130172
sum_count_primary += count_primary;
173+
sum_count_prod += count_primary*count_static;
131174
}
132175

133-
if (sum_count_primary < min_sample_size[i]) {
134-
continue;
135-
}
136-
if (100*max_count_primary < min_percent[i]*sum_count_primary) {
137-
continue;;
176+
for (std::pair<llama_token, int> token_count_primary : part_primary) {
177+
const llama_token token = token_count_primary.first;
178+
179+
common_ngram_cache_part::iterator token_count_static_it = part_static.find(token);
180+
181+
const int32_t count_primary = token_count_primary.second;
182+
const int32_t count_static = token_count_static_it != part_static.end() ? 100*token_count_static_it->second : 1;
183+
const int32_t count_prod = count_primary*count_static;
184+
185+
if (sum_count_primary < min_sample_size[i]) {
186+
continue;
187+
}
188+
189+
if (100*count_prod < min_percent[i]*sum_count_prod) {
190+
continue;
191+
}
192+
193+
draft_candidate cc;
194+
for (const llama_token & t : cp.draft) {
195+
cc.draft.push_back(t);
196+
}
197+
cc.draft.push_back(token);
198+
cc.nll = cp.nll - logf(1.0f*count_prod/sum_count_prod);
199+
cc.nsampled = nsc;
200+
201+
bool duplicate = false;
202+
for (const draft_candidate & co : drafts_new) {
203+
if (co.draft == cc.draft) {
204+
duplicate = true;
205+
break;
206+
}
207+
}
208+
if (duplicate) {
209+
continue;
210+
}
211+
212+
drafts_new.push_back(cc);
138213
}
139-
drafted_token = max_token;
140214
}
141-
142-
return drafted_token;
143215
}
144216

145217
void common_ngram_cache_draft(
146-
std::vector<llama_token> & inp, std::vector<llama_token> & draft, int n_draft, int ngram_min, int ngram_max,
218+
std::vector<llama_token> & inp, std::vector<std::vector<llama_token>> & drafts, int n_draft, int ngram_min, int ngram_max,
147219
common_ngram_cache & nc_context, common_ngram_cache & nc_dynamic, common_ngram_cache & nc_static
148220
) {
149-
GGML_ASSERT(draft.size() == 1);
221+
if (n_draft == 0) {
222+
return;
223+
}
224+
225+
GGML_ASSERT(drafts.size() == 1);
226+
GGML_ASSERT(drafts[0].size() == 1);
150227
const int inp_size = inp.size();
151228

152-
if (inp_size < LLAMA_NGRAM_STATIC) {
229+
if (inp_size < std::max(ngram_max, common_ngram_STATIC)) {
153230
return;
154231
}
155232

156-
while ((int) draft.size()-1 < n_draft) {
157-
llama_token drafted_token = -1;
233+
// While building the tree, store drafts with potential children in a heap:
234+
std::vector<draft_candidate> drafts_wip;
235+
236+
{
237+
draft_candidate candidate;
238+
candidate.draft.push_back(drafts[0][0]);
239+
candidate.nll = 0.0f;
240+
candidate.nsampled = LLAMA_NGRAM_MAX;
241+
drafts_wip.push_back(candidate);
242+
}
243+
244+
drafts.clear();
245+
int i_draft = 0;
246+
247+
// Temporarily hold new drafts in vector, only add part of them in the last iteration to exactly meet n_draft.
248+
std::vector<draft_candidate> drafts_new;
158249

159-
const int ngram_start_static = inp_size-LLAMA_NGRAM_STATIC + draft.size()-1;
250+
while (i_draft + ((int) drafts_new.size()) < n_draft && !(drafts_wip.empty() && drafts_new.empty())) {
251+
for (const draft_candidate & ndc : drafts_new) {
252+
drafts_wip.push_back(ndc);
253+
std::push_heap(drafts_wip.begin(), drafts_wip.end(), compare_draft_candidate());
254+
i_draft++;
255+
}
256+
drafts_new.clear();
257+
258+
std::pop_heap(drafts_wip.begin(), drafts_wip.end(), compare_draft_candidate());
259+
const draft_candidate cp = drafts_wip.back(); // cp = candidate parent
260+
drafts_wip.pop_back();
261+
262+
const int ngram_start_static = inp_size-common_ngram_STATIC + cp.draft.size()-1;
160263
common_ngram ngram_static;
161-
for (int j = ngram_start_static; j < ngram_start_static + LLAMA_NGRAM_STATIC; ++j) {
162-
ngram_static.tokens[j-ngram_start_static] = get_token(inp, draft, j);
264+
for (int j = ngram_start_static; j < ngram_start_static + common_ngram_STATIC; ++j) {
265+
ngram_static.tokens[j-ngram_start_static] = get_token(inp, cp.draft, j);
163266
}
164267
common_ngram_cache::iterator part_static_it = nc_static.find(ngram_static);
165268
common_ngram_cache_part part_static;
@@ -170,29 +273,37 @@ void common_ngram_cache_draft(
170273
// cd = context + dynamic
171274
std::vector<common_ngram> ngrams_cd;
172275
for (int ngram_size_cd = ngram_min; ngram_size_cd <= ngram_max; ++ngram_size_cd) {
173-
const int ngram_start_cd = inp_size-ngram_size_cd + draft.size()-1;
276+
const int ngram_start_cd = inp_size-ngram_size_cd + cp.draft.size()-1;
174277
common_ngram ngram_cd;
175278
for (int j = ngram_start_cd; j < ngram_start_cd + ngram_size_cd; ++j) {
176-
ngram_cd.tokens[j-ngram_start_cd] = get_token(inp, draft, j);
279+
ngram_cd.tokens[j-ngram_start_cd] = get_token(inp, cp.draft, j);
177280
}
178281
ngrams_cd.push_back(ngram_cd);
179282
}
180-
if (drafted_token == -1) {
181-
drafted_token = try_draft(nc_context, ngrams_cd, part_static, draft_min_sample_size_lax, draft_min_percent_lax);
182-
}
183-
if (drafted_token == -1) {
184-
drafted_token = try_draft(nc_dynamic, ngrams_cd, part_static, draft_min_sample_size_strict, draft_min_percent_strict);
185-
}
186-
if (drafted_token == -1) {
187-
drafted_token = try_draft(nc_static, ngram_static);
283+
284+
try_draft(nc_context, ngrams_cd, part_static, draft_min_sample_size_lax, draft_min_percent_lax, cp, ngram_min, drafts_new);
285+
try_draft(nc_dynamic, ngrams_cd, part_static, draft_min_sample_size_strict, draft_min_percent_lax, cp, ngram_min, drafts_new);
286+
try_draft(nc_static, ngram_static, draft_min_sample_size_strict, draft_min_percent_strict, cp, ngram_min, drafts_new);
287+
288+
if (drafts_new.empty()) {
289+
drafts.push_back(cp.draft);
290+
i_draft++;
188291
}
292+
}
189293

190-
if (drafted_token == -1) {
294+
for (const draft_candidate & dc : drafts_wip) { // dc = draft child
295+
drafts.push_back(dc.draft);
296+
}
297+
298+
std::sort(drafts_new.begin(), drafts_new.end(), compare_draft_candidate());
299+
300+
for (const draft_candidate & dc : drafts_new) {
301+
drafts.push_back(dc.draft);
302+
i_draft++;
303+
304+
if (i_draft >= n_draft) {
191305
break;
192306
}
193-
194-
LOG(" - draft candidate: token=%d\n", drafted_token);
195-
draft.push_back(drafted_token);
196307
}
197308
}
198309

common/ngram-cache.h

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,22 +13,22 @@
1313
// Data structures to map n-grams to empirical token probabilities:
1414

1515
struct common_ngram {
16-
llama_token tokens[LLAMA_NGRAM_MAX];
16+
llama_token tokens[common_ngram_MAX];
1717

1818
common_ngram() {
19-
for (int i = 0; i < LLAMA_NGRAM_MAX; ++i) {
19+
for (int i = 0; i < common_ngram_MAX; ++i) {
2020
tokens[i] = -1;
2121
}
2222
}
2323

2424
common_ngram(const llama_token * input, const int ngram_size) {
25-
for (int i = 0; i < LLAMA_NGRAM_MAX; ++i) {
25+
for (int i = 0; i < common_ngram_MAX; ++i) {
2626
tokens[i] = i < ngram_size ? input[i] : -1;
2727
}
2828
}
2929

3030
bool operator==(const common_ngram & other) const {
31-
for (int i = 0; i < LLAMA_NGRAM_MAX; ++i) {
31+
for (int i = 0; i < common_ngram_MAX; ++i) {
3232
if (tokens[i] != other.tokens[i]) {
3333
return false;
3434
}
@@ -47,7 +47,7 @@ struct common_token_hash_function {
4747
struct common_ngram_hash_function {
4848
size_t operator()(const common_ngram & ngram) const {
4949
size_t hash = common_token_hash_function{}(ngram.tokens[0]);
50-
for (int i = 1; i < LLAMA_NGRAM_MAX; ++i) {
50+
for (int i = 1; i < common_ngram_MAX; ++i) {
5151
hash ^= common_token_hash_function{}(ngram.tokens[i]);
5252
}
5353
return hash;
@@ -60,6 +60,7 @@ typedef std::unordered_map<llama_token, int32_t> common_ngram_cache_part;
6060
// n-gram -> empirical distribution of following tokens
6161
typedef std::unordered_map<common_ngram, common_ngram_cache_part, common_ngram_hash_function> common_ngram_cache;
6262

63+
typedef std::vector<llama_token> llama_draft_t;
6364

6465
// Update an ngram cache with tokens.
6566
// ngram_cache: the cache to modify.
@@ -82,7 +83,7 @@ void common_ngram_cache_update(
8283
// nc_dynamic: ngram cache based on previous user generations.
8384
// nc_static: ngram cache generated from a large text corpus, used for validation.
8485
void common_ngram_cache_draft(
85-
std::vector<llama_token> & inp, std::vector<llama_token> & draft, int n_draft, int ngram_min, int ngram_max,
86+
std::vector<llama_token> & inp, std::vector<llama_draft_t> & drafts, int n_draft, int ngram_min, int ngram_max,
8687
common_ngram_cache & nc_context, common_ngram_cache & nc_dynamic, common_ngram_cache & nc_static);
8788

8889
// Save an ngram cache to a file.

0 commit comments

Comments
 (0)