Skip to content

Commit 64caa4f

Browse files
committed
Implementation of a sequence repetition penalty
1 parent bd33e5a commit 64caa4f

File tree

6 files changed

+569
-0
lines changed

6 files changed

+569
-0
lines changed

common/common.cpp

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,144 @@ void process_escapes(std::string& input) {
102102
input.resize(output_idx);
103103
}
104104

105+
void seqrep_sampler_params_init(llama_sampler_seqrep_params * params) {
106+
assert(params != NULL);
107+
memset(params, 0, sizeof(llama_sampler_seqrep_params));
108+
params->last_n = 256;
109+
params->mid_word_scale = 0.1f;
110+
params->tolerance_half_step_cost = 1.0f;
111+
}
112+
113+
void seqrep_sampler_params_dump(const llama_sampler_seqrep_params * params) {
114+
assert(params != NULL);
115+
LOG_TEE("seqrep(last_n = %d, min_length = %zd, start_offset = %zd, presence_penalty = %.4f, length_penalty = %.4f, tolerance = %.4f, mid_word_scale = %.4f, tolerance_match_credit = %.4f, tolerance_half_step_cost = %.4f, flags = %d)\n",
116+
params->last_n, params->min_length, params->start_offset, params->presence_penalty,
117+
params->length_penalty, params->tolerance, params->mid_word_scale, params->tolerance_match_credit,
118+
params->tolerance_half_step_cost, params->flags);
119+
}
120+
121+
void seqrep_sampler_help() {
122+
llama_sampler_seqrep_params p;
123+
seqrep_sampler_params_init(&p);
124+
fprintf(stdout, "==== Sequence Repetition Sampler Help ====\n\n");
125+
fprintf(stdout, " The sequence repetition sampler takes a configuration string in the format:\n");
126+
fprintf(stdout, " arg1:arg2:argN\n");
127+
fprintf(stdout, " A colon separated argument can be a key value pair like xyz=1 or flag like xyz\n");
128+
fprintf(stdout, "\n- Available key/value arguments\n");
129+
fprintf(stdout, " * repetition_mode=REPEAT_PENALTY\n emulates the repetition penalty sampler. warning: 1.0 disables penalties since this preset enables flag_divide_by_penalty. using 0.0 is probably not what you want\n");
130+
fprintf(stdout, " * presence_mode=PRESENCE_PENALTY\n emulates the presence penalty sampler\n");
131+
fprintf(stdout, " * frequency_mode=FREQUENCY_PENALTY\n Emulates the repetition penalty sampler\n");
132+
fprintf(stdout, " * last_n\n last n tokens to consider for sequence penalizing (default: %d, 0 = disabled, -1 = ctx_size)\n", p.last_n);
133+
fprintf(stdout, " * min_length\n minimum matching sequence length (default: %zd, < 2 = disabled)\n", p.min_length);
134+
fprintf(stdout, " * presence_penalty\n presence penalty for tokens that can continue a sequence (default: %f, 0.0 = disabled)\n", p.presence_penalty);
135+
fprintf(stdout, " * length_penalty\n penalty for tokens that can continue a sequence, multiplied by length (default: %f, 0.0 = disabled)\n", p.length_penalty);
136+
fprintf(stdout, " * tolerance\n tolerance for fuzzy matching sequences (default: %f, 0 = disabled)\n", p.tolerance);
137+
fprintf(stdout, " * mid_word_scale\n scale penalty when for mid-word tokens. 1.0 would mean apply the full penalty (default: %f, 1.0 = disabled)\n", p.mid_word_scale);
138+
fprintf(stdout, " * tolerance_match_credit\n credit tolerance on matched tokens (default: %f, 0.0 = disabled)\n", p.tolerance_match_credit);
139+
fprintf(stdout, " * tolerance_half_step_cost\n advanced option to adjust tolerance cost for failed matches within a half step of a match (default: %f, 1.0 = normal)\n", p.tolerance_half_step_cost);
140+
fprintf(stdout, "\n- Available flags arguments (currently all default to disabled)\n");
141+
fprintf(stdout, " * flag_immediate_wildcard\n when tolerance is consumed, by default it doesn't count as a match until a real match is found\n");
142+
fprintf(stdout, " * flag_tolerance_no_consecutive\n do not allow using tolerance consecutively\n");
143+
fprintf(stdout, " * flag_tolerance_no_first\n do not allow using tolerance before the first match\n");
144+
fprintf(stdout, " * flag_tolerance_cap_initial\n only meaningful with match credit, prevents match credit adjusting tolerance higher than the initial value\n");
145+
fprintf(stdout, " * flag_penalize_length_max_seen\n when applying length_penalty, use the maximum seen sequence length rather than the total length of seen sequences\n");
146+
fprintf(stdout, " * flag_divide_by_penalty\n divide the logit when applying a penalty rather than subtracting it. warning: when this flag is enabled, 1.0 disables penalties not 0.0. 0.0 is probably not what you want\n");
147+
fprintf(stdout, "\n- Examples:\n");
148+
fprintf(stdout, " * repetition_mode=1.2:last_n=32\n same as --repeat-last-n 32 --repeat-penalty 1.2\n");
149+
fprintf(stdout, " * presence_mode=.2:last_n=32\n same as --repeat-last-n 32 --presence-penalty .2\n");
150+
fprintf(stdout, " * frequency_mode=.2:last_n=32\n same as --repeat-last-n 32 --frequency-penalty .2\n");
151+
fprintf(stdout, " * min_length=3:tolerance=1:length_penalty=.2:last_n=-1\n match repeated sequences of at least 3 tokens within the entire context and apply a penalty of 0.2*total_length to the token that would continue the sequence. allow one non-matching token in matched sequences.\n");
152+
}
153+
154+
bool seqrep_sampler_params_parse(char * s, llama_sampler_seqrep_params * params) {
155+
assert(params != NULL);
156+
assert(s != NULL);
157+
size_t offset = 0;
158+
std::string sparams = s;
159+
size_t slen = sparams.size();
160+
161+
while (offset < slen) {
162+
// printf("SR OFFS: %lu\n", offset);
163+
size_t argsep = sparams.find_first_of(':', offset);
164+
std::string argchunk;
165+
if (argsep == std::string::npos) {
166+
argchunk = sparams.substr(offset);
167+
} else if (argsep > offset) {
168+
argchunk = sparams.substr(offset, argsep - offset);
169+
}
170+
std::string argval;
171+
size_t valsep = argchunk.find_first_of('=');
172+
if (valsep != std::string::npos && valsep < argchunk.size()) {
173+
argval = argchunk.substr(valsep + 1);
174+
argchunk.resize(valsep);
175+
}
176+
// printf("SR: k[%s] = v[%s]\n", argchunk.c_str(), argval.c_str());
177+
if (argchunk.empty() && argval.empty()) {
178+
// pass
179+
} else if (argchunk == "repetition_mode") {
180+
params->last_n = 64;
181+
params->min_length = 1;
182+
params->mid_word_scale = 1.0f;
183+
params->flags = LLAMA_SEQREP_DIVIDE_BY_PENALTY;
184+
params->length_penalty = 1.0f;
185+
params->presence_penalty = argval.empty() ? 1.1f : std::atof(argval.c_str());
186+
} else if (argchunk == "presence_mode") {
187+
params->last_n = 64;
188+
params->min_length = 1;
189+
params->mid_word_scale = 1.0f;
190+
params->flags = 0;
191+
params->length_penalty = 0.0f;
192+
params->presence_penalty = std::atof(argval.c_str());
193+
} else if (argchunk == "frequency_mode") {
194+
params->last_n = 64;
195+
params->min_length = 1;
196+
params->mid_word_scale = 1.0f;
197+
params->flags = 0;
198+
params->length_penalty = std::atof(argval.c_str());
199+
params->presence_penalty = 0.0f;
200+
} else if (argchunk == "flag_immediate_wildcard") {
201+
params->flags |= LLAMA_SEQREP_IMMEDIATE_WILDCARD;
202+
} else if (argchunk == "flag_tolerance_no_consecutive") {
203+
params->flags |= LLAMA_SEQREP_TOLERANCE_NO_CONSECUTIVE;
204+
} else if (argchunk == "flag_tolerance_no_first") {
205+
params->flags |= LLAMA_SEQREP_TOLERANCE_NO_FIRST;
206+
} else if (argchunk == "flag_tolerance_cap_initial") {
207+
params->flags |= LLAMA_SEQREP_TOLERANCE_CAP_INITIAL;
208+
} else if (argchunk == "flag_penalize_length_max_seen") {
209+
params->flags |= LLAMA_SEQREP_PENALIZE_LENGTH_MAX_SEEN;
210+
} else if (argchunk == "flag_divide_by_penalty") {
211+
params->flags |= LLAMA_SEQREP_DIVIDE_BY_PENALTY;
212+
} else if (argchunk == "min_length") {
213+
params->min_length = std::atoi(argval.c_str());
214+
} else if (argchunk == "start_offset") {
215+
params->start_offset = std::atoi(argval.c_str());
216+
} else if (argchunk == "last_n") {
217+
params->last_n = std::atoi(argval.c_str());
218+
} else if (argchunk == "tolerance") {
219+
params->tolerance = std::atof(argval.c_str());
220+
} else if (argchunk == "presence_penalty") {
221+
params->presence_penalty = std::atof(argval.c_str());
222+
} else if (argchunk == "length_penalty") {
223+
params->length_penalty = std::atof(argval.c_str());
224+
} else if (argchunk == "mid_word_scale") {
225+
params->mid_word_scale = std::atof(argval.c_str());
226+
} else if (argchunk == "tolerance_match_credit") {
227+
params->tolerance_match_credit = std::atof(argval.c_str());
228+
} else if (argchunk == "tolerance_half_step_cost") {
229+
params->tolerance_half_step_cost = std::atof(argval.c_str());
230+
} else {
231+
fprintf(stderr, "seqrep: Bad argument [%s]=[%s]!\n", argchunk.c_str(), argval.c_str());
232+
return false;
233+
}
234+
if (argsep != std::string::npos) {
235+
offset = argsep + 1;
236+
} else {
237+
break;
238+
}
239+
}
240+
return true;
241+
}
242+
105243
bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
106244
bool invalid_param = false;
107245
std::string arg;
@@ -248,6 +386,25 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
248386
break;
249387
}
250388
params.presence_penalty = std::stof(argv[i]);
389+
} else if (arg == "-seqrep" || arg == "--seqrep-penalty") {
390+
if (++i >= argc) {
391+
invalid_param = true;
392+
break;
393+
}
394+
if (std::strcmp(argv[i], "help") == 0) {
395+
seqrep_sampler_help();
396+
exit(0);
397+
}
398+
llama_sampler_seqrep_params sr_params;
399+
seqrep_sampler_params_init(&sr_params);
400+
if (!seqrep_sampler_params_parse(argv[i], &sr_params)) {
401+
seqrep_sampler_help();
402+
exit(1);
403+
}
404+
if (sr_params.last_n != 0 && sr_params.min_length > 0
405+
&& (sr_params.presence_penalty != 0.0f || sr_params.length_penalty != 0.0f)) {
406+
params.seqrep_params.push_back(sr_params);
407+
}
251408
} else if (arg == "--mirostat") {
252409
if (++i >= argc) {
253410
invalid_param = true;
@@ -622,6 +779,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
622779
fprintf(stdout, " --repeat-penalty N penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)\n", (double)params.repeat_penalty);
623780
fprintf(stdout, " --presence-penalty N repeat alpha presence penalty (default: %.1f, 0.0 = disabled)\n", (double)params.presence_penalty);
624781
fprintf(stdout, " --frequency-penalty N repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)\n", (double)params.frequency_penalty);
782+
fprintf(stdout, " -seqrep CFG, --seqrep-penalty CFG\n");
783+
fprintf(stdout, " add a copy of the sequence repetition penalty sampler. may be specified multiple times. for help: -seqrep help\n");
625784
fprintf(stdout, " --mirostat N use Mirostat sampling.\n");
626785
fprintf(stdout, " Top K, Nucleus, Tail Free and Locally Typical samplers are ignored if used.\n");
627786
fprintf(stdout, " (default: %d, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)\n", params.mirostat);
@@ -908,6 +1067,10 @@ llama_token llama_sample_token(
9081067
last_tokens.data() + last_tokens.size() - last_n_repeat,
9091068
last_n_repeat, alpha_frequency, alpha_presence);
9101069

1070+
for (auto & sr_params : params.seqrep_params) {
1071+
llama_sample_seqrep_penalty(ctx, &cur_p, last_tokens.data(), last_tokens.size(), &sr_params);
1072+
}
1073+
9111074
if (!penalize_nl) {
9121075
for (size_t idx = 0; idx < cur_p.size; idx++) {
9131076
if (cur_p.data[idx].id == llama_token_nl(ctx)) {

common/common.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ struct gpt_params {
5252
int32_t repeat_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
5353
float frequency_penalty = 0.00f; // 0.0 = disabled
5454
float presence_penalty = 0.00f; // 0.0 = disabled
55+
std::vector<llama_sampler_seqrep_params> seqrep_params;
5556
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
5657
float mirostat_tau = 5.00f; // target entropy
5758
float mirostat_eta = 0.10f; // learning rate
@@ -201,3 +202,7 @@ std::string get_sortable_timestamp();
201202
void dump_non_result_info_yaml(
202203
FILE * stream, const gpt_params & params, const llama_context * lctx,
203204
const std::string & timestamp, const std::vector<int> & prompt_tokens, const char * model_desc);
205+
206+
void seqrep_sampler_params_init(llama_sampler_seqrep_params * params);
207+
void seqrep_sampler_params_dump(const llama_sampler_seqrep_params * params);
208+
bool seqrep_sampler_params_parse(char * s, llama_sampler_seqrep_params * params);

examples/main/main.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -422,6 +422,9 @@ int main(int argc, char ** argv) {
422422
}
423423
LOG_TEE("sampling: repeat_last_n = %d, repeat_penalty = %f, presence_penalty = %f, frequency_penalty = %f, top_k = %d, tfs_z = %f, top_p = %f, typical_p = %f, temp = %f, mirostat = %d, mirostat_lr = %f, mirostat_ent = %f\n",
424424
params.repeat_last_n, params.repeat_penalty, params.presence_penalty, params.frequency_penalty, params.top_k, params.tfs_z, params.top_p, params.typical_p, params.temp, params.mirostat, params.mirostat_eta, params.mirostat_tau);
425+
for (auto & sr_params : params.seqrep_params) {
426+
seqrep_sampler_params_dump(&sr_params);
427+
}
425428
LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
426429
LOG_TEE("\n\n");
427430

0 commit comments

Comments
 (0)