Skip to content

Commit 1b23152

Browse files
committed
Add top nsigma to sampler chain config
1 parent 64e1af7 commit 1b23152

File tree

2 files changed

+47
-44
lines changed

2 files changed

+47
-44
lines changed

common/common.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ enum common_sampler_type {
9696
COMMON_SAMPLER_TYPE_XTC = 8,
9797
COMMON_SAMPLER_TYPE_INFILL = 9,
9898
COMMON_SAMPLER_TYPE_PENALTIES = 10,
99+
COMMON_SAMPLER_TYPE_TOP_NSIGMA = 11,
99100
};
100101

101102
// dimensionality reduction methods, used by cvector-generator

common/sampling.cpp

Lines changed: 46 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -229,51 +229,48 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
229229
params.logit_bias.data()));
230230

231231
if (params.mirostat == 0) {
232-
if (params.top_n_sigma >= 0) {
233-
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k));
234-
llama_sampler_chain_add(result->chain, llama_sampler_init_temp (params.temp));
235-
llama_sampler_chain_add(result->chain, llama_sampler_init_top_n_sigma (params.top_n_sigma));
236-
} else {
237-
for (const auto & cnstr : params.samplers) {
238-
switch (cnstr) {
239-
case COMMON_SAMPLER_TYPE_DRY:
240-
{
241-
std::vector<const char *> c_breakers;
242-
c_breakers.reserve(params.dry_sequence_breakers.size());
243-
for (const auto & str : params.dry_sequence_breakers) {
244-
c_breakers.push_back(str.c_str());
245-
}
246-
247-
llama_sampler_chain_add(result->chain, llama_sampler_init_dry (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
232+
for (const auto & cnstr : params.samplers) {
233+
switch (cnstr) {
234+
case COMMON_SAMPLER_TYPE_DRY:
235+
{
236+
std::vector<const char *> c_breakers;
237+
c_breakers.reserve(params.dry_sequence_breakers.size());
238+
for (const auto & str : params.dry_sequence_breakers) {
239+
c_breakers.push_back(str.c_str());
248240
}
249-
break;
250-
case COMMON_SAMPLER_TYPE_TOP_K:
251-
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k));
252-
break;
253-
case COMMON_SAMPLER_TYPE_TOP_P:
254-
llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, params.min_keep));
255-
break;
256-
case COMMON_SAMPLER_TYPE_MIN_P:
257-
llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep));
258-
break;
259-
case COMMON_SAMPLER_TYPE_XTC:
260-
llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
261-
break;
262-
case COMMON_SAMPLER_TYPE_TYPICAL_P:
263-
llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep));
264-
break;
265-
case COMMON_SAMPLER_TYPE_TEMPERATURE:
266-
llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
267-
break;
268-
case COMMON_SAMPLER_TYPE_INFILL:
269-
llama_sampler_chain_add(result->chain, llama_sampler_init_infill (vocab));
270-
break;
271-
case COMMON_SAMPLER_TYPE_PENALTIES:
272-
llama_sampler_chain_add(result->chain, llama_sampler_init_penalties(params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
273-
break;
274-
default:
275-
GGML_ASSERT(false && "unknown sampler type");
276-
}
241+
242+
llama_sampler_chain_add(result->chain, llama_sampler_init_dry (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
243+
}
244+
break;
245+
case COMMON_SAMPLER_TYPE_TOP_K:
246+
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k));
247+
break;
248+
case COMMON_SAMPLER_TYPE_TOP_P:
249+
llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, params.min_keep));
250+
break;
251+
case COMMON_SAMPLER_TYPE_MIN_P:
252+
llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep));
253+
break;
254+
case COMMON_SAMPLER_TYPE_XTC:
255+
llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
256+
break;
257+
case COMMON_SAMPLER_TYPE_TYPICAL_P:
258+
llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep));
259+
break;
260+
case COMMON_SAMPLER_TYPE_TEMPERATURE:
261+
llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
262+
break;
263+
case COMMON_SAMPLER_TYPE_INFILL:
264+
llama_sampler_chain_add(result->chain, llama_sampler_init_infill (vocab));
265+
break;
266+
case COMMON_SAMPLER_TYPE_PENALTIES:
267+
llama_sampler_chain_add(result->chain, llama_sampler_init_penalties(params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
268+
break;
269+
case COMMON_SAMPLER_TYPE_TOP_NSIGMA:
270+
llama_sampler_chain_add(result->chain, llama_sampler_init_top_n_sigma(params.top_n_sigma));
271+
break;
272+
default:
273+
GGML_ASSERT(false && "unknown sampler type");
277274
}
278275
}
279276
llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed));
@@ -480,6 +477,7 @@ char common_sampler_type_to_chr(enum common_sampler_type cnstr) {
480477
case COMMON_SAMPLER_TYPE_XTC: return 'x';
481478
case COMMON_SAMPLER_TYPE_INFILL: return 'i';
482479
case COMMON_SAMPLER_TYPE_PENALTIES: return 'e';
480+
case COMMON_SAMPLER_TYPE_TOP_NSIGMA: return 's';
483481
default : return '?';
484482
}
485483
}
@@ -495,6 +493,7 @@ std::string common_sampler_type_to_str(enum common_sampler_type cnstr) {
495493
case COMMON_SAMPLER_TYPE_XTC: return "xtc";
496494
case COMMON_SAMPLER_TYPE_INFILL: return "infill";
497495
case COMMON_SAMPLER_TYPE_PENALTIES: return "penalties";
496+
case COMMON_SAMPLER_TYPE_TOP_NSIGMA: return "top_n_sigma";
498497
default : return "";
499498
}
500499
}
@@ -510,6 +509,7 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
510509
{ "xtc", COMMON_SAMPLER_TYPE_XTC },
511510
{ "infill", COMMON_SAMPLER_TYPE_INFILL },
512511
{ "penalties", COMMON_SAMPLER_TYPE_PENALTIES },
512+
{ "top_n_sigma", COMMON_SAMPLER_TYPE_TOP_NSIGMA},
513513
};
514514

515515
// since samplers names are written multiple ways
@@ -524,6 +524,7 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
524524
{ "typ", COMMON_SAMPLER_TYPE_TYPICAL_P },
525525
{ "min-p", COMMON_SAMPLER_TYPE_MIN_P },
526526
{ "temp", COMMON_SAMPLER_TYPE_TEMPERATURE },
527+
{ "top-n-sigma", COMMON_SAMPLER_TYPE_TOP_NSIGMA},
527528
};
528529

529530
std::vector<common_sampler_type> samplers;
@@ -557,6 +558,7 @@ std::vector<common_sampler_type> common_sampler_types_from_chars(const std::stri
557558
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_XTC), COMMON_SAMPLER_TYPE_XTC },
558559
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_INFILL), COMMON_SAMPLER_TYPE_INFILL },
559560
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_PENALTIES), COMMON_SAMPLER_TYPE_PENALTIES },
561+
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_NSIGMA), COMMON_SAMPLER_TYPE_TOP_NSIGMA},
560562
};
561563

562564
std::vector<common_sampler_type> samplers;

0 commit comments

Comments
 (0)