Skip to content

Commit ea4abc9

Browse files
committed
token healing : refactor argument parsing
Unify `main` and `server` token healing argument handling.
1 parent 3ba5c55 commit ea4abc9

File tree

5 files changed

+61
-60
lines changed

5 files changed

+61
-60
lines changed

common/common.cpp

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1095,21 +1095,8 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
10951095
}
10961096
if (arg == "-th" || arg == "--token-healing") {
10971097
CHECK_ARG
1098-
sparams.token_healing_enabled = true;
1099-
auto & th_type = sparams.token_healing_type;
1100-
auto & th_n_rollback = sparams.token_healing_n_rollback;
11011098
std::string value(argv[i]);
1102-
/**/ if (value == "0" ) { sparams.token_healing_enabled = false; }
1103-
else if (value == "1" ) { th_type = llama_token_healing_type::ROLLBACK_LAST; }
1104-
else if (value == "d1") { th_type = llama_token_healing_type::DYNAMIC_ONCE; }
1105-
else if (value == "d" ) { th_type = llama_token_healing_type::DYNAMIC_MULTI; }
1106-
else if (value[0] == 'r' ) {
1107-
th_type = llama_token_healing_type::ROLLBACK_MULTI;
1108-
th_n_rollback = std::stoi(value.substr(1));
1109-
if (th_n_rollback <= 0) {
1110-
sparams.token_healing_enabled = false;
1111-
}
1112-
} else { invalid_param = true; }
1099+
invalid_param = !llama_token_healing_parse_params(value, sparams.token_healing);
11131100
return true;
11141101
}
11151102
if (arg == "--override-kv") {

common/sampling.cpp

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,25 @@ void llama_token_healing_set_prefix(llama_sampling_context * ctx_sampling, const
154154
ctx_sampling->token_healing_prefix = prefix;
155155
}
156156

157+
bool llama_token_healing_parse_params(const std::string & params, llama_token_healing_params & th_params) {
158+
th_params.enabled = true;
159+
th_params.n_rollback = -1;
160+
/**/ if (params == "0" ) { th_params.enabled = false; }
161+
else if (params == "1" ) { th_params.type = llama_token_healing_type::ROLLBACK_LAST; }
162+
else if (params == "d1") { th_params.type = llama_token_healing_type::DYNAMIC_ONCE; }
163+
else if (params == "d" ) { th_params.type = llama_token_healing_type::DYNAMIC_MULTI; }
164+
else if (params[0] == 'r' ) {
165+
th_params.type = llama_token_healing_type::ROLLBACK_MULTI;
166+
th_params.n_rollback = std::stoi(params.substr(1));
167+
if (th_params.n_rollback <= 0) {
168+
return false;
169+
}
170+
} else {
171+
return false;
172+
}
173+
return true;
174+
}
175+
157176
//
158177
// Sampling
159178
//
@@ -552,11 +571,10 @@ static llama_token_data_array llama_sampling_prepare_impl(
552571
cur.resize(n_vocab);
553572

554573
// Constrain tokens based on the remaining token healing prefix (if any)
555-
const auto & th_type = params.token_healing_type;
556574
const auto & th_prefix = ctx_sampling->token_healing_prefix;
557-
if (params.token_healing_enabled && !th_prefix.empty()) {
558-
const bool is_multi_step = th_type == llama_token_healing_type::ROLLBACK_MULTI ||
559-
th_type == llama_token_healing_type::DYNAMIC_MULTI;
575+
if (params.token_healing.enabled && !th_prefix.empty()) {
576+
const bool is_multi_step = params.token_healing.type == llama_token_healing_type::ROLLBACK_MULTI ||
577+
params.token_healing.type == llama_token_healing_type::DYNAMIC_MULTI;
560578
std::vector<llama_token> th_candidates = token_healing_get_candidates(ctx_main, th_prefix, is_multi_step);
561579

562580
LOG("token_healing: prefix = '%s'\n", th_prefix.c_str());
@@ -635,7 +653,7 @@ void llama_sampling_accept(
635653
llama_grammar_accept_token(ctx_sampling->grammar, ctx_main, id);
636654
}
637655

638-
if (ctx_sampling->params.token_healing_enabled && apply_grammar) {
656+
if (ctx_sampling->params.token_healing.enabled && apply_grammar) {
639657
std::string & th_prefix = ctx_sampling->token_healing_prefix;
640658
if (!th_prefix.empty()) {
641659
const std::string new_token_piece = llama_token_to_piece(ctx_main, id);

common/sampling.h

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,12 @@ enum class llama_token_healing_type : uint8_t {
2626
DYNAMIC_MULTI // dynamic roll back, multiple constrained decoding steps
2727
};
2828

29+
struct llama_token_healing_params {
30+
bool enabled = false;
31+
llama_token_healing_type type = llama_token_healing_type::DYNAMIC_MULTI;
32+
int n_rollback = -1; // number of tokens to roll back
33+
};
34+
2935
// sampling parameters
3036
typedef struct llama_sampling_params {
3137
int32_t n_prev = 64; // number of previous tokens to remember
@@ -70,9 +76,7 @@ typedef struct llama_sampling_params {
7076
std::vector<llama_token> penalty_prompt_tokens;
7177
bool use_penalty_prompt_tokens = false;
7278

73-
llama_token_healing_type token_healing_type = llama_token_healing_type::ROLLBACK_LAST;
74-
bool token_healing_enabled = false;
75-
int token_healing_n_rollback = -1; // number of tokens to roll back
79+
llama_token_healing_params token_healing;
7680
} llama_sampling_params;
7781

7882
// general sampler context
@@ -190,3 +194,6 @@ llama_token_healing_output llama_token_healing_rollback(
190194
int max_to_remove = -1);
191195

192196
void llama_token_healing_set_prefix(llama_sampling_context * ctx_sampling, const std::string & prefix);
197+
198+
// Helper for parsing token healing params from a string.
199+
bool llama_token_healing_parse_params(const std::string & params, llama_token_healing_params & th_params);

examples/main/main.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -291,14 +291,14 @@ int main(int argc, char ** argv) {
291291
LOG("tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str());
292292
}
293293

294-
if (sparams.token_healing_enabled && (params.conversation || !params.input_suffix.empty())) {
295-
sparams.token_healing_enabled = false;
294+
if (sparams.token_healing.enabled && (params.conversation || !params.input_suffix.empty())) {
295+
sparams.token_healing.enabled = false;
296296
LOG("token healing: disabled due to custom suffix/conversation mode");
297297
}
298298
llama_token_healing_output token_healing_out{};
299-
if (!params.interactive_first && sparams.token_healing_enabled) {
300-
token_healing_out = llama_token_healing_rollback(ctx, sparams.token_healing_type, embd_inp,
301-
sparams.token_healing_n_rollback);
299+
if (!params.interactive_first && sparams.token_healing.enabled) {
300+
token_healing_out = llama_token_healing_rollback(ctx, sparams.token_healing.type, embd_inp,
301+
sparams.token_healing.n_rollback);
302302
}
303303

304304
// Should not run without any tokens
@@ -956,13 +956,13 @@ int main(int argc, char ** argv) {
956956
embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end());
957957
embd_inp.insert(embd_inp.end(), line_sfx.begin(), line_sfx.end());
958958

959-
if (sparams.token_healing_enabled) {
959+
if (sparams.token_healing.enabled) {
960960
// Limit token healing rollback to new tokens only (otherwise would need to shift everything)
961961
const int n_new_tokens = embd_inp.size() - original_size;
962-
const int max_to_remove = sparams.token_healing_n_rollback < 0
962+
const int max_to_remove = sparams.token_healing.n_rollback < 0
963963
? n_new_tokens
964-
: std::min(sparams.token_healing_n_rollback, n_new_tokens);
965-
token_healing_out = llama_token_healing_rollback(ctx, sparams.token_healing_type, embd_inp, max_to_remove);
964+
: std::min(sparams.token_healing.n_rollback, n_new_tokens);
965+
token_healing_out = llama_token_healing_rollback(ctx, sparams.token_healing.type, embd_inp, max_to_remove);
966966
n_bytes_to_skip = token_healing_out.prefix.size();
967967
}
968968

examples/server/server.cpp

Lines changed: 18 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1098,31 +1098,20 @@ struct server_context {
10981098

10991099
{
11001100
const auto & token_healing_str = data.find("token_healing");
1101-
auto & th_enabled = slot.sparams.token_healing_enabled;
1102-
th_enabled = default_sparams.token_healing_enabled;
11031101
if (token_healing_str != data.end() && token_healing_str->is_string()) {
11041102
const auto value = token_healing_str->get<std::string>();
1105-
auto & th_type = slot.sparams.token_healing_type;
1106-
auto & th_n_rollback = slot.sparams.token_healing_n_rollback;
1107-
th_enabled = true;
1108-
/**/ if (value == "0" ) { th_enabled = false; }
1109-
else if (value == "1" ) { th_type = llama_token_healing_type::ROLLBACK_LAST; }
1110-
else if (value == "d1") { th_type = llama_token_healing_type::DYNAMIC_ONCE; }
1111-
else if (value == "d" ) { th_type = llama_token_healing_type::DYNAMIC_MULTI; }
1112-
else if (value[0] == 'r' ) {
1113-
th_type = llama_token_healing_type::ROLLBACK_MULTI;
1114-
th_n_rollback = std::stoi(value.substr(1));
1115-
if (th_n_rollback <= 0) {
1116-
th_enabled = false;
1117-
}
1118-
} else { th_enabled = false; }
1119-
1103+
if (!llama_token_healing_parse_params(value, slot.sparams.token_healing)) {
1104+
send_error(task, "\"token_healing\" parse error", ERROR_TYPE_INVALID_REQUEST);
1105+
return false;
1106+
}
11201107
LOG_VERBOSE("token healing", {
11211108
{"id_slot", slot.id},
1122-
{"enabled", th_enabled},
1123-
{"type", th_type},
1124-
{"n_rollback", th_n_rollback}
1109+
{"enabled", slot.sparams.token_healing.enabled},
1110+
{"type", slot.sparams.token_healing.type},
1111+
{"n_rollback", slot.sparams.token_healing.n_rollback}
11251112
});
1113+
} else {
1114+
slot.sparams.token_healing = default_sparams.token_healing;
11261115
}
11271116
}
11281117

@@ -1406,7 +1395,7 @@ struct server_context {
14061395
{"min_keep", slot.sparams.min_keep},
14071396
{"grammar", slot.sparams.grammar},
14081397
{"samplers", samplers_sequence},
1409-
{"token_healing_enabled", slot.sparams.token_healing_enabled}
1398+
{"token_healing_enabled", slot.sparams.token_healing.enabled}
14101399
};
14111400
}
14121401

@@ -2109,10 +2098,10 @@ struct server_context {
21092098
prefix_tokens.insert(prefix_tokens.begin(), llama_token_prefix(model));
21102099
suffix_tokens.insert(suffix_tokens.begin(), llama_token_suffix(model));
21112100

2112-
if (slot.sparams.token_healing_enabled) {
2101+
if (slot.sparams.token_healing.enabled) {
21132102
// For FIM roll back only the prefix part (i.e. cursor location)
2114-
token_healing_out = llama_token_healing_rollback(ctx, slot.sparams.token_healing_type,
2115-
prefix_tokens, slot.sparams.token_healing_n_rollback);
2103+
token_healing_out = llama_token_healing_rollback(ctx, slot.sparams.token_healing.type,
2104+
prefix_tokens, slot.sparams.token_healing.n_rollback);
21162105
}
21172106

21182107
auto embd_inp = params.spm_infill ? suffix_tokens : prefix_tokens;
@@ -2131,9 +2120,9 @@ struct server_context {
21312120
} else {
21322121
prompt_tokens = tokenize(slot.prompt, system_prompt.empty()); // add BOS if there isn't system prompt
21332122

2134-
if (slot.sparams.token_healing_enabled) {
2135-
token_healing_out = llama_token_healing_rollback(ctx, slot.sparams.token_healing_type,
2136-
prompt_tokens, slot.sparams.token_healing_n_rollback);
2123+
if (slot.sparams.token_healing.enabled) {
2124+
token_healing_out = llama_token_healing_rollback(ctx, slot.sparams.token_healing.type,
2125+
prompt_tokens, slot.sparams.token_healing.n_rollback);
21372126
}
21382127
}
21392128

@@ -2149,7 +2138,7 @@ struct server_context {
21492138
{"prompt_tokens", tokens_to_str(ctx, prompt_tokens.cbegin(), prompt_tokens.cend())},
21502139
});
21512140

2152-
if (slot.sparams.token_healing_enabled) {
2141+
if (slot.sparams.token_healing.enabled) {
21532142
slot.n_th_prefix = token_healing_out.prefix.size();
21542143
LOG_VERBOSE("token healing prompt", {
21552144
{"id_slot", slot.id},
@@ -2224,7 +2213,7 @@ struct server_context {
22242213
}
22252214

22262215
llama_sampling_reset(slot.ctx_sampling);
2227-
if (slot.sparams.token_healing_enabled) {
2216+
if (slot.sparams.token_healing.enabled) {
22282217
llama_token_healing_set_prefix(slot.ctx_sampling, token_healing_out.prefix);
22292218
}
22302219

0 commit comments

Comments
 (0)