Skip to content

[RFC] common, server : add top-a sampler #5612

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,12 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
break;
}
sparams.min_p = std::stof(argv[i]);
} else if (arg == "--top-a") {
if (++i >= argc) {
invalid_param = true;
break;
}
sparams.top_a = std::stof(argv[i]);
} else if (arg == "--temp") {
if (++i >= argc) {
invalid_param = true;
Expand Down Expand Up @@ -984,6 +990,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" --top-k N top-k sampling (default: %d, 0 = disabled)\n", sparams.top_k);
printf(" --top-p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)sparams.top_p);
printf(" --min-p N min-p sampling (default: %.1f, 0.0 = disabled)\n", (double)sparams.min_p);
printf(" --top-a N top-a sampling (default: %.1f, 0.0 = disabled)\n", (double)sparams.top_a);
printf(" --tfs N tail free sampling, parameter z (default: %.1f, 1.0 = disabled)\n", (double)sparams.tfs_z);
printf(" --typical N locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)\n", (double)sparams.typical_p);
printf(" --repeat-last-n N last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)\n", sparams.penalty_last_n);
Expand Down Expand Up @@ -1157,6 +1164,7 @@ std::vector<llama_sampler_type> sampler_types_from_names(const std::vector<std::
{"top_p", llama_sampler_type::TOP_P},
{"typical_p", llama_sampler_type::TYPICAL_P},
{"min_p", llama_sampler_type::MIN_P},
{"top_a", llama_sampler_type::TOP_A},
{"tfs_z", llama_sampler_type::TFS_Z},
{"temperature", llama_sampler_type::TEMPERATURE}
};
Expand All @@ -1170,6 +1178,7 @@ std::vector<llama_sampler_type> sampler_types_from_names(const std::vector<std::
{"typical-p", llama_sampler_type::TYPICAL_P},
{"typical", llama_sampler_type::TYPICAL_P},
{"min-p", llama_sampler_type::MIN_P},
{"top-a", llama_sampler_type::TOP_A},
{"tfs-z", llama_sampler_type::TFS_Z},
{"tfs", llama_sampler_type::TFS_Z},
{"temp", llama_sampler_type::TEMPERATURE}
Expand Down Expand Up @@ -1205,6 +1214,7 @@ std::vector<llama_sampler_type> sampler_types_from_chars(const std::string & nam
{'p', llama_sampler_type::TOP_P},
{'y', llama_sampler_type::TYPICAL_P},
{'m', llama_sampler_type::MIN_P},
{'a', llama_sampler_type::TOP_A},
{'f', llama_sampler_type::TFS_Z},
{'t', llama_sampler_type::TEMPERATURE}
};
Expand All @@ -1227,6 +1237,7 @@ std::string sampler_type_to_name_string(llama_sampler_type sampler_type) {
case llama_sampler_type::TYPICAL_P: return "typical_p";
case llama_sampler_type::TOP_P: return "top_p";
case llama_sampler_type::MIN_P: return "min_p";
case llama_sampler_type::TOP_A: return "top_a";
case llama_sampler_type::TEMPERATURE: return "temperature";
default : return "";
}
Expand Down Expand Up @@ -1773,6 +1784,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
fprintf(stream, "top_k: %d # default: 40\n", sparams.top_k);
fprintf(stream, "top_p: %f # default: 0.95\n", sparams.top_p);
fprintf(stream, "min_p: %f # default: 0.0\n", sparams.min_p);
fprintf(stream, "top_a: %f # default: 0.0\n", sparams.top_a);
fprintf(stream, "typical_p: %f # default: 1.0\n", sparams.typical_p);
fprintf(stream, "verbose_prompt: %s # default: false\n", params.verbose_prompt ? "true" : "false");
fprintf(stream, "display_prompt: %s # default: true\n", params.display_prompt ? "true" : "false");
Expand Down
6 changes: 4 additions & 2 deletions common/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,10 @@ std::string llama_sampling_print(const llama_sampling_params & params) {

snprintf(result, sizeof(result),
"\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n"
"\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, min_p = %.3f, typical_p = %.3f, temp = %.3f\n"
"\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, min_p = %.3f, top_a = %.3f, typical_p = %.3f, temp = %.3f\n"
"\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f",
params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present,
params.top_k, params.tfs_z, params.top_p, params.min_p, params.typical_p, params.temp,
params.top_k, params.tfs_z, params.top_p, params.min_p, params.top_a, params.typical_p, params.temp,
params.mirostat, params.mirostat_eta, params.mirostat_tau);

return std::string(result);
Expand Down Expand Up @@ -128,6 +128,7 @@ static void sampler_queue(
const int32_t top_k = params.top_k;
const float top_p = params.top_p;
const float min_p = params.min_p;
const float top_a = params.top_a;
const float tfs_z = params.tfs_z;
const float typical_p = params.typical_p;
const std::vector<llama_sampler_type> & samplers_sequence = params.samplers_sequence;
Expand All @@ -139,6 +140,7 @@ static void sampler_queue(
case llama_sampler_type::TYPICAL_P: llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep); break;
case llama_sampler_type::TOP_P : llama_sample_top_p (ctx_main, &cur_p, top_p, min_keep); break;
case llama_sampler_type::MIN_P : llama_sample_min_p (ctx_main, &cur_p, min_p, min_keep); break;
case llama_sampler_type::TOP_A : llama_sample_top_a (ctx_main, &cur_p, top_a, min_keep); break;
case llama_sampler_type::TEMPERATURE:
if (dynatemp_range > 0) {
float dynatemp_min = std::max(0.0f, temp - dynatemp_range);
Expand Down
3 changes: 3 additions & 0 deletions common/sampling.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ enum class llama_sampler_type : char {
TOP_K = 'k',
TOP_P = 'p',
MIN_P = 'm',
TOP_A = 'a',
TFS_Z = 'f',
TYPICAL_P = 'y',
TEMPERATURE = 't'
Expand All @@ -26,6 +27,7 @@ typedef struct llama_sampling_params {
int32_t top_k = 40; // <= 0 to use vocab size
float top_p = 0.95f; // 1.0 = disabled
float min_p = 0.05f; // 0.0 = disabled
float top_a = 0.00f; // 0.0 = disabled
float tfs_z = 1.00f; // 1.0 = disabled
float typical_p = 1.00f; // 1.0 = disabled
float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities
Expand All @@ -46,6 +48,7 @@ typedef struct llama_sampling_params {
llama_sampler_type::TYPICAL_P,
llama_sampler_type::TOP_P,
llama_sampler_type::MIN_P,
llama_sampler_type::TOP_A,
llama_sampler_type::TEMPERATURE
};

Expand Down
2 changes: 2 additions & 0 deletions examples/server/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,8 @@ node index.js

`min_p`: The minimum probability for a token to be considered, relative to the probability of the most likely token (default: 0.05).

`top_a`: Limit the next token selection to a subset of tokens with a probability above a*P^2, where P is the most probable token (default: 0.0, 0.0 = disabled).

`n_predict`: Set the maximum number of tokens to predict when generating text. **Note:** May exceed the set limit slightly if the last token is a partial multibyte character. When 0, no tokens will be generated but the prompt is evaluated into the cache. (default: -1, -1 = infinity).

`n_keep`: Specify the number of tokens from the prompt to retain when the context size is exceeded and tokens need to be discarded.
Expand Down
2 changes: 2 additions & 0 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -816,6 +816,7 @@ struct server_context {
slot.sparams.top_k = json_value(data, "top_k", default_sparams.top_k);
slot.sparams.top_p = json_value(data, "top_p", default_sparams.top_p);
slot.sparams.min_p = json_value(data, "min_p", default_sparams.min_p);
slot.sparams.top_a = json_value(data, "top_a", default_sparams.top_a);
slot.sparams.tfs_z = json_value(data, "tfs_z", default_sparams.tfs_z);
slot.sparams.typical_p = json_value(data, "typical_p", default_sparams.typical_p);
slot.sparams.temp = json_value(data, "temperature", default_sparams.temp);
Expand Down Expand Up @@ -1194,6 +1195,7 @@ struct server_context {
{"top_k", slot.sparams.top_k},
{"top_p", slot.sparams.top_p},
{"min_p", slot.sparams.min_p},
{"top_a", slot.sparams.top_a},
{"tfs_z", slot.sparams.tfs_z},
{"typical_p", slot.sparams.typical_p},
{"repeat_last_n", slot.sparams.penalty_last_n},
Expand Down
14 changes: 11 additions & 3 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10725,7 +10725,7 @@ void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * can
}
}

void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) {
static void llama_sample_min_p_pow(struct llama_context * ctx, llama_token_data_array * candidates, float p, float pow, size_t min_keep) {
if (p <= 0.0f || !candidates->size) {
return;
}
Expand All @@ -10742,7 +10742,7 @@ void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * can
for (size_t i = 0; i < candidates->size; ++i) {
max_logit = std::max(max_logit, candidates->data[i].logit);
}
const float min_logit = max_logit + logf(p); // min logit for p_i >= p * p_max
const float min_logit = max_logit + logf(p) * pow; // min logit for p_i >= p * p_max^pow

for (size_t i = 0; i < candidates->size; ++i) {
if (candidates->data[i].logit >= min_logit) {
Expand All @@ -10768,7 +10768,7 @@ void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * can
candidates->sorted = true;
}

const float min_logit = candidates->data[0].logit + logf(p); // min logit for p_i >= p * p_max
const float min_logit = candidates->data[0].logit + logf(p) * pow; // min logit for p_i >= p * p_max^pow
size_t i = 1; // first token always matches

for (; i < candidates->size; ++i) {
Expand All @@ -10786,6 +10786,14 @@ void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * can
}
}

void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) {
llama_sample_min_p_pow(ctx, candidates, p, 1.f, min_keep);
}

void llama_sample_top_a(struct llama_context * ctx, llama_token_data_array * candidates, float a, size_t min_keep) {
llama_sample_min_p_pow(ctx, candidates, a, 2.f, min_keep);
}

void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep) {
if (z >= 1.0f || candidates->size <= 2) {
return;
Expand Down
7 changes: 7 additions & 0 deletions llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -813,6 +813,13 @@ extern "C" {
float p,
size_t min_keep);

/// @details Top-A sampling as described in https://github.com/BlinkDL/RWKV-LM/tree/4cb363e5aa31978d801a47bc89d28e927ab6912e#the-top-a-sampling-method
LLAMA_API void llama_sample_top_a(
struct llama_context * ctx,
llama_token_data_array * candidates,
float a,
size_t min_keep);

/// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
LLAMA_API void llama_sample_tail_free(
struct llama_context * ctx,
Expand Down