diff --git a/common/common.cpp b/common/common.cpp index 16ef4d7f74dd9..0ea4df662752a 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -376,6 +376,12 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { break; } sparams.min_p = std::stof(argv[i]); + } else if (arg == "--top-a") { + if (++i >= argc) { + invalid_param = true; + break; + } + sparams.top_a = std::stof(argv[i]); } else if (arg == "--temp") { if (++i >= argc) { invalid_param = true; @@ -984,6 +990,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" --top-k N top-k sampling (default: %d, 0 = disabled)\n", sparams.top_k); printf(" --top-p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)sparams.top_p); printf(" --min-p N min-p sampling (default: %.1f, 0.0 = disabled)\n", (double)sparams.min_p); + printf(" --top-a N top-a sampling (default: %.1f, 0.0 = disabled)\n", (double)sparams.top_a); printf(" --tfs N tail free sampling, parameter z (default: %.1f, 1.0 = disabled)\n", (double)sparams.tfs_z); printf(" --typical N locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)\n", (double)sparams.typical_p); printf(" --repeat-last-n N last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)\n", sparams.penalty_last_n); @@ -1157,6 +1164,7 @@ std::vector sampler_types_from_names(const std::vector sampler_types_from_names(const std::vector sampler_types_from_chars(const std::string & nam {'p', llama_sampler_type::TOP_P}, {'y', llama_sampler_type::TYPICAL_P}, {'m', llama_sampler_type::MIN_P}, + {'a', llama_sampler_type::TOP_A}, {'f', llama_sampler_type::TFS_Z}, {'t', llama_sampler_type::TEMPERATURE} }; @@ -1227,6 +1237,7 @@ std::string sampler_type_to_name_string(llama_sampler_type sampler_type) { case llama_sampler_type::TYPICAL_P: return "typical_p"; case llama_sampler_type::TOP_P: return "top_p"; case llama_sampler_type::MIN_P: return "min_p"; + case llama_sampler_type::TOP_A: return "top_a"; case llama_sampler_type::TEMPERATURE: return "temperature"; default : return ""; } @@ -1773,6 +1784,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l fprintf(stream, "top_k: %d # default: 40\n", sparams.top_k); fprintf(stream, "top_p: %f # default: 0.95\n", sparams.top_p); fprintf(stream, "min_p: %f # default: 0.0\n", sparams.min_p); + fprintf(stream, "top_a: %f # default: 0.0\n", sparams.top_a); fprintf(stream, "typical_p: %f # default: 1.0\n", sparams.typical_p); fprintf(stream, "verbose_prompt: %s # default: false\n", params.verbose_prompt ? "true" : "false"); fprintf(stream, "display_prompt: %s # default: true\n", params.display_prompt ? "true" : "false"); diff --git a/common/sampling.cpp b/common/sampling.cpp index 823031febc7e2..36c8ee8b26e04 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -91,10 +91,10 @@ std::string llama_sampling_print(const llama_sampling_params & params) { snprintf(result, sizeof(result), "\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n" - "\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, min_p = %.3f, typical_p = %.3f, temp = %.3f\n" + "\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, min_p = %.3f, top_a = %.3f, typical_p = %.3f, temp = %.3f\n" "\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f", params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present, - params.top_k, params.tfs_z, params.top_p, params.min_p, params.typical_p, params.temp, + params.top_k, params.tfs_z, params.top_p, params.min_p, params.top_a, params.typical_p, params.temp, params.mirostat, params.mirostat_eta, params.mirostat_tau); return std::string(result); @@ -128,6 +128,7 @@ static void sampler_queue( const int32_t top_k = params.top_k; const float top_p = params.top_p; const float min_p = params.min_p; + const float top_a = params.top_a; const float tfs_z = params.tfs_z; const float typical_p = params.typical_p; const std::vector & samplers_sequence = params.samplers_sequence; @@ -139,6 +140,7 @@ static void sampler_queue( case llama_sampler_type::TYPICAL_P: llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep); break; case llama_sampler_type::TOP_P : llama_sample_top_p (ctx_main, &cur_p, top_p, min_keep); break; case llama_sampler_type::MIN_P : llama_sample_min_p (ctx_main, &cur_p, min_p, min_keep); break; + case llama_sampler_type::TOP_A : llama_sample_top_a (ctx_main, &cur_p, top_a, min_keep); break; case llama_sampler_type::TEMPERATURE: if (dynatemp_range > 0) { float dynatemp_min = std::max(0.0f, temp - dynatemp_range); diff --git a/common/sampling.h b/common/sampling.h index 48b2459d1f944..fa6b751ac8fae 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -13,6 +13,7 @@ enum class llama_sampler_type : char { TOP_K = 'k', TOP_P = 'p', MIN_P = 'm', + TOP_A = 'a', TFS_Z = 'f', TYPICAL_P = 'y', TEMPERATURE = 't' @@ -26,6 +27,7 @@ typedef struct llama_sampling_params { int32_t top_k = 40; // <= 0 to use vocab size float top_p = 0.95f; // 1.0 = disabled float min_p = 0.05f; // 0.0 = disabled + float top_a = 0.00f; // 0.0 = disabled float tfs_z = 1.00f; // 1.0 = disabled float typical_p = 1.00f; // 1.0 = disabled float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities @@ -46,6 +48,7 @@ typedef struct llama_sampling_params { llama_sampler_type::TYPICAL_P, llama_sampler_type::TOP_P, llama_sampler_type::MIN_P, + llama_sampler_type::TOP_A, llama_sampler_type::TEMPERATURE }; diff --git a/examples/server/README.md b/examples/server/README.md index 23606b32a2c81..dd17a03b805a0 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -213,6 +213,8 @@ node index.js `min_p`: The minimum probability for a token to be considered, relative to the probability of the most likely token (default: 0.05). + `top_a`: Limit the next token selection to a subset of tokens with a probability above a*P^2, where P is the most probable token (default: 0.0, 0.0 = disabled). + `n_predict`: Set the maximum number of tokens to predict when generating text. **Note:** May exceed the set limit slightly if the last token is a partial multibyte character. When 0, no tokens will be generated but the prompt is evaluated into the cache. (default: -1, -1 = infinity). `n_keep`: Specify the number of tokens from the prompt to retain when the context size is exceeded and tokens need to be discarded. diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 796f3499c9877..4b7979fdc8822 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -816,6 +816,7 @@ struct server_context { slot.sparams.top_k = json_value(data, "top_k", default_sparams.top_k); slot.sparams.top_p = json_value(data, "top_p", default_sparams.top_p); slot.sparams.min_p = json_value(data, "min_p", default_sparams.min_p); + slot.sparams.top_a = json_value(data, "top_a", default_sparams.top_a); slot.sparams.tfs_z = json_value(data, "tfs_z", default_sparams.tfs_z); slot.sparams.typical_p = json_value(data, "typical_p", default_sparams.typical_p); slot.sparams.temp = json_value(data, "temperature", default_sparams.temp); @@ -1194,6 +1195,7 @@ struct server_context { {"top_k", slot.sparams.top_k}, {"top_p", slot.sparams.top_p}, {"min_p", slot.sparams.min_p}, + {"top_a", slot.sparams.top_a}, {"tfs_z", slot.sparams.tfs_z}, {"typical_p", slot.sparams.typical_p}, {"repeat_last_n", slot.sparams.penalty_last_n}, diff --git a/llama.cpp b/llama.cpp index c58a029f74faf..9b89f32795809 100644 --- a/llama.cpp +++ b/llama.cpp @@ -10725,7 +10725,7 @@ void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * can } } -void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) { +static void llama_sample_min_p_pow(struct llama_context * ctx, llama_token_data_array * candidates, float p, float pow, size_t min_keep) { if (p <= 0.0f || !candidates->size) { return; } @@ -10742,7 +10742,7 @@ void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * can for (size_t i = 0; i < candidates->size; ++i) { max_logit = std::max(max_logit, candidates->data[i].logit); } - const float min_logit = max_logit + logf(p); // min logit for p_i >= p * p_max + const float min_logit = max_logit + logf(p) * pow; // min logit for p_i >= p * p_max^pow for (size_t i = 0; i < candidates->size; ++i) { if (candidates->data[i].logit >= min_logit) { @@ -10768,7 +10768,7 @@ void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * can candidates->sorted = true; } - const float min_logit = candidates->data[0].logit + logf(p); // min logit for p_i >= p * p_max + const float min_logit = candidates->data[0].logit + logf(p) * pow; // min logit for p_i >= p * p_max^pow size_t i = 1; // first token always matches for (; i < candidates->size; ++i) { @@ -10786,6 +10786,14 @@ void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * can } } +void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) { + llama_sample_min_p_pow(ctx, candidates, p, 1.f, min_keep); +} + +void llama_sample_top_a(struct llama_context * ctx, llama_token_data_array * candidates, float a, size_t min_keep) { + llama_sample_min_p_pow(ctx, candidates, a, 2.f, min_keep); +} + void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep) { if (z >= 1.0f || candidates->size <= 2) { return; diff --git a/llama.h b/llama.h index 7a107c7f335d5..130eed484e3cb 100644 --- a/llama.h +++ b/llama.h @@ -813,6 +813,13 @@ extern "C" { float p, size_t min_keep); + /// @details Top-A sampling as described in https://github.com/BlinkDL/RWKV-LM/tree/4cb363e5aa31978d801a47bc89d28e927ab6912e#the-top-a-sampling-method + LLAMA_API void llama_sample_top_a( + struct llama_context * ctx, + llama_token_data_array * candidates, + float a, + size_t min_keep); + /// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/. LLAMA_API void llama_sample_tail_free( struct llama_context * ctx,