Skip to content

Commit 5ee99c3

Browse files
authored
common, server : surface min_keep as its own parameter (#5567)
* Feature - surface min_keep as its own parameter * Updated README with min_keep param
1 parent c145f8a commit 5ee99c3

File tree

6 files changed

+14
-1
lines changed

6 files changed

+14
-1
lines changed

common/common.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1704,6 +1704,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
17041704
}
17051705
fprintf(stream, "lora_base: %s\n", params.lora_base.c_str());
17061706
fprintf(stream, "main_gpu: %d # default: 0\n", params.main_gpu);
1707+
fprintf(stream, "min_keep: %d # default: 0 (disabled)\n", sparams.min_keep);
17071708
fprintf(stream, "mirostat: %d # default: 0 (disabled)\n", sparams.mirostat);
17081709
fprintf(stream, "mirostat_ent: %f # default: 5.0\n", sparams.mirostat_tau);
17091710
fprintf(stream, "mirostat_lr: %f # default: 0.1\n", sparams.mirostat_eta);

common/sampling.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,10 @@ static llama_token llama_sampling_sample_impl(
248248
llama_sample_temp(ctx_main, &cur_p, temp);
249249
id = llama_sample_token_mirostat_v2(ctx_main, &cur_p, mirostat_tau, mirostat_eta, &ctx_sampling->mirostat_mu);
250250
} else {
251-
sampler_queue(ctx_main, params, cur_p, 1);
251+
// temperature sampling
252+
size_t min_keep = std::max(1, params.min_keep);
253+
254+
sampler_queue(ctx_main, params, cur_p, min_keep);
252255

253256
id = llama_sample_token(ctx_main, &cur_p);
254257

common/sampling.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ enum class llama_sampler_type : char {
2222
typedef struct llama_sampling_params {
2323
int32_t n_prev = 64; // number of previous tokens to remember
2424
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
25+
int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens
2526
int32_t top_k = 40; // <= 0 to use vocab size
2627
float top_p = 0.95f; // 1.0 = disabled
2728
float min_p = 0.05f; // 0.0 = disabled

examples/server/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,8 @@ node index.js
199199

200200
`n_probs`: If greater than 0, the response also contains the probabilities of top N tokens for each generated token (default: 0)
201201

202+
`min_keep`: If greater than 0, force samplers to return N possible tokens at minimum (default: 0)
203+
202204
`image_data`: An array of objects to hold base64-encoded image `data` and its `id`s to be reference in `prompt`. You can determine the place of the image in the prompt as in the following: `USER:[img-12]Describe the image in detail.\nASSISTANT:`. In this case, `[img-12]` will be replaced by the embeddings of the image with id `12` in the following `image_data` array: `{..., "image_data": [{"data": "<BASE64_STRING>", "id": 12}]}`. Use `image_data` only with multimodal models, e.g., LLaVA.
203205

204206
`slot_id`: Assign the completion task to an specific slot. If is -1 the task will be assigned to a Idle slot (default: -1)

examples/server/public/index.html

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,7 @@
234234
mirostat_eta: 0.1, // learning rate
235235
grammar: '',
236236
n_probs: 0, // no completion_probabilities,
237+
min_keep: 0, // min probs from each sampler,
237238
image_data: [],
238239
cache_prompt: true,
239240
api_key: ''
@@ -791,6 +792,9 @@
791792
<fieldset>
792793
${IntField({ label: "Show Probabilities", max: 10, min: 0, name: "n_probs", value: params.value.n_probs })}
793794
</fieldset>
795+
<fieldset>
796+
${IntField({ label: "Min Probabilities from each Sampler", max: 10, min: 0, name: "min_keep", value: params.value.min_keep })}
797+
</fieldset>
794798
<fieldset>
795799
<label for="api_key">API Key</label>
796800
<input type="text" name="api_key" value="${params.value.api_key}" placeholder="Enter API key" oninput=${updateParams} />

examples/server/server.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -548,6 +548,7 @@ struct llama_server_context
548548
slot->params.seed = json_value(data, "seed", default_params.seed);
549549
slot->sparams.grammar = json_value(data, "grammar", default_sparams.grammar);
550550
slot->sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
551+
slot->sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep);
551552

552553
if (slot->n_predict > 0 && slot->params.n_predict > slot->n_predict) {
553554
// Might be better to reject the request with a 400 ?
@@ -1093,6 +1094,7 @@ struct llama_server_context
10931094
{"stream", slot.params.stream},
10941095
{"logit_bias", slot.sparams.logit_bias},
10951096
{"n_probs", slot.sparams.n_probs},
1097+
{"min_keep", slot.sparams.min_keep},
10961098
{"grammar", slot.sparams.grammar},
10971099
{"samplers", samplers_sequence}
10981100
};

0 commit comments

Comments
 (0)