@@ -55,52 +55,101 @@ static llama_token get_token(const std::vector<llama_token> & inp, const std::ve
55
55
return i < inp.size () ? inp[i] : draft[1 + i - inp.size ()];
56
56
}
57
57
58
- // If sample size or percentage are below these thresholds the draft is aborted early :
59
- constexpr int draft_min_sample_size_lax[LLAMA_NGRAM_MAX] = { 2 , 2 , 1 , 1 };
60
- constexpr int draft_min_percent_lax[LLAMA_NGRAM_MAX] = {66 , 50 , 50 , 50 };
58
+ // Sample size and percentage must meet these thresholds to be added to the draft tree :
59
+ constexpr int draft_min_sample_size_lax[LLAMA_NGRAM_MAX] = { 1 , 1 , 1 , 1 };
60
+ constexpr int draft_min_percent_lax[LLAMA_NGRAM_MAX] = {20 , 20 , 10 , 10 };
61
61
constexpr int draft_min_sample_size_strict[LLAMA_NGRAM_MAX] = { 4 , 3 , 2 , 2 };
62
- constexpr int draft_min_percent_strict[LLAMA_NGRAM_MAX] = {75 , 66 , 66 , 66 };
62
+ constexpr int draft_min_percent_strict[LLAMA_NGRAM_MAX] = {50 , 50 , 50 , 50 };
63
+
64
+ struct draft_candidate {
65
+ llama_draft_t draft;
66
+ float nll;
67
+ int nsampled;
68
+ };
69
+
70
+ struct compare_draft_candidate {
71
+ bool operator ()(const draft_candidate & a, const draft_candidate & b){
72
+ if (a.nsampled > b.nsampled ) {
73
+ return true ;
74
+ }
75
+ if (a.nsampled < b.nsampled ) {
76
+ return false ;
77
+ }
78
+ return a.nll < b.nll ;
79
+ }
80
+ };
81
+
82
+ // Helper function that tries to draft tokens from only the static ngram cache:
83
+ static void try_draft (
84
+ common_ngram_cache & nc_static, const common_ngram & ngram_static,
85
+ const int * min_sample_size, const int * min_percent, const draft_candidate & cp,
86
+ const int ngram_min, std::vector<draft_candidate> & drafts_new) {
87
+
88
+ const int nsc = (ngram_min + common_ngram_STATIC) - (cp.draft .size () - 1 );
89
+ if (nsc < (ngram_min + common_ngram_STATIC + 1 )/2 ) {
90
+ return ;
91
+ }
63
92
64
- // Helper function that tries to draft a token from only the static ngram cache:
65
- static llama_token try_draft (common_ngram_cache & nc_static, const common_ngram ngram_static) {
66
93
common_ngram_cache::iterator part_static_it = nc_static.find (ngram_static);
67
94
if (part_static_it == nc_static.end ()) {
68
- return - 1 ;
95
+ return ;
69
96
}
70
97
const common_ngram_cache_part part_static = part_static_it->second ;
71
98
72
- int max_count_static = 0 ;
73
99
int sum_count_static = 0 ;
74
- llama_token max_token = -1 ;
75
100
76
101
for (std::pair<llama_token, int > token_count_static : part_static) {
77
- const llama_token token = token_count_static.first ;
78
102
const int32_t count_static = token_count_static.second ;
79
103
80
- if (count_static > max_count_static) {
81
- max_token = token;
82
- max_count_static = count_static;
83
- }
84
104
sum_count_static += count_static;
85
105
}
86
106
87
- if (sum_count_static < draft_min_sample_size_lax[LLAMA_NGRAM_STATIC-1 ]) {
88
- return -1 ;
89
- }
90
- if (100 *max_count_static < draft_min_percent_lax[LLAMA_NGRAM_STATIC-1 ]*sum_count_static) {
91
- return -1 ;
107
+ for (std::pair<llama_token, int > token_count_static : part_static) {
108
+ const llama_token token = token_count_static.first ;
109
+ const int32_t count_static = token_count_static.second ;
110
+
111
+ if (sum_count_static < min_sample_size[common_ngram_STATIC-1 ]) {
112
+ continue ;
113
+ }
114
+ if (100 *count_static < min_percent[common_ngram_STATIC-1 ]*sum_count_static) {
115
+ continue ;;
116
+ }
117
+
118
+ draft_candidate cc;
119
+ for (const llama_token & t : cp.draft ) {
120
+ cc.draft .push_back (t);
121
+ }
122
+ cc.draft .push_back (token);
123
+ cc.nll = cp.nll - logf (1 .0f *count_static/sum_count_static);
124
+ cc.nsampled = nsc;
125
+
126
+ bool duplicate = false ;
127
+ for (const draft_candidate & co : drafts_new) {
128
+ if (co.draft == cc.draft ) {
129
+ duplicate = true ;
130
+ break ;
131
+ }
132
+ }
133
+ if (duplicate) {
134
+ continue ;
135
+ }
136
+
137
+ drafts_new.push_back (cc);
92
138
}
93
- return max_token;
94
139
}
95
140
96
- // Try to draft a token from primary cache (context/dynamic), validate with static cache:
97
- static llama_token try_draft (
141
+ // Try to draft tokens from primary cache (context/dynamic), validate with static cache:
142
+ static void try_draft (
98
143
common_ngram_cache & nc_primary, const std::vector<common_ngram> & ngrams_primary, common_ngram_cache_part & part_static,
99
- const int * min_sample_size, const int * min_percent) {
144
+ const int * min_sample_size, const int * min_percent, const draft_candidate & cp,
145
+ const int ngram_min, std::vector<draft_candidate> & drafts_new) {
100
146
101
- llama_token drafted_token = -1 ;
147
+ for (int i = ngrams_primary.size ()-1 ; i >= 0 ; --i) {
148
+ const int nsc = (ngram_min + i) - (cp.draft .size () - 1 );
149
+ if (nsc < (ngram_min + i + 1 )/2 ) {
150
+ break ;
151
+ }
102
152
103
- for (int i = ngrams_primary.size ()-1 ; i >= 0 && drafted_token == -1 ; --i) {
104
153
const common_ngram ngram_primary = ngrams_primary[i];
105
154
106
155
common_ngram_cache::iterator part_primary_it = nc_primary.find (ngram_primary);
@@ -109,10 +158,8 @@ static llama_token try_draft(
109
158
}
110
159
const common_ngram_cache_part part_primary = part_primary_it->second ;
111
160
112
- int max_count_primary = 0 ;
113
- int max_count_static = 0 ;
114
161
int sum_count_primary = 0 ;
115
- llama_token max_token = - 1 ;
162
+ int sum_count_prod = 0 ;
116
163
117
164
for (std::pair<llama_token, int > token_count_primary : part_primary) {
118
165
const llama_token token = token_count_primary.first ;
@@ -122,44 +169,100 @@ static llama_token try_draft(
122
169
const int32_t count_primary = token_count_primary.second ;
123
170
const int32_t count_static = token_count_static_it != part_static.end () ? 100 *token_count_static_it->second : 1 ;
124
171
125
- if (count_primary*count_static > max_count_primary*max_count_static) {
126
- max_token = token;
127
- max_count_primary = count_primary;
128
- max_count_static = count_static;
129
- }
130
172
sum_count_primary += count_primary;
173
+ sum_count_prod += count_primary*count_static;
131
174
}
132
175
133
- if (sum_count_primary < min_sample_size[i]) {
134
- continue ;
135
- }
136
- if (100 *max_count_primary < min_percent[i]*sum_count_primary) {
137
- continue ;;
176
+ for (std::pair<llama_token, int > token_count_primary : part_primary) {
177
+ const llama_token token = token_count_primary.first ;
178
+
179
+ common_ngram_cache_part::iterator token_count_static_it = part_static.find (token);
180
+
181
+ const int32_t count_primary = token_count_primary.second ;
182
+ const int32_t count_static = token_count_static_it != part_static.end () ? 100 *token_count_static_it->second : 1 ;
183
+ const int32_t count_prod = count_primary*count_static;
184
+
185
+ if (sum_count_primary < min_sample_size[i]) {
186
+ continue ;
187
+ }
188
+
189
+ if (100 *count_prod < min_percent[i]*sum_count_prod) {
190
+ continue ;
191
+ }
192
+
193
+ draft_candidate cc;
194
+ for (const llama_token & t : cp.draft ) {
195
+ cc.draft .push_back (t);
196
+ }
197
+ cc.draft .push_back (token);
198
+ cc.nll = cp.nll - logf (1 .0f *count_prod/sum_count_prod);
199
+ cc.nsampled = nsc;
200
+
201
+ bool duplicate = false ;
202
+ for (const draft_candidate & co : drafts_new) {
203
+ if (co.draft == cc.draft ) {
204
+ duplicate = true ;
205
+ break ;
206
+ }
207
+ }
208
+ if (duplicate) {
209
+ continue ;
210
+ }
211
+
212
+ drafts_new.push_back (cc);
138
213
}
139
- drafted_token = max_token;
140
214
}
141
-
142
- return drafted_token;
143
215
}
144
216
145
217
void common_ngram_cache_draft (
146
- std::vector<llama_token> & inp, std::vector<llama_token> & draft , int n_draft, int ngram_min, int ngram_max,
218
+ std::vector<llama_token> & inp, std::vector<std::vector< llama_token>> & drafts , int n_draft, int ngram_min, int ngram_max,
147
219
common_ngram_cache & nc_context, common_ngram_cache & nc_dynamic, common_ngram_cache & nc_static
148
220
) {
149
- GGML_ASSERT (draft.size () == 1 );
221
+ if (n_draft == 0 ) {
222
+ return ;
223
+ }
224
+
225
+ GGML_ASSERT (drafts.size () == 1 );
226
+ GGML_ASSERT (drafts[0 ].size () == 1 );
150
227
const int inp_size = inp.size ();
151
228
152
- if (inp_size < LLAMA_NGRAM_STATIC ) {
229
+ if (inp_size < std::max (ngram_max, common_ngram_STATIC) ) {
153
230
return ;
154
231
}
155
232
156
- while ((int ) draft.size ()-1 < n_draft) {
157
- llama_token drafted_token = -1 ;
233
+ // While building the tree, store drafts with potential children in a heap:
234
+ std::vector<draft_candidate> drafts_wip;
235
+
236
+ {
237
+ draft_candidate candidate;
238
+ candidate.draft .push_back (drafts[0 ][0 ]);
239
+ candidate.nll = 0 .0f ;
240
+ candidate.nsampled = LLAMA_NGRAM_MAX;
241
+ drafts_wip.push_back (candidate);
242
+ }
243
+
244
+ drafts.clear ();
245
+ int i_draft = 0 ;
246
+
247
+ // Temporarily hold new drafts in vector, only add part of them in the last iteration to exactly meet n_draft.
248
+ std::vector<draft_candidate> drafts_new;
158
249
159
- const int ngram_start_static = inp_size-LLAMA_NGRAM_STATIC + draft.size ()-1 ;
250
+ while (i_draft + ((int ) drafts_new.size ()) < n_draft && !(drafts_wip.empty () && drafts_new.empty ())) {
251
+ for (const draft_candidate & ndc : drafts_new) {
252
+ drafts_wip.push_back (ndc);
253
+ std::push_heap (drafts_wip.begin (), drafts_wip.end (), compare_draft_candidate ());
254
+ i_draft++;
255
+ }
256
+ drafts_new.clear ();
257
+
258
+ std::pop_heap (drafts_wip.begin (), drafts_wip.end (), compare_draft_candidate ());
259
+ const draft_candidate cp = drafts_wip.back (); // cp = candidate parent
260
+ drafts_wip.pop_back ();
261
+
262
+ const int ngram_start_static = inp_size-common_ngram_STATIC + cp.draft .size ()-1 ;
160
263
common_ngram ngram_static;
161
- for (int j = ngram_start_static; j < ngram_start_static + LLAMA_NGRAM_STATIC ; ++j) {
162
- ngram_static.tokens [j-ngram_start_static] = get_token (inp, draft, j);
264
+ for (int j = ngram_start_static; j < ngram_start_static + common_ngram_STATIC ; ++j) {
265
+ ngram_static.tokens [j-ngram_start_static] = get_token (inp, cp. draft , j);
163
266
}
164
267
common_ngram_cache::iterator part_static_it = nc_static.find (ngram_static);
165
268
common_ngram_cache_part part_static;
@@ -170,29 +273,37 @@ void common_ngram_cache_draft(
170
273
// cd = context + dynamic
171
274
std::vector<common_ngram> ngrams_cd;
172
275
for (int ngram_size_cd = ngram_min; ngram_size_cd <= ngram_max; ++ngram_size_cd) {
173
- const int ngram_start_cd = inp_size-ngram_size_cd + draft.size ()-1 ;
276
+ const int ngram_start_cd = inp_size-ngram_size_cd + cp. draft .size ()-1 ;
174
277
common_ngram ngram_cd;
175
278
for (int j = ngram_start_cd; j < ngram_start_cd + ngram_size_cd; ++j) {
176
- ngram_cd.tokens [j-ngram_start_cd] = get_token (inp, draft, j);
279
+ ngram_cd.tokens [j-ngram_start_cd] = get_token (inp, cp. draft , j);
177
280
}
178
281
ngrams_cd.push_back (ngram_cd);
179
282
}
180
- if (drafted_token == - 1 ) {
181
- drafted_token = try_draft (nc_context, ngrams_cd, part_static, draft_min_sample_size_lax, draft_min_percent_lax);
182
- }
183
- if (drafted_token == - 1 ) {
184
- drafted_token = try_draft (nc_dynamic, ngrams_cd, part_static, draft_min_sample_size_strict, draft_min_percent_strict);
185
- }
186
- if (drafted_token == - 1 ) {
187
- drafted_token = try_draft (nc_static, ngram_static) ;
283
+
284
+ try_draft (nc_context, ngrams_cd, part_static, draft_min_sample_size_lax, draft_min_percent_lax, cp, ngram_min, drafts_new );
285
+ try_draft (nc_dynamic, ngrams_cd, part_static, draft_min_sample_size_strict, draft_min_percent_lax, cp, ngram_min, drafts_new);
286
+ try_draft (nc_static, ngram_static, draft_min_sample_size_strict, draft_min_percent_strict, cp, ngram_min, drafts_new);
287
+
288
+ if (drafts_new. empty ()) {
289
+ drafts. push_back (cp. draft );
290
+ i_draft++ ;
188
291
}
292
+ }
189
293
190
- if (drafted_token == -1 ) {
294
+ for (const draft_candidate & dc : drafts_wip) { // dc = draft child
295
+ drafts.push_back (dc.draft );
296
+ }
297
+
298
+ std::sort (drafts_new.begin (), drafts_new.end (), compare_draft_candidate ());
299
+
300
+ for (const draft_candidate & dc : drafts_new) {
301
+ drafts.push_back (dc.draft );
302
+ i_draft++;
303
+
304
+ if (i_draft >= n_draft) {
191
305
break ;
192
306
}
193
-
194
- LOG (" - draft candidate: token=%d\n " , drafted_token);
195
- draft.push_back (drafted_token);
196
307
}
197
308
}
198
309
0 commit comments