Skip to content

Commit e199bfe

Browse files
committed
Initial implementation of a sequence repetition penalty
1 parent b532a69 commit e199bfe

File tree

5 files changed

+486
-0
lines changed

5 files changed

+486
-0
lines changed

common/common.cpp

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,147 @@ void process_escapes(std::string& input) {
100100
input.resize(output_idx);
101101
}
102102

103+
void seqrep_sampler_params_init(llama_sampler_seqrep_params * params) {
104+
assert(params != NULL);
105+
memset(params, 0, sizeof(llama_sampler_seqrep_params));
106+
params->last_n = 256;
107+
params->mid_word_scale = 0.1f;
108+
params->tolerance_half_step_cost = 1.0f;
109+
}
110+
111+
void seqrep_sampler_params_dump(FILE * fp, llama_sampler_seqrep_params * params) {
112+
if (fp == NULL) {
113+
return;
114+
}
115+
assert(params != NULL);
116+
fprintf(fp, "seqrep(last_n = %d, min_length = %zd, start_offset = %zd, presence_penalty = %.4f, length_penalty = %.4f, tolerance = %.4f, mid_word_scale = %.4f, tolerance_match_credit = %.4f, tolerance_half_step_cost = %.4f, flags = %d)",
117+
params->last_n, params->min_length, params->start_offset, params->presence_penalty,
118+
params->length_penalty, params->tolerance, params->mid_word_scale, params->tolerance_match_credit,
119+
params->tolerance_half_step_cost, params->flags);
120+
}
121+
122+
void seqrep_sampler_help() {
123+
llama_sampler_seqrep_params p;
124+
seqrep_sampler_params_init(&p);
125+
fprintf(stdout, "==== Sequence Repetition Sampler Help ====\n\n");
126+
fprintf(stdout, " The sequence repetition sampler takes a configuration string in the format:\n");
127+
fprintf(stdout, " arg1:arg2:argN\n");
128+
fprintf(stdout, " A colon separated argument can be a key value pair like xyz=1 or flag like xyz\n");
129+
fprintf(stdout, "\n- Available key/value arguments\n");
130+
fprintf(stdout, " * repetition_mode=REPEAT_PENALTY\n emulates the repetition penalty sampler. warning: 1.0 disables penalties since this preset enables flag_divide_by_penalty. using 0.0 is probably not what you want\n");
131+
fprintf(stdout, " * presence_mode=PRESENCE_PENALTY\n emulates the presence penalty sampler\n");
132+
fprintf(stdout, " * frequency_mode=FREQUENCY_PENALTY\n Emulates the repetition penalty sampler\n");
133+
fprintf(stdout, " * last_n\n last n tokens to consider for sequence penalizing (default: %d, 0 = disabled, -1 = ctx_size)\n", p.last_n);
134+
fprintf(stdout, " * min_length\n minimum matching sequence length (default: %zd, < 2 = disabled)\n", p.min_length);
135+
fprintf(stdout, " * presence_penalty\n presence penalty for tokens that can continue a sequence (default: %f, 0.0 = disabled)\n", p.presence_penalty);
136+
fprintf(stdout, " * length_penalty\n penalty for tokens that can continue a sequence, multiplied by length (default: %f, 0.0 = disabled)\n", p.length_penalty);
137+
fprintf(stdout, " * tolerance\n tolerance for fuzzy matching sequences (default: %f, 0 = disabled)\n", p.tolerance);
138+
fprintf(stdout, " * mid_word_scale\n scale penalty when for mid-word tokens. 1.0 would mean apply the full penalty (default: %f, 1.0 = disabled)\n", p.mid_word_scale);
139+
fprintf(stdout, " * tolerance_match_credit\n credit tolerance on matched tokens (default: %f, 0.0 = disabled)\n", p.tolerance_match_credit);
140+
fprintf(stdout, " * tolerance_half_step_cost\n advanced option to adjust tolerance cost for failed matches within a half step of a match (default: %f, 1.0 = normal)\n", p.tolerance_half_step_cost);
141+
fprintf(stdout, "\n- Available flags arguments (currently all default to disabled)\n");
142+
fprintf(stdout, " * flag_immediate_wildcard\n when tolerance is consumed, by default it doesn't count as a match until a real match is found\n");
143+
fprintf(stdout, " * flag_tolerance_no_consecutive\n do not allow using tolerance consecutively\n");
144+
fprintf(stdout, " * flag_tolerance_no_first\n do not allow using tolerance before the first match\n");
145+
fprintf(stdout, " * flag_tolerance_cap_initial\n only meaningful with match credit, prevents match credit adjusting tolerance higher than the initial value\n");
146+
fprintf(stdout, " * flag_penalize_length_max_seen\n when applying length_penalty, use the maximum seen sequence length rather than the total length of seen sequences\n");
147+
fprintf(stdout, " * flag_divide_by_penalty\n divide the logit when applying a penalty rather than subtracting it. warning: when this flag is enabled, 1.0 disables penalties not 0.0. 0.0 is probably not what you want\n");
148+
fprintf(stdout, "\n- Examples:\n");
149+
fprintf(stdout, " * repetition_mode=1.2:last_n=32\n same as --repeat-last-n 32 --repeat-penalty 1.2\n");
150+
fprintf(stdout, " * presence_mode=.2:last_n=32\n same as --repeat-last-n 32 --presence-penalty .2\n");
151+
fprintf(stdout, " * frequency_mode=.2:last_n=32\n same as --repeat-last-n 32 --frequency-penalty .2\n");
152+
fprintf(stdout, " * min_length=3:tolerance=1:length_penalty=.2:last_n=-1\n match repeated sequences of at least 3 tokens within the entire context and apply a penalty of 0.2*total_length to the token that would continue the sequence. allow one non-matching token in matched sequences.\n");
153+
}
154+
155+
bool seqrep_sampler_params_parse(char * s, llama_sampler_seqrep_params * params) {
156+
assert(params != NULL);
157+
assert(s != NULL);
158+
size_t offset = 0;
159+
std::string sparams = s;
160+
size_t slen = sparams.size();
161+
162+
while (offset < slen) {
163+
// printf("SR OFFS: %lu\n", offset);
164+
size_t argsep = sparams.find_first_of(':', offset);
165+
std::string argchunk;
166+
if (argsep == std::string::npos) {
167+
argchunk = sparams.substr(offset);
168+
} else if (argsep > offset) {
169+
argchunk = sparams.substr(offset, argsep - offset);
170+
}
171+
std::string argval;
172+
size_t valsep = argchunk.find_first_of('=');
173+
if (valsep != std::string::npos && valsep < argchunk.size()) {
174+
argval = argchunk.substr(valsep + 1);
175+
argchunk.resize(valsep);
176+
}
177+
// printf("SR: k[%s] = v[%s]\n", argchunk.c_str(), argval.c_str());
178+
if (argchunk.empty() && argval.empty()) {
179+
// pass
180+
} else if (argchunk == "repetition_mode") {
181+
params->last_n = 64;
182+
params->min_length = 1;
183+
params->mid_word_scale = 1.0f;
184+
params->flags = LLAMA_SEQREP_DIVIDE_BY_PENALTY;
185+
params->length_penalty = 1.0f;
186+
params->presence_penalty = argval.empty() ? 1.1f : std::atof(argval.c_str());
187+
} else if (argchunk == "presence_mode") {
188+
params->last_n = 64;
189+
params->min_length = 1;
190+
params->mid_word_scale = 1.0f;
191+
params->flags = 0;
192+
params->length_penalty = 0.0f;
193+
params->presence_penalty = std::atof(argval.c_str());
194+
} else if (argchunk == "frequency_mode") {
195+
params->last_n = 64;
196+
params->min_length = 1;
197+
params->mid_word_scale = 1.0f;
198+
params->flags = 0;
199+
params->length_penalty = std::atof(argval.c_str());
200+
params->presence_penalty = 0.0f;
201+
} else if (argchunk == "flag_immediate_wildcard") {
202+
params->flags |= LLAMA_SEQREP_IMMEDIATE_WILDCARD;
203+
} else if (argchunk == "flag_tolerance_no_consecutive") {
204+
params->flags |= LLAMA_SEQREP_TOLERANCE_NO_CONSECUTIVE;
205+
} else if (argchunk == "flag_tolerance_no_first") {
206+
params->flags |= LLAMA_SEQREP_TOLERANCE_NO_FIRST;
207+
} else if (argchunk == "flag_tolerance_cap_initial") {
208+
params->flags |= LLAMA_SEQREP_TOLERANCE_CAP_INITIAL;
209+
} else if (argchunk == "flag_penalize_length_max_seen") {
210+
params->flags |= LLAMA_SEQREP_PENALIZE_LENGTH_MAX_SEEN;
211+
} else if (argchunk == "flag_divide_by_penalty") {
212+
params->flags |= LLAMA_SEQREP_DIVIDE_BY_PENALTY;
213+
} else if (argchunk == "min_length") {
214+
params->min_length = std::atoi(argval.c_str());
215+
} else if (argchunk == "start_offset") {
216+
params->start_offset = std::atoi(argval.c_str());
217+
} else if (argchunk == "last_n") {
218+
params->last_n = std::atoi(argval.c_str());
219+
} else if (argchunk == "tolerance") {
220+
params->tolerance = std::atof(argval.c_str());
221+
} else if (argchunk == "presence_penalty") {
222+
params->presence_penalty = std::atof(argval.c_str());
223+
} else if (argchunk == "length_penalty") {
224+
params->length_penalty = std::atof(argval.c_str());
225+
} else if (argchunk == "mid_word_scale") {
226+
params->mid_word_scale = std::atof(argval.c_str());
227+
} else if (argchunk == "tolerance_match_credit") {
228+
params->tolerance_match_credit = std::atof(argval.c_str());
229+
} else if (argchunk == "tolerance_half_step_cost") {
230+
params->tolerance_half_step_cost = std::atof(argval.c_str());
231+
} else {
232+
fprintf(stderr, "seqrep: Bad argument [%s]=[%s]!\n", argchunk.c_str(), argval.c_str());
233+
return false;
234+
}
235+
if (argsep != std::string::npos) {
236+
offset = argsep + 1;
237+
} else {
238+
break;
239+
}
240+
}
241+
return true;
242+
}
243+
103244
bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
104245
bool invalid_param = false;
105246
std::string arg;
@@ -246,6 +387,25 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
246387
break;
247388
}
248389
params.presence_penalty = std::stof(argv[i]);
390+
} else if (arg == "-seqrep" || arg == "--seqrep-penalty") {
391+
if (++i >= argc) {
392+
invalid_param = true;
393+
break;
394+
}
395+
if (std::strcmp(argv[i], "help") == 0) {
396+
seqrep_sampler_help();
397+
exit(0);
398+
}
399+
llama_sampler_seqrep_params sr_params;
400+
seqrep_sampler_params_init(&sr_params);
401+
if (!seqrep_sampler_params_parse(argv[i], &sr_params)) {
402+
seqrep_sampler_help();
403+
exit(1);
404+
}
405+
if (sr_params.last_n != 0 && sr_params.min_length > 0
406+
&& (sr_params.presence_penalty != 0.0f || sr_params.length_penalty != 0.0f)) {
407+
params.seqrep_params.push_back(sr_params);
408+
}
249409
} else if (arg == "--mirostat") {
250410
if (++i >= argc) {
251411
invalid_param = true;
@@ -608,6 +768,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
608768
fprintf(stdout, " --repeat-penalty N penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)\n", (double)params.repeat_penalty);
609769
fprintf(stdout, " --presence-penalty N repeat alpha presence penalty (default: %.1f, 0.0 = disabled)\n", (double)params.presence_penalty);
610770
fprintf(stdout, " --frequency-penalty N repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)\n", (double)params.frequency_penalty);
771+
fprintf(stdout, " -seqrep CFG, --seqrep-penalty CFG\n");
772+
fprintf(stdout, " add a copy of the sequence repetition penalty sampler. may be specified multiple times. for help: -seqrep help\n");
611773
fprintf(stdout, " --mirostat N use Mirostat sampling.\n");
612774
fprintf(stdout, " Top K, Nucleus, Tail Free and Locally Typical samplers are ignored if used.\n");
613775
fprintf(stdout, " (default: %d, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)\n", params.mirostat);

common/common.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ struct gpt_params {
5151
int32_t repeat_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
5252
float frequency_penalty = 0.00f; // 0.0 = disabled
5353
float presence_penalty = 0.00f; // 0.0 = disabled
54+
std::vector<llama_sampler_seqrep_params> seqrep_params;
5455
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
5556
float mirostat_tau = 5.00f; // target entropy
5657
float mirostat_eta = 0.10f; // learning rate
@@ -165,3 +166,7 @@ std::string get_sortable_timestamp();
165166
void dump_non_result_info_yaml(
166167
FILE * stream, const gpt_params & params, const llama_context * lctx,
167168
const std::string & timestamp, const std::vector<int> & prompt_tokens, const char * model_desc);
169+
170+
void seqrep_sampler_params_init(llama_sampler_seqrep_params * params);
171+
void seqrep_sampler_params_dump(FILE * fp, llama_sampler_seqrep_params * params);
172+
bool seqrep_sampler_params_parse(char * s, llama_sampler_seqrep_params * params);

examples/main/main.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -423,6 +423,9 @@ int main(int argc, char ** argv) {
423423
LOG_TEE("sampling: repeat_last_n = %d, repeat_penalty = %f, presence_penalty = %f, frequency_penalty = %f, top_k = %d, tfs_z = %f, top_p = %f, typical_p = %f, temp = %f, mirostat = %d, mirostat_lr = %f, mirostat_ent = %f\n",
424424
params.repeat_last_n, params.repeat_penalty, params.presence_penalty, params.frequency_penalty, params.top_k, params.tfs_z, params.top_p, params.typical_p, params.temp, params.mirostat, params.mirostat_eta, params.mirostat_tau);
425425
LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
426+
for (auto & sr_params : params.seqrep_params) {
427+
seqrep_sampler_params_dump(stderr, &sr_params);
428+
}
426429
LOG_TEE("\n\n");
427430

428431
grammar_parser::parse_state parsed_grammar;
@@ -691,6 +694,11 @@ int main(int argc, char ** argv) {
691694
llama_sample_frequency_and_presence_penalties(ctx, &cur_p,
692695
last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
693696
last_n_repeat, alpha_frequency, alpha_presence);
697+
698+
for (auto & sr_params : params.seqrep_params) {
699+
llama_sample_seqrep_penalty(ctx, &cur_p, last_n_tokens.data(), last_n_tokens.size(), &sr_params);
700+
}
701+
694702
if (!penalize_nl) {
695703
for (size_t idx = 0; idx < cur_p.size; idx++) {
696704
if (cur_p.data[idx].id == llama_token_nl(ctx)) {

0 commit comments

Comments
 (0)