Skip to content

Commit 0cffe93

Browse files
committed
logit_bias: apply configurable escalating EOG bias at low n_remain
give eog an increasing (with length - per token, could be per codepoint in future) bias, only after a configured amount generated add to `sample_apply` an `n_remain` param, which is safer than having logit_bias maintain state for how many times it's called (which would lead to wrong assumptions e.g. when calling multiple times per token). see new command line options (incl a request 'after' instead of 'remain'): -eog, --eog-bias-per-tok N when fewer than -start-eog-at-remain tokens are left to generate after -n, add this bias eog for each subsequent token (default: 0.0) -remain, --start-eog-at-remain N start applying -eog bias when this many tokens remain of the -n max (default: 0.0) -after, --start-eog-after N start applying -eog bias after this many tokens generated (default: 1000000000.0); whichever happens first between -remain and -after applies Verified that eog bias was effective at avoiding overgeneration and is a reasonable supplement or alternative to editing the prompt; a *constant* eog bias, already supported in samplers, is likely to allow pathologically short outputs.
1 parent e434e69 commit 0cffe93

File tree

19 files changed

+171
-72
lines changed

19 files changed

+171
-72
lines changed

common/arg.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1205,6 +1205,15 @@ bool common_params_parse(int argc, char ** argv, common_params & params, llama_e
12051205
exit(1); // for other exceptions, we exit with status code 1
12061206
}
12071207

1208+
float &pafter = params.sampling.start_eog_after;
1209+
float &premain = params.sampling.start_eog_at_remain;
1210+
float const premain0 = premain;
1211+
float remain = params.n_predict - pafter;
1212+
if (premain < remain)
1213+
premain = remain;
1214+
if (params.sampling.eog_bias_per_tok)
1215+
LOG_INF("%s: n_predict=%d (first of start_eog_at_remain=%0.3g start_eog_after=%0.3g) => (remain=%0.3g) eog-bias-per-tok=%0.3g\n", __func__, (int) params.n_predict,
1216+
(double) premain0, (double) pafter, (double)premain, (double) params.sampling.eog_bias_per_tok);
12081217
return true;
12091218
}
12101219

@@ -1937,6 +1946,27 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
19371946
}
19381947
}
19391948
).set_sparam());
1949+
add_opt(common_arg(
1950+
{"-eog", "--eog-bias-per-tok"}, "N",
1951+
string_format("when fewer than -start-eog-at-remain tokens are left to generate after -n, add this bias eog for each subsequent token (default: %.1f)", (double)params.sampling.eog_bias_per_tok),
1952+
[](common_params & params, const std::string & value) {
1953+
params.sampling.eog_bias_per_tok = std::stof(value);
1954+
}
1955+
).set_sparam());
1956+
add_opt(common_arg(
1957+
{"-remain", "--start-eog-at-remain"}, "N",
1958+
string_format("start applying -eog bias when this many tokens remain of the -n max (default: %.1f)", (double)params.sampling.start_eog_at_remain),
1959+
[](common_params & params, const std::string & value) {
1960+
params.sampling.start_eog_at_remain = std::stof(value);
1961+
}
1962+
).set_sparam());
1963+
add_opt(common_arg(
1964+
{"-after", "--start-eog-after"}, "N",
1965+
string_format("start applying -eog bias after this many tokens generated (default: %.1f); whichever happens first between -remain and -after applies", (double)params.sampling.start_eog_after),
1966+
[](common_params & params, const std::string & value) {
1967+
params.sampling.start_eog_after = std::stof(value);
1968+
}
1969+
).set_sparam());
19401970
add_opt(common_arg(
19411971
{"--grammar"}, "GRAMMAR",
19421972
string_format("BNF-like grammar to constrain generations (see samples in grammars/ dir) (default: '%s')", params.sampling.grammar.c_str()),

common/common.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,13 @@ struct common_params_sampling {
178178

179179
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
180180

181+
float eog_bias_per_tok = 0; // escalating bias added to eog per token after:
182+
/// this many remaining tokens (before applying eog_bias_per_tok) ...
183+
float start_eog_at_remain = 0;
184+
// or (whichever is first) after start_eog_after many generated:
185+
/// (i.e. EOG logit bias = max(0,start_eog_after = max(start_eog_after, n_remain - start_eog_at_remain)) * eog_bias_per_tok)
186+
float start_eog_after = 1e9;
187+
181188
// print the parameters into a string
182189
std::string print() const;
183190
};

common/sampling.cpp

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,10 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
226226
llama_sampler_init_logit_bias(
227227
llama_vocab_n_tokens(vocab),
228228
params.logit_bias.size(),
229-
params.logit_bias.data()));
229+
params.logit_bias.data(),
230+
params.eog_bias_per_tok,
231+
params.start_eog_at_remain,
232+
vocab));
230233

231234
if (params.mirostat == 0) {
232235
for (const auto & cnstr : params.samplers) {
@@ -335,18 +338,18 @@ void common_perf_print(const struct llama_context * ctx, const struct common_sam
335338
}
336339
}
337340

338-
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) {
341+
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first, float n_remain) {
339342
gsmpl->set_logits(ctx, idx);
340343

341344
auto & grmr = gsmpl->grmr;
342345
auto & chain = gsmpl->chain;
343346
auto & cur_p = gsmpl->cur_p; // initialized by set_logits
344347

345348
if (grammar_first) {
346-
llama_sampler_apply(grmr, &cur_p);
349+
llama_sampler_apply(grmr, &cur_p, n_remain);
347350
}
348351

349-
llama_sampler_apply(chain, &cur_p);
352+
llama_sampler_apply(chain, &cur_p, n_remain);
350353

351354
GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration");
352355

@@ -361,7 +364,7 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
361364
llama_token_data single_token_data = { id, 1.0f, 0.0f };
362365
llama_token_data_array single_token_data_array = { &single_token_data, 1, -1, false };
363366

364-
llama_sampler_apply(grmr, &single_token_data_array);
367+
llama_sampler_apply(grmr, &single_token_data_array, n_remain);
365368

366369
const bool is_valid = single_token_data_array.data[0].logit != -INFINITY;
367370
if (is_valid) {
@@ -373,23 +376,23 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
373376
// if the token is not valid, sample again, but first apply the grammar sampler and then the sampling chain
374377
gsmpl->set_logits(ctx, idx);
375378

376-
llama_sampler_apply(grmr, &cur_p);
377-
llama_sampler_apply(chain, &cur_p);
379+
llama_sampler_apply(grmr, &cur_p, n_remain);
380+
llama_sampler_apply(chain, &cur_p, n_remain);
378381

379382
GGML_ASSERT(cur_p.selected != -1 && "no selected token during re-sampling - check your sampling configuration");
380383

381384
return cur_p.data[cur_p.selected].id;
382385
}
383386

384-
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first) {
387+
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first, float n_remain) {
385388
GGML_ASSERT(idxs.size() == draft.size() + 1 && "idxs.size() must be draft.size() + 1");
386389

387390
std::vector<llama_token> result;
388391
result.reserve(idxs.size());
389392

390393
size_t i = 0;
391394
for (; i < draft.size(); i++) {
392-
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);
395+
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first, n_remain);
393396

394397
common_sampler_accept(gsmpl, id, true);
395398

@@ -401,7 +404,7 @@ std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sample
401404
}
402405

403406
if (i == draft.size()) {
404-
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);
407+
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first, n_remain);
405408

406409
common_sampler_accept(gsmpl, id, true);
407410

@@ -411,13 +414,13 @@ std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sample
411414
return result;
412415
}
413416

414-
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first) {
417+
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first, float n_remain) {
415418
std::vector<int> idxs(draft.size() + 1);
416419
for (size_t i = 0; i < idxs.size(); ++i) {
417420
idxs[i] = i;
418421
}
419422

420-
return common_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft, grammar_first);
423+
return common_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft, grammar_first, n_remain);
421424
}
422425

423426
uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) {

common/sampling.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ void common_perf_print(const struct llama_context * ctx, const struct common_sam
5858
// if grammar_first is true, the grammar is applied before the samplers (slower)
5959
// useful in cases where all the resulting candidates (not just the sampled one) must fit the grammar
6060
//
61-
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first = false);
61+
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first = false, float n_remain = 0);
6262

6363
// generalized version of common_sampler_sample
6464
//
@@ -76,10 +76,10 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
7676
//
7777
// returns at least 1 token, up to idxs.size()
7878
//
79-
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first = false);
79+
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first = false, float n_remain = 0);
8080

8181
// assume idxs == [ 0, 1, 2, ..., draft.size() ]
82-
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first = false);
82+
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first = false, float n_remain = 0);
8383

8484
uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl);
8585

common/speculative.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -238,12 +238,12 @@ llama_tokens common_speculative_gen_draft(
238238
llama_decode(ctx, batch);
239239

240240
common_sampler_reset(smpl);
241-
241+
int n_remain = params.n_draft;
242242
// sample n_draft tokens from the draft model
243243
for (int i = 0; i < params.n_draft; ++i) {
244244
common_batch_clear(batch);
245245

246-
common_sampler_sample(smpl, ctx, 0, true);
246+
common_sampler_sample(smpl, ctx, 0, true, --n_remain);
247247

248248
const auto * cur_p = common_sampler_get_candidates(smpl);
249249

examples/batched/batched.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,9 @@ int main(int argc, char ** argv) {
162162

163163
const auto t_main_start = ggml_time_us();
164164

165+
int n_remain = n_predict;
165166
while (n_cur <= n_predict) {
167+
--n_remain;
166168
// prepare the next batch
167169
common_batch_clear(batch);
168170

@@ -173,7 +175,7 @@ int main(int argc, char ** argv) {
173175
continue;
174176
}
175177

176-
const llama_token new_token_id = llama_sampler_sample(smpl, ctx, i_batch[i]);
178+
const llama_token new_token_id = llama_sampler_sample(smpl, ctx, i_batch[i], n_remain);
177179

178180
// is it an end of generation? -> mark the stream as finished
179181
if (llama_vocab_is_eog(vocab, new_token_id) || n_cur == n_predict) {

examples/gritlm/gritlm.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ static std::string generate(llama_context * ctx, llama_sampler * smpl, const std
108108

109109
std::vector<llama_token> inputs = common_tokenize(vocab, prompt, false, true);
110110
int32_t i_current_token = 0;
111-
111+
int n_remain = 32;
112112
while (true) {
113113
common_batch_clear(bat);
114114
{
@@ -122,7 +122,7 @@ static std::string generate(llama_context * ctx, llama_sampler * smpl, const std
122122

123123
llama_decode(ctx, bat);
124124

125-
llama_token token = llama_sampler_sample(smpl, ctx, bat.n_tokens - 1);
125+
llama_token token = llama_sampler_sample(smpl, ctx, bat.n_tokens - 1, --n_remain);
126126

127127
if (token == eos_token) {
128128
break;

examples/lookahead/lookahead.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,7 @@ int main(int argc, char ** argv) {
253253

254254
int seq_id_best = 0;
255255

256+
int n_remain = N;
256257
for (int v = 0; v < N; ++v) {
257258
int i_batch = 0;
258259

@@ -274,8 +275,9 @@ int main(int argc, char ** argv) {
274275
}
275276
}
276277

278+
--n_remain;
277279
// sample the next token
278-
id = common_sampler_sample(smpl, ctx, i_batch);
280+
id = common_sampler_sample(smpl, ctx, i_batch, n_remain);
279281

280282
common_sampler_accept(smpl, id, true);
281283

@@ -349,10 +351,11 @@ int main(int argc, char ** argv) {
349351
tokens_j[j] = tokens_j[j + 1];
350352
}
351353

354+
unsigned constexpr NA = (unsigned)-1;
352355
if (v == 0) {
353356
// sample from the last level
354357
for (int i = 0; i < W; i++) {
355-
tokens_j[N - 2][i] = common_sampler_sample(smpl, ctx, ngrams_cur.size()*(N-1) + W*(N - 2) + i);
358+
tokens_j[N - 2][i] = common_sampler_sample(smpl, ctx, ngrams_cur.size()*(N-1) + W*(N - 2) + i, NA);
356359
}
357360
} else {
358361
for (int i = 0; i < W; i++) {

examples/lookup/lookup.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,8 @@ int main(int argc, char ** argv){
117117
int i_dft = 0;
118118
while (true) {
119119
// sample from the target model
120-
llama_token id = common_sampler_sample(smpl, ctx, i_dft);
120+
unsigned const n_remain = params.n_predict - n_predict;
121+
llama_token id = common_sampler_sample(smpl, ctx, i_dft, n_remain);
121122

122123
common_sampler_accept(smpl, id, true);
123124

examples/passkey/passkey.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,10 +217,12 @@ int main(int argc, char ** argv) {
217217

218218
const auto t_main_start = ggml_time_us();
219219

220+
int n_remain = n_len - n_cur;
220221
while (n_cur <= n_len) {
222+
--n_remain;
221223
// sample the next token
222224
{
223-
const llama_token new_token_id = llama_sampler_sample(smpl, ctx, batch.n_tokens - 1);
225+
const llama_token new_token_id = llama_sampler_sample(smpl, ctx, batch.n_tokens - 1, n_remain);
224226

225227
// is it an end of generation?
226228
if (llama_vocab_is_eog(vocab, new_token_id) || n_cur == n_len) {

0 commit comments

Comments
 (0)