@@ -295,11 +295,10 @@ int main(int argc, char ** argv) {
295
295
sparams.token_healing_enabled = false ;
296
296
LOG (" token healing: disabled due to custom suffix/conversation mode" );
297
297
}
298
- std::string token_healing_prefix;
299
- int token_healing_n_removed = 0 ;
298
+ llama_token_healing_output token_healing_out{};
300
299
if (!params.interactive_first && sparams.token_healing_enabled ) {
301
- token_healing_prefix = llama_token_healing_rollback (ctx, sparams.token_healing_type , embd_inp,
302
- sparams.token_healing_n_rollback , &token_healing_n_removed );
300
+ token_healing_out = llama_token_healing_rollback (ctx, sparams.token_healing_type , embd_inp,
301
+ sparams.token_healing_n_rollback );
303
302
}
304
303
305
304
// Should not run without any tokens
@@ -326,7 +325,7 @@ int main(int argc, char ** argv) {
326
325
std::vector<llama_token> original_inp = ::llama_tokenize (ctx, params.prompt , true , true );
327
326
LOG (" original_inp tokenized: %s\n " , LOG_TOKENS_TOSTR_PRETTY (ctx, original_inp).c_str ());
328
327
329
- original_prompt_len = original_inp.size () - token_healing_n_removed ;
328
+ original_prompt_len = original_inp.size () - token_healing_out. n_tokens_removed ;
330
329
guidance_offset = (int )guidance_inp.size () - original_prompt_len;
331
330
LOG (" original_prompt_len: %s" , log_tostr (original_prompt_len));
332
331
LOG (" guidance_offset: %s" , log_tostr (guidance_offset));
@@ -548,7 +547,7 @@ int main(int argc, char ** argv) {
548
547
fprintf (stderr, " %s: failed to initialize sampling subsystem\n " , __func__);
549
548
exit (1 );
550
549
}
551
- llama_token_healing_set_prefix (ctx_sampling, token_healing_prefix );
550
+ llama_token_healing_set_prefix (ctx_sampling, token_healing_out. prefix );
552
551
553
552
if (llama_model_has_encoder (model)) {
554
553
int enc_input_size = embd_inp.size ();
@@ -883,7 +882,8 @@ int main(int argc, char ** argv) {
883
882
assistant_ss << llama_token_to_piece (ctx, id, false );
884
883
}
885
884
886
- token_healing_n_removed = 0 ;
885
+ token_healing_out = {};
886
+
887
887
if (n_past > 0 && is_interacting) {
888
888
LOG (" waiting for user input\n " );
889
889
@@ -962,9 +962,8 @@ int main(int argc, char ** argv) {
962
962
const int max_to_remove = sparams.token_healing_n_rollback < 0
963
963
? n_new_tokens
964
964
: std::min (sparams.token_healing_n_rollback , n_new_tokens);
965
- token_healing_prefix = llama_token_healing_rollback (ctx, sparams.token_healing_type , embd_inp,
966
- max_to_remove, &token_healing_n_removed);
967
- n_bytes_to_skip = token_healing_prefix.size ();
965
+ token_healing_out = llama_token_healing_rollback (ctx, sparams.token_healing_type , embd_inp, max_to_remove);
966
+ n_bytes_to_skip = token_healing_out.prefix .size ();
968
967
}
969
968
970
969
for (size_t i = original_size; i < embd_inp.size (); ++i) {
@@ -976,7 +975,7 @@ int main(int argc, char ** argv) {
976
975
// reset assistant message
977
976
assistant_ss.str (" " );
978
977
979
- n_remain -= line_inp.size () + token_healing_n_removed ;
978
+ n_remain -= line_inp.size () + token_healing_out. n_tokens_removed ;
980
979
LOG (" n_remain: %d\n " , n_remain);
981
980
} else {
982
981
LOG (" empty line, passing control back\n " );
@@ -988,9 +987,9 @@ int main(int argc, char ** argv) {
988
987
if (n_past > 0 ) {
989
988
if (is_interacting) {
990
989
llama_sampling_reset (ctx_sampling);
991
- if (token_healing_n_removed > 0 ) {
990
+ if (token_healing_out. n_tokens_removed > 0 ) {
992
991
// Set new prefix after an interaction
993
- llama_token_healing_set_prefix (ctx_sampling, token_healing_prefix );
992
+ llama_token_healing_set_prefix (ctx_sampling, token_healing_out. prefix );
994
993
}
995
994
}
996
995
is_interacting = false ;
0 commit comments