Skip to content

Commit 6ffd4e9

Browse files
authored
server : pre-calculate EOG logit biases (ggml-org#14721)
ggml-ci
1 parent e4841d2 commit 6ffd4e9

File tree

3 files changed

+17
-15
lines changed

3 files changed

+17
-15
lines changed

common/common.cpp

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1005,15 +1005,21 @@ struct common_init_result common_init_from_params(common_params & params) {
10051005
params.sampling.ignore_eos = false;
10061006
}
10071007

1008-
if (params.sampling.ignore_eos) {
1009-
for (llama_token i = 0; i < llama_vocab_n_tokens(vocab); i++) {
1010-
if (llama_vocab_is_eog(vocab, i)) {
1011-
LOG_INF("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(lctx, i).c_str(), -INFINITY);
1012-
params.sampling.logit_bias.push_back({i, -INFINITY});
1013-
}
1008+
// initialize once
1009+
for (llama_token i = 0; i < llama_vocab_n_tokens(vocab); i++) {
1010+
if (llama_vocab_is_eog(vocab, i)) {
1011+
LOG_INF("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(lctx, i).c_str(), -INFINITY);
1012+
params.sampling.logit_bias_eog.push_back({i, -INFINITY});
10141013
}
10151014
}
10161015

1016+
if (params.sampling.ignore_eos) {
1017+
// add EOG biases to the active set of logit biases
1018+
params.sampling.logit_bias.insert(
1019+
params.sampling.logit_bias.end(),
1020+
params.sampling.logit_bias_eog.begin(), params.sampling.logit_bias_eog.end());
1021+
}
1022+
10171023
if (params.sampling.penalty_last_n == -1) {
10181024
LOG_INF("%s: setting penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
10191025
params.sampling.penalty_last_n = llama_n_ctx(lctx);

common/common.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,8 @@ struct common_params_sampling {
177177
std::vector<common_grammar_trigger> grammar_triggers; // optional triggers (for lazy grammars)
178178
std::set<llama_token> preserved_tokens;
179179

180-
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
180+
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
181+
std::vector<llama_logit_bias> logit_bias_eog; // pre-calculated logit biases for EOG tokens
181182

182183
// print the parameters into a string
183184
std::string print() const;

tools/server/server.cpp

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -473,12 +473,9 @@ struct server_task {
473473

474474
params.sampling.ignore_eos = json_value(data, "ignore_eos", params_base.sampling.ignore_eos);
475475
if (params.sampling.ignore_eos) {
476-
for (llama_token i = 0; i < llama_vocab_n_tokens(vocab); i++) {
477-
if (llama_vocab_is_eog(vocab, i)) {
478-
//SRV_DBG("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(ctx, i).c_str(), -INFINITY);
479-
params.sampling.logit_bias.push_back({i, -INFINITY});
480-
}
481-
}
476+
params.sampling.logit_bias.insert(
477+
params.sampling.logit_bias.end(),
478+
defaults.sampling.logit_bias_eog.begin(), defaults.sampling.logit_bias_eog.end());
482479
}
483480
}
484481

@@ -1906,7 +1903,6 @@ struct server_context {
19061903

19071904
bool clean_kv_cache = true;
19081905
bool add_bos_token = true;
1909-
bool has_eos_token = false;
19101906

19111907
int32_t n_ctx; // total context for all clients / slots
19121908

@@ -1965,7 +1961,6 @@ struct server_context {
19651961
n_ctx = llama_n_ctx(ctx);
19661962

19671963
add_bos_token = llama_vocab_get_add_bos(vocab);
1968-
has_eos_token = llama_vocab_eos(vocab) != LLAMA_TOKEN_NULL;
19691964

19701965
if (!params_base.speculative.model.path.empty() || !params_base.speculative.model.hf_repo.empty()) {
19711966
SRV_INF("loading draft model '%s'\n", params_base.speculative.model.path.c_str());

0 commit comments

Comments
 (0)