Skip to content

Commit d5eea13

Browse files
committed
server : add token healing support
1 parent fc8773d commit d5eea13

File tree

2 files changed

+72
-7
lines changed

2 files changed

+72
-7
lines changed

examples/server/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -436,6 +436,8 @@ node index.js
436436

437437
`json_schema`: Set a JSON schema for grammar-based sampling (e.g. `{"items": {"type": "string"}, "minItems": 10, "maxItems": 100}` of a list of strings, or `{}` for any JSON). See [tests](../../tests/test-json-schema-to-grammar.cpp) for supported features. Default: no JSON schema.
438438

439+
`token_healing`: Set token healing strategy. Default: `0`, which is disabled.
440+
439441
`seed`: Set the random number generator (RNG) seed. Default: `-1`, which is a random seed.
440442

441443
`ignore_eos`: Ignore end of stream token and continue generating. Default: `false`

examples/server/server.cpp

Lines changed: 70 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,7 @@ struct server_slot {
185185
// stats
186186
size_t n_sent_text = 0; // number of sent text character
187187
size_t n_sent_token_probs = 0;
188+
size_t n_th_prefix = 0; // size of remaining token healing prefix
188189

189190
int64_t t_start_process_prompt;
190191
int64_t t_start_generation;
@@ -206,6 +207,7 @@ struct server_slot {
206207
infill = false;
207208
ga_i = 0;
208209
n_past_se = 0;
210+
n_th_prefix = 0;
209211

210212
generated_token_probs.clear();
211213
}
@@ -1094,6 +1096,36 @@ struct server_context {
10941096
}
10951097
}
10961098

1099+
{
1100+
const auto & token_healing_str = data.find("token_healing");
1101+
auto & th_enabled = slot.sparams.token_healing_enabled;
1102+
th_enabled = default_sparams.token_healing_enabled;
1103+
if (token_healing_str != data.end() && token_healing_str->is_string()) {
1104+
const auto value = token_healing_str->get<std::string>();
1105+
auto & th_type = slot.sparams.token_healing_type;
1106+
auto & th_n_rollback = slot.sparams.token_healing_n_rollback;
1107+
th_enabled = true;
1108+
/**/ if (value == "0" ) { th_enabled = false; }
1109+
else if (value == "1" ) { th_type = llama_token_healing_type::ROLLBACK_LAST; }
1110+
else if (value == "d1") { th_type = llama_token_healing_type::DYNAMIC_ONCE; }
1111+
else if (value == "d" ) { th_type = llama_token_healing_type::DYNAMIC_MULTI; }
1112+
else if (value[0] == 'r' ) {
1113+
th_type = llama_token_healing_type::ROLLBACK_MULTI;
1114+
th_n_rollback = std::stoi(value.substr(1));
1115+
if (th_n_rollback <= 0) {
1116+
th_enabled = false;
1117+
}
1118+
} else { th_enabled = false; }
1119+
1120+
LOG_VERBOSE("token healing", {
1121+
{"id_slot", slot.id},
1122+
{"enabled", th_enabled},
1123+
{"type", th_type},
1124+
{"n_rollback", th_n_rollback}
1125+
});
1126+
}
1127+
}
1128+
10971129
{
10981130
if (slot.ctx_sampling != nullptr) {
10991131
llama_sampling_free(slot.ctx_sampling);
@@ -1189,14 +1221,26 @@ struct server_context {
11891221
}
11901222

11911223
bool process_token(completion_token_output & result, server_slot & slot) {
1192-
// remember which tokens were sampled - used for repetition penalties during sampling
11931224
const std::string token_str = llama_token_to_piece(ctx, result.tok, params.special);
11941225
slot.sampled = result.tok;
1195-
1196-
// search stop word and delete it
1197-
slot.generated_text += token_str;
11981226
slot.has_next_token = true;
11991227

1228+
// Suppress generating the token healing prefix to not repeat the input prompt's suffix
1229+
bool is_token_healing = false;
1230+
if (slot.n_th_prefix > 0) {
1231+
if (slot.n_th_prefix < token_str.size()) {
1232+
slot.generated_text += token_str.substr(slot.n_th_prefix);
1233+
slot.n_th_prefix = 0;
1234+
is_token_healing = false; // to send partial token text when streaming
1235+
} else {
1236+
slot.n_th_prefix -= token_str.size();
1237+
is_token_healing = true;
1238+
}
1239+
} else {
1240+
slot.generated_text += token_str;
1241+
}
1242+
1243+
// remember which tokens were sampled - used for repetition penalties during sampling
12001244
if (slot.ctx_sampling->params.use_penalty_prompt_tokens && result.tok != -1) {
12011245
// we can change penalty_prompt_tokens because it is always created from scratch each request
12021246
slot.ctx_sampling->params.penalty_prompt_tokens.push_back(result.tok);
@@ -1224,7 +1268,7 @@ struct server_context {
12241268
break;
12251269
}
12261270

1227-
if (!incomplete) {
1271+
if (!incomplete && !is_token_healing) {
12281272
size_t pos = std::min(slot.n_sent_text, slot.generated_text.size());
12291273

12301274
const std::string str_test = slot.generated_text.substr(pos);
@@ -1256,7 +1300,7 @@ struct server_context {
12561300
}
12571301
}
12581302

1259-
if (incomplete) {
1303+
if (incomplete || is_token_healing) {
12601304
slot.has_next_token = true;
12611305
}
12621306

@@ -1361,7 +1405,8 @@ struct server_context {
13611405
{"n_probs", slot.sparams.n_probs},
13621406
{"min_keep", slot.sparams.min_keep},
13631407
{"grammar", slot.sparams.grammar},
1364-
{"samplers", samplers_sequence}
1408+
{"samplers", samplers_sequence},
1409+
{"token_healing_enabled", slot.sparams.token_healing_enabled}
13651410
};
13661411
}
13671412

@@ -2106,6 +2151,21 @@ struct server_context {
21062151
continue;
21072152
}
21082153

2154+
// Roll back prompt tokens if token healing
2155+
llama_token_healing_output token_healing_out{};
2156+
if (slot.sparams.token_healing_enabled) {
2157+
token_healing_out = llama_token_healing_rollback(ctx, slot.sparams.token_healing_type,
2158+
prompt_tokens, slot.sparams.token_healing_n_rollback);
2159+
slot.n_th_prefix = token_healing_out.prefix.size();
2160+
slot.n_prompt_tokens = prompt_tokens.size();
2161+
LOG_VERBOSE("token healing prompt", {
2162+
{"id_slot", slot.id},
2163+
{"id_task", slot.id_task},
2164+
{"removed_suffix", token_healing_out.prefix},
2165+
{"n_tokens_removed", token_healing_out.n_tokens_removed}
2166+
});
2167+
}
2168+
21092169
if (slot.embedding) {
21102170
// this prompt is too large to process - discard it
21112171
if (slot.n_prompt_tokens > n_ubatch) {
@@ -2156,6 +2216,9 @@ struct server_context {
21562216
}
21572217

21582218
llama_sampling_reset(slot.ctx_sampling);
2219+
if (slot.sparams.token_healing_enabled) {
2220+
llama_token_healing_set_prefix(slot.ctx_sampling, token_healing_out.prefix);
2221+
}
21592222

21602223
if (!slot.params.cache_prompt) {
21612224
slot.n_past_se = 0;

0 commit comments

Comments
 (0)