@@ -185,6 +185,7 @@ struct server_slot {
185
185
// stats
186
186
size_t n_sent_text = 0 ; // number of sent text character
187
187
size_t n_sent_token_probs = 0 ;
188
+ size_t n_th_prefix = 0 ; // size of remaining token healing prefix
188
189
189
190
int64_t t_start_process_prompt;
190
191
int64_t t_start_generation;
@@ -206,6 +207,7 @@ struct server_slot {
206
207
infill = false ;
207
208
ga_i = 0 ;
208
209
n_past_se = 0 ;
210
+ n_th_prefix = 0 ;
209
211
210
212
generated_token_probs.clear ();
211
213
}
@@ -1094,6 +1096,36 @@ struct server_context {
1094
1096
}
1095
1097
}
1096
1098
1099
+ {
1100
+ const auto & token_healing_str = data.find (" token_healing" );
1101
+ auto & th_enabled = slot.sparams .token_healing_enabled ;
1102
+ th_enabled = default_sparams.token_healing_enabled ;
1103
+ if (token_healing_str != data.end () && token_healing_str->is_string ()) {
1104
+ const auto value = token_healing_str->get <std::string>();
1105
+ auto & th_type = slot.sparams .token_healing_type ;
1106
+ auto & th_n_rollback = slot.sparams .token_healing_n_rollback ;
1107
+ th_enabled = true ;
1108
+ /* */ if (value == " 0" ) { th_enabled = false ; }
1109
+ else if (value == " 1" ) { th_type = llama_token_healing_type::ROLLBACK_LAST; }
1110
+ else if (value == " d1" ) { th_type = llama_token_healing_type::DYNAMIC_ONCE; }
1111
+ else if (value == " d" ) { th_type = llama_token_healing_type::DYNAMIC_MULTI; }
1112
+ else if (value[0 ] == ' r' ) {
1113
+ th_type = llama_token_healing_type::ROLLBACK_MULTI;
1114
+ th_n_rollback = std::stoi (value.substr (1 ));
1115
+ if (th_n_rollback <= 0 ) {
1116
+ th_enabled = false ;
1117
+ }
1118
+ } else { th_enabled = false ; }
1119
+
1120
+ LOG_VERBOSE (" token healing" , {
1121
+ {" id_slot" , slot.id },
1122
+ {" enabled" , th_enabled},
1123
+ {" type" , th_type},
1124
+ {" n_rollback" , th_n_rollback}
1125
+ });
1126
+ }
1127
+ }
1128
+
1097
1129
{
1098
1130
if (slot.ctx_sampling != nullptr ) {
1099
1131
llama_sampling_free (slot.ctx_sampling );
@@ -1189,14 +1221,26 @@ struct server_context {
1189
1221
}
1190
1222
1191
1223
bool process_token (completion_token_output & result, server_slot & slot) {
1192
- // remember which tokens were sampled - used for repetition penalties during sampling
1193
1224
const std::string token_str = llama_token_to_piece (ctx, result.tok , params.special );
1194
1225
slot.sampled = result.tok ;
1195
-
1196
- // search stop word and delete it
1197
- slot.generated_text += token_str;
1198
1226
slot.has_next_token = true ;
1199
1227
1228
+ // Suppress generating the token healing prefix to not repeat the input prompt's suffix
1229
+ bool is_token_healing = false ;
1230
+ if (slot.n_th_prefix > 0 ) {
1231
+ if (slot.n_th_prefix < token_str.size ()) {
1232
+ slot.generated_text += token_str.substr (slot.n_th_prefix );
1233
+ slot.n_th_prefix = 0 ;
1234
+ is_token_healing = false ; // to send partial token text when streaming
1235
+ } else {
1236
+ slot.n_th_prefix -= token_str.size ();
1237
+ is_token_healing = true ;
1238
+ }
1239
+ } else {
1240
+ slot.generated_text += token_str;
1241
+ }
1242
+
1243
+ // remember which tokens were sampled - used for repetition penalties during sampling
1200
1244
if (slot.ctx_sampling ->params .use_penalty_prompt_tokens && result.tok != -1 ) {
1201
1245
// we can change penalty_prompt_tokens because it is always created from scratch each request
1202
1246
slot.ctx_sampling ->params .penalty_prompt_tokens .push_back (result.tok );
@@ -1224,7 +1268,7 @@ struct server_context {
1224
1268
break ;
1225
1269
}
1226
1270
1227
- if (!incomplete) {
1271
+ if (!incomplete && !is_token_healing ) {
1228
1272
size_t pos = std::min (slot.n_sent_text , slot.generated_text .size ());
1229
1273
1230
1274
const std::string str_test = slot.generated_text .substr (pos);
@@ -1256,7 +1300,7 @@ struct server_context {
1256
1300
}
1257
1301
}
1258
1302
1259
- if (incomplete) {
1303
+ if (incomplete || is_token_healing ) {
1260
1304
slot.has_next_token = true ;
1261
1305
}
1262
1306
@@ -1361,7 +1405,8 @@ struct server_context {
1361
1405
{" n_probs" , slot.sparams .n_probs },
1362
1406
{" min_keep" , slot.sparams .min_keep },
1363
1407
{" grammar" , slot.sparams .grammar },
1364
- {" samplers" , samplers_sequence}
1408
+ {" samplers" , samplers_sequence},
1409
+ {" token_healing_enabled" , slot.sparams .token_healing_enabled }
1365
1410
};
1366
1411
}
1367
1412
@@ -2106,6 +2151,21 @@ struct server_context {
2106
2151
continue ;
2107
2152
}
2108
2153
2154
+ // Roll back prompt tokens if token healing
2155
+ llama_token_healing_output token_healing_out{};
2156
+ if (slot.sparams .token_healing_enabled ) {
2157
+ token_healing_out = llama_token_healing_rollback (ctx, slot.sparams .token_healing_type ,
2158
+ prompt_tokens, slot.sparams .token_healing_n_rollback );
2159
+ slot.n_th_prefix = token_healing_out.prefix .size ();
2160
+ slot.n_prompt_tokens = prompt_tokens.size ();
2161
+ LOG_VERBOSE (" token healing prompt" , {
2162
+ {" id_slot" , slot.id },
2163
+ {" id_task" , slot.id_task },
2164
+ {" removed_suffix" , token_healing_out.prefix },
2165
+ {" n_tokens_removed" , token_healing_out.n_tokens_removed }
2166
+ });
2167
+ }
2168
+
2109
2169
if (slot.embedding ) {
2110
2170
// this prompt is too large to process - discard it
2111
2171
if (slot.n_prompt_tokens > n_ubatch) {
@@ -2156,6 +2216,9 @@ struct server_context {
2156
2216
}
2157
2217
2158
2218
llama_sampling_reset (slot.ctx_sampling );
2219
+ if (slot.sparams .token_healing_enabled ) {
2220
+ llama_token_healing_set_prefix (slot.ctx_sampling , token_healing_out.prefix );
2221
+ }
2159
2222
2160
2223
if (!slot.params .cache_prompt ) {
2161
2224
slot.n_past_se = 0 ;
0 commit comments