Skip to content

Commit 13885c7

Browse files
committed
main : add token healing
1 parent 272e3bd commit 13885c7

File tree

5 files changed

+249
-6
lines changed

5 files changed

+249
-6
lines changed

common/common.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1093,6 +1093,25 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
10931093
sparams.grammar = json_schema_to_grammar(json::parse(argv[i]));
10941094
return true;
10951095
}
1096+
if (arg == "-th" || arg == "--token-healing") {
1097+
CHECK_ARG
1098+
sparams.token_healing_enabled = true;
1099+
auto & th_type = sparams.token_healing_type;
1100+
auto & th_n_rollback = sparams.token_healing_n_rollback;
1101+
std::string value(argv[i]);
1102+
/**/ if (value == "0" ) { sparams.token_healing_enabled = false; }
1103+
else if (value == "1" ) { th_type = llama_token_healing_type::ROLLBACK_LAST; }
1104+
else if (value == "d1") { th_type = llama_token_healing_type::DYNAMIC_ONCE; }
1105+
else if (value == "d" ) { th_type = llama_token_healing_type::DYNAMIC_MULTI; }
1106+
else if (value[0] == 'r' ) {
1107+
th_type = llama_token_healing_type::ROLLBACK_MULTI;
1108+
th_n_rollback = std::stoi(value.substr(1));
1109+
if (th_n_rollback <= 0) {
1110+
sparams.token_healing_enabled = false;
1111+
}
1112+
} else { invalid_param = true; }
1113+
return true;
1114+
}
10961115
if (arg == "--override-kv") {
10971116
CHECK_ARG
10981117
if (!string_parse_kv_override(argv[i], params.kv_overrides)) {
@@ -1501,6 +1520,10 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
15011520
"if suffix/prefix are specified, template will be disabled\n"
15021521
"only commonly used templates are accepted:\n"
15031522
"https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template" });
1523+
1524+
options.push_back({ "main", "-th, --token-healing {0,1,d1,d,r{N}}",
1525+
"Token healing type. (default: 0, disabled)\n"
1526+
"1: replace one token, d1: replace longest suffix with one token, d: replace longest suffix, r{N}: roll back N tokens" });
15041527
options.push_back({ "grammar" });
15051528
options.push_back({ "*", " --grammar GRAMMAR", "BNF-like grammar to constrain generations (see samples in grammars/ dir) (default: '%s')", sparams.grammar.c_str() });
15061529
options.push_back({ "*", " --grammar-file FNAME", "file to read grammar from" });

common/sampling.cpp

Lines changed: 145 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,112 @@
22
#include "sampling.h"
33
#include <random>
44

5+
//
6+
// Token healing (internal)
7+
//
8+
9+
static bool startswith(const std::string & str, const std::string & prefix) {
10+
return str.rfind(prefix, 0) != std::string::npos;
11+
}
12+
13+
static bool token_healing_prefix_exists(const llama_context * ctx_main, const std::string & prefix) {
14+
const int32_t n_vocab = llama_n_vocab(llama_get_model(ctx_main));
15+
for (llama_token token_id = 0; token_id < n_vocab; ++token_id) {
16+
if (startswith(llama_token_to_piece(ctx_main, token_id), prefix)) {
17+
return true;
18+
}
19+
}
20+
return false;
21+
}
22+
23+
static std::vector<llama_token> token_healing_find_prefix(
24+
const llama_context * ctx_main,
25+
const std::string & prefix,
26+
const bool include_partial_prefix) {
27+
// Example: prefix=" world" -> " world", " worldwide", ...
28+
// If `include_partial_prefix`, include also: " w", " wo", ...
29+
std::vector<llama_token> candidates;
30+
const int32_t n_vocab = llama_n_vocab(llama_get_model(ctx_main));
31+
for (llama_token token_id = 0; token_id < n_vocab; ++token_id) {
32+
std::string token = llama_token_to_piece(ctx_main, token_id);
33+
if (startswith(token, prefix) ||
34+
(include_partial_prefix && startswith(prefix, token))) {
35+
candidates.push_back(token_id);
36+
}
37+
}
38+
return candidates;
39+
}
40+
41+
//
42+
// Token healing (external)
43+
//
44+
45+
std::string llama_token_healing_rollback(
46+
const llama_context * ctx_main,
47+
llama_token_healing_type th_type,
48+
std::vector<llama_token> & tokens,
49+
int max_to_remove,
50+
int * n_removed) {
51+
// NB. To avoid returning empty `tokens`, at least 1 token will remain in `tokens` after rolling back.
52+
// It is the caller's responsibility to add BOS to the start of the prompt if they want to roll back the whole prompt.
53+
if (n_removed != nullptr) {
54+
*n_removed = 0;
55+
}
56+
if (tokens.size() <= 1) {
57+
return "";
58+
}
59+
const llama_model * model = llama_get_model(ctx_main);
60+
const bool is_dynamic = th_type == llama_token_healing_type::DYNAMIC_ONCE || th_type == llama_token_healing_type::DYNAMIC_MULTI;
61+
const int n_ctx = tokens.size();
62+
max_to_remove = th_type == llama_token_healing_type::ROLLBACK_LAST ? 1 : max_to_remove;
63+
max_to_remove = max_to_remove < 0 ? n_ctx - 1 : std::min(max_to_remove, n_ctx - 1); // 1 token must remain
64+
int removed = 0;
65+
std::string prefix;
66+
// Roll back tokens a fixed amount or until there does not exist a token that can cover the prompt
67+
// and stop early if a special token is encountered.
68+
// NB. This doesn't handle cases where a long token is split many times,
69+
// e.g. if "abc" is tokenized into ["a", "b", "c"] but "bc" is not a token (hypothetically),
70+
// then "abc" will not be returned even if "abcd" exists in the vocab.
71+
while (removed < max_to_remove) {
72+
const llama_token next_token_id = tokens[n_ctx - removed - 1];
73+
if (llama_token_is_control(model, next_token_id) || llama_token_is_eog(model, next_token_id)) {
74+
break; // Don't roll back e.g. <|endoftext|>
75+
}
76+
std::string new_prefix = llama_token_to_piece(ctx_main, next_token_id) + prefix;
77+
if (is_dynamic && !token_healing_prefix_exists(ctx_main, new_prefix)) {
78+
break;
79+
}
80+
removed += 1;
81+
prefix = new_prefix;
82+
}
83+
if (removed == 0) { // E.g. if the last token is a special token
84+
return "";
85+
}
86+
// If constrained decoding would give back the original prompt, there is no need to modify the context
87+
const bool is_multi_step = th_type == llama_token_healing_type::ROLLBACK_MULTI ||
88+
th_type == llama_token_healing_type::DYNAMIC_MULTI;
89+
const std::vector<llama_token> candidates = token_healing_find_prefix(ctx_main, prefix, is_multi_step);
90+
LOG("token_healing: prefix = '%s' (%d tokens)\n", prefix.c_str(), removed);
91+
if (removed == 1 && candidates.size() == 1) {
92+
LOG("token_healing: nothing to heal\n");
93+
return "";
94+
}
95+
// Finalize outputs
96+
if (n_removed != nullptr) {
97+
*n_removed = removed;
98+
}
99+
tokens.resize(n_ctx - removed);
100+
return prefix;
101+
}
102+
103+
void llama_token_healing_set_prefix(llama_sampling_context * ctx_sampling, const std::string & prefix) {
104+
ctx_sampling->token_healing_prefix = prefix;
105+
}
106+
107+
//
108+
// Sampling
109+
//
110+
5111
struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params) {
6112
struct llama_sampling_context * result = new llama_sampling_context();
7113

@@ -72,6 +178,8 @@ void llama_sampling_reset(llama_sampling_context * ctx) {
72178
ctx->grammar = grammar;
73179
}
74180

181+
ctx->token_healing_prefix.clear();
182+
75183
std::fill(ctx->prev.begin(), ctx->prev.end(), 0);
76184
ctx->cur.clear();
77185
ctx->n_valid = 0;
@@ -130,7 +238,7 @@ std::string llama_sampling_print(const llama_sampling_params & params) {
130238
}
131239

132240
std::string llama_sampling_order_print(const llama_sampling_params & params) {
133-
std::string result = "CFG -> Penalties ";
241+
std::string result = "(Token healing) -> CFG -> Penalties ";
134242
if (params.mirostat == 0) {
135243
for (auto sampler_type : params.samplers_sequence) {
136244
const auto sampler_type_name = llama_sampling_type_to_str(sampler_type);
@@ -393,8 +501,27 @@ static llama_token_data_array llama_sampling_prepare_impl(
393501

394502
cur.resize(n_vocab);
395503

396-
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
397-
cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
504+
// Constrain tokens based on the remaining token healing prefix (if any)
505+
const auto & th_type = params.token_healing_type;
506+
const auto & th_prefix = ctx_sampling->token_healing_prefix;
507+
if (params.token_healing_enabled && !th_prefix.empty()) {
508+
const bool is_multi_step = th_type == llama_token_healing_type::ROLLBACK_MULTI ||
509+
th_type == llama_token_healing_type::DYNAMIC_MULTI;
510+
std::vector<llama_token> th_candidates = token_healing_find_prefix(ctx_main, th_prefix, is_multi_step);
511+
512+
LOG("token_healing: prefix = '%s'\n", th_prefix.c_str());
513+
for (const llama_token token_id : th_candidates) {
514+
LOG(" [%6d] '%s'\n", token_id, llama_token_to_piece(ctx_main, token_id).c_str());
515+
}
516+
517+
// N.B. We could also set token constraints by setting rejected tokens' logits to -inf
518+
for (const llama_token token_id : th_candidates) {
519+
cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
520+
}
521+
} else {
522+
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
523+
cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
524+
}
398525
}
399526

400527
llama_token_data_array cur_p = { cur.data(), cur.size(), false };
@@ -457,4 +584,19 @@ void llama_sampling_accept(
457584
if (ctx_sampling->grammar != NULL && apply_grammar) {
458585
llama_grammar_accept_token(ctx_sampling->grammar, ctx_main, id);
459586
}
587+
588+
if (ctx_sampling->params.token_healing_enabled && apply_grammar) {
589+
std::string & th_prefix = ctx_sampling->token_healing_prefix;
590+
if (!th_prefix.empty()) {
591+
const std::string new_token_piece = llama_token_to_piece(ctx_main, id);
592+
if (new_token_piece.size() < th_prefix.size()) {
593+
// Shift prefix constraint (for multi step token healing)
594+
th_prefix = th_prefix.substr(new_token_piece.size());
595+
} else {
596+
// Prefix has been generated => no more constrained generation
597+
th_prefix.clear();
598+
LOG("token_healing: done\n");
599+
}
600+
}
601+
}
460602
}

common/sampling.h

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,13 @@ enum class llama_sampler_type : char {
1919
TEMPERATURE = 't'
2020
};
2121

22+
enum class llama_token_healing_type : uint8_t {
23+
ROLLBACK_LAST, // roll back last token with a single constrained decoding step
24+
ROLLBACK_MULTI, // roll back a fixed amount of tokens, multiple constrained decoding steps
25+
DYNAMIC_ONCE, // dynamic roll back, single constrained decoding step
26+
DYNAMIC_MULTI // dynamic roll back, multiple constrained decoding steps
27+
};
28+
2229
// sampling parameters
2330
typedef struct llama_sampling_params {
2431
int32_t n_prev = 64; // number of previous tokens to remember
@@ -62,6 +69,10 @@ typedef struct llama_sampling_params {
6269

6370
std::vector<llama_token> penalty_prompt_tokens;
6471
bool use_penalty_prompt_tokens = false;
72+
73+
llama_token_healing_type token_healing_type = llama_token_healing_type::ROLLBACK_LAST;
74+
bool token_healing_enabled = false;
75+
int token_healing_n_rollback = -1; // number of tokens to roll back
6576
} llama_sampling_params;
6677

6778
// general sampler context
@@ -78,6 +89,8 @@ struct llama_sampling_context {
7889
// internal
7990
grammar_parser::parse_state parsed_grammar;
8091

92+
std::string token_healing_prefix; // remaining prefix to constrain sampling
93+
8194
// TODO: replace with ring-buffer
8295
std::vector<llama_token> prev;
8396
std::vector<llama_token_data> cur;
@@ -158,3 +171,18 @@ void llama_sampling_accept(
158171
struct llama_context * ctx_main,
159172
llama_token id,
160173
bool apply_grammar);
174+
175+
//
176+
// Token healing
177+
//
178+
179+
// Roll back `tokens` for constrained generation according to the token healing
180+
// strategy. Returns the prefix for constrained generation.
181+
std::string llama_token_healing_rollback(
182+
const llama_context * ctx_main,
183+
llama_token_healing_type th_type,
184+
std::vector<llama_token> & tokens,
185+
int max_to_remove = -1,
186+
int * n_removed = nullptr);
187+
188+
void llama_token_healing_set_prefix(llama_sampling_context * ctx_sampling, const std::string & prefix);

examples/main/README.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,19 @@ A more practical use case might be to prevent the generation of `\code{begin}` a
251251

252252
Example usage: `--logit-bias 29905-inf`
253253

254+
### Token healing
255+
256+
- `-th {0,1,d1,d,r{N}}, --token-healing {0,1,d1,d,r{N}}`: Set the token healing strategy (default: 0, 0 = disabled).
257+
258+
Token healing (a.k.a. token alignment) alleviates tokenization artifacts for text completion.
259+
260+
- `-th 1`: Roll back the last token and constrain the bytes of the next token to start with the chopped off last token [0, 2].
261+
- `-th d1`: Roll back multiple tokens until there doesn't exist a token which can cover the prompt's suffix and do a single constrained decoding step [2].
262+
- `-th d`: Like `d1` but allow multiple decoding steps until the removed suffix is generated.
263+
- `-th r{N}`: Like `d` but roll back `N` tokens, where `-th r3` is recommended [1].
264+
265+
Sources: [0](https://github.com/guidance-ai/guidance/blob/main/notebooks/art_of_prompt_design/prompt_boundaries_and_token_healing.ipynb), [1](https://arxiv.org/abs/2403.08688), [2](https://arxiv.org/abs/2402.01035).
266+
254267
### RNG Seed
255268

256269
- `-s SEED, --seed SEED`: Set the random number generator (RNG) seed (default: -1, -1 = random seed).

examples/main/main.cpp

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,17 @@ int main(int argc, char ** argv) {
291291
LOG("tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str());
292292
}
293293

294+
if (sparams.token_healing_enabled && (params.conversation || !params.input_suffix.empty())) {
295+
sparams.token_healing_enabled = false;
296+
LOG("token_healing: disabled due to custom suffix/conversation mode");
297+
}
298+
std::string token_healing_prefix;
299+
int token_healing_n_removed = 0;
300+
if (!params.interactive_first && sparams.token_healing_enabled) {
301+
token_healing_prefix = llama_token_healing_rollback(ctx, sparams.token_healing_type, embd_inp,
302+
sparams.token_healing_n_rollback, &token_healing_n_removed);
303+
}
304+
294305
// Should not run without any tokens
295306
if (embd_inp.empty()) {
296307
if (add_bos) {
@@ -315,7 +326,7 @@ int main(int argc, char ** argv) {
315326
std::vector<llama_token> original_inp = ::llama_tokenize(ctx, params.prompt, true, true);
316327
LOG("original_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, original_inp).c_str());
317328

318-
original_prompt_len = original_inp.size();
329+
original_prompt_len = original_inp.size() - token_healing_n_removed;
319330
guidance_offset = (int)guidance_inp.size() - original_prompt_len;
320331
LOG("original_prompt_len: %s", log_tostr(original_prompt_len));
321332
LOG("guidance_offset: %s", log_tostr(guidance_offset));
@@ -510,6 +521,7 @@ int main(int argc, char ** argv) {
510521
int n_consumed = 0;
511522
int n_session_consumed = 0;
512523
int n_past_guidance = 0;
524+
int n_bytes_to_skip = 0; // to skip printing when generating token healing prefix
513525

514526
std::vector<int> input_tokens; g_input_tokens = &input_tokens;
515527
std::vector<int> output_tokens; g_output_tokens = &output_tokens;
@@ -536,6 +548,7 @@ int main(int argc, char ** argv) {
536548
fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__);
537549
exit(1);
538550
}
551+
llama_token_healing_set_prefix(ctx_sampling, token_healing_prefix);
539552

540553
if (llama_model_has_encoder(model)) {
541554
int enc_input_size = embd_inp.size();
@@ -770,7 +783,15 @@ int main(int argc, char ** argv) {
770783
const std::string token_str = llama_token_to_piece(ctx, id, params.special);
771784

772785
// Console/Stream Output
773-
fprintf(stdout, "%s", token_str.c_str());
786+
// Suppress printing while generating token healing prefix
787+
if (n_bytes_to_skip > 0 && n_bytes_to_skip < (int)token_str.size()) {
788+
fprintf(stdout, "%s", token_str.substr(n_bytes_to_skip).c_str());
789+
n_bytes_to_skip = 0;
790+
} else if (n_bytes_to_skip > 0) {
791+
n_bytes_to_skip -= token_str.size();
792+
} else {
793+
fprintf(stdout, "%s", token_str.c_str());
794+
}
774795

775796
// Record Displayed Tokens To Log
776797
// Note: Generated tokens are created one by one hence this check
@@ -862,6 +883,7 @@ int main(int argc, char ** argv) {
862883
assistant_ss << llama_token_to_piece(ctx, id, false);
863884
}
864885

886+
token_healing_n_removed = 0;
865887
if (n_past > 0 && is_interacting) {
866888
LOG("waiting for user input\n");
867889

@@ -934,6 +956,17 @@ int main(int argc, char ** argv) {
934956
embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end());
935957
embd_inp.insert(embd_inp.end(), line_sfx.begin(), line_sfx.end());
936958

959+
if (sparams.token_healing_enabled) {
960+
// Limit token healing rollback to new tokens only (otherwise would need to shift everything)
961+
const int n_new_tokens = embd_inp.size() - original_size;
962+
const int max_to_remove = sparams.token_healing_n_rollback < 0
963+
? n_new_tokens
964+
: std::min(sparams.token_healing_n_rollback, n_new_tokens);
965+
token_healing_prefix = llama_token_healing_rollback(ctx, sparams.token_healing_type, embd_inp,
966+
max_to_remove, &token_healing_n_removed);
967+
n_bytes_to_skip = token_healing_prefix.size();
968+
}
969+
937970
for (size_t i = original_size; i < embd_inp.size(); ++i) {
938971
const llama_token token = embd_inp[i];
939972
output_tokens.push_back(token);
@@ -943,7 +976,7 @@ int main(int argc, char ** argv) {
943976
// reset assistant message
944977
assistant_ss.str("");
945978

946-
n_remain -= line_inp.size();
979+
n_remain -= line_inp.size() + token_healing_n_removed;
947980
LOG("n_remain: %d\n", n_remain);
948981
} else {
949982
LOG("empty line, passing control back\n");
@@ -955,6 +988,10 @@ int main(int argc, char ** argv) {
955988
if (n_past > 0) {
956989
if (is_interacting) {
957990
llama_sampling_reset(ctx_sampling);
991+
if (token_healing_n_removed > 0) {
992+
// Set new prefix after an interaction
993+
llama_token_healing_set_prefix(ctx_sampling, token_healing_prefix);
994+
}
958995
}
959996
is_interacting = false;
960997
}

0 commit comments

Comments
 (0)