Skip to content

Commit e8dbe04

Browse files
committed
llama : introduce llama_sampling_params
1 parent ab5a99e commit e8dbe04

File tree

23 files changed

+297
-183
lines changed

23 files changed

+297
-183
lines changed

common/common.cpp

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,7 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
252252
bool invalid_param = false;
253253
std::string arg;
254254
const std::string arg_prefix = "--";
255-
llama_sampling_params & sparams = params.sparams;
255+
auto & sparams = params.sparams;
256256

257257
for (int i = 1; i < argc; i++) {
258258
arg = argv[i];
@@ -320,7 +320,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
320320
bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_params & params, int & i, bool & invalid_param) {
321321
const char split_delim = ',';
322322

323-
llama_sampling_params & sparams = params.sparams;
323+
auto & sparams = params.sparams;
324324

325325
if (arg == "-s" || arg == "--seed") {
326326
CHECK_ARG
@@ -1039,7 +1039,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
10391039
return true;
10401040
}
10411041
if (arg == "--ignore-eos") {
1042-
params.ignore_eos = true;
1042+
sparams.ignore_eos = true;
10431043
return true;
10441044
}
10451045
if (arg == "--penalize-nl") {
@@ -1054,7 +1054,8 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
10541054
std::string value_str;
10551055
try {
10561056
if (ss >> key && ss >> sign && std::getline(ss, value_str) && (sign == '+' || sign == '-')) {
1057-
sparams.logit_bias[key] = std::stof(value_str) * ((sign == '-') ? -1.0f : 1.0f);
1057+
const float bias = std::stof(value_str) * ((sign == '-') ? -1.0f : 1.0f);
1058+
sparams.logit_bias.push_back({key, bias});
10581059
}
10591060
else {
10601061
throw std::exception();
@@ -1401,7 +1402,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
14011402
#endif
14021403

14031404
void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
1404-
const llama_sampling_params & sparams = params.sparams;
1405+
const auto & sparams = params.sparams;
14051406

14061407
std::string sampler_type_chars;
14071408
std::string sampler_type_names;
@@ -2165,8 +2166,9 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) {
21652166
llama_lora_adapters_apply(lctx, iparams.lora_adapters);
21662167
}
21672168

2168-
if (params.ignore_eos) {
2169-
params.sparams.logit_bias[llama_token_eos(model)] = -INFINITY;
2169+
if (params.sparams.ignore_eos && llama_token_eos(model) == -1) {
2170+
fprintf(stderr, "%s: warning: model does not have an EOS token, ignoring --ignore-eos\n", __func__);
2171+
params.sparams.ignore_eos = false;
21702172
}
21712173

21722174
if (params.warmup) {
@@ -3142,7 +3144,7 @@ void yaml_dump_string_multiline(FILE * stream, const char * prop_name, const cha
31423144

31433145
void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const llama_context * lctx,
31443146
const std::string & timestamp, const std::vector<int> & prompt_tokens, const char * model_desc) {
3145-
const llama_sampling_params & sparams = params.sparams;
3147+
const auto & sparams = params.sparams;
31463148

31473149
fprintf(stream, "build_commit: %s\n", LLAMA_COMMIT);
31483150
fprintf(stream, "build_number: %d\n", LLAMA_BUILD_NUMBER);
@@ -3205,10 +3207,7 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l
32053207
fprintf(stream, "grammar-file: # never logged, see grammar instead. Can still be specified for input.\n");
32063208
fprintf(stream, "hellaswag: %s # default: false\n", params.hellaswag ? "true" : "false");
32073209
fprintf(stream, "hellaswag_tasks: %zu # default: 400\n", params.hellaswag_tasks);
3208-
3209-
const auto logit_bias_eos = sparams.logit_bias.find(llama_token_eos(llama_get_model(lctx)));
3210-
const bool ignore_eos = logit_bias_eos != sparams.logit_bias.end() && logit_bias_eos->second == -INFINITY;
3211-
fprintf(stream, "ignore_eos: %s # default: false\n", ignore_eos ? "true" : "false");
3210+
fprintf(stream, "ignore_eos: %s # default: false\n", sparams.ignore_eos ? "true" : "false");
32123211

32133212
yaml_dump_string_multiline(stream, "in_prefix", params.input_prefix.c_str());
32143213
fprintf(stream, "in_prefix_bos: %s # default: false\n", params.input_prefix_bos ? "true" : "false");
@@ -3219,11 +3218,8 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l
32193218
fprintf(stream, "logdir: %s # default: unset (no logging)\n", params.logdir.c_str());
32203219

32213220
fprintf(stream, "logit_bias:\n");
3222-
for (std::pair<llama_token, float> lb : sparams.logit_bias) {
3223-
if (ignore_eos && lb.first == logit_bias_eos->first) {
3224-
continue;
3225-
}
3226-
fprintf(stream, " %d: %f", lb.first, lb.second);
3221+
for (const auto & logit_bias : sparams.logit_bias) {
3222+
fprintf(stream, " %d: %f", logit_bias.token, logit_bias.bias);
32273223
}
32283224

32293225
fprintf(stream, "lora:\n");

common/common.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,7 @@ struct gpt_params {
108108
enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings
109109
enum llama_attention_type attention_type = LLAMA_ATTENTION_TYPE_UNSPECIFIED; // attention type for embeddings
110110

111-
// // sampling parameters
112-
struct llama_sampling_params sparams;
111+
struct gpt_sampling_params sparams;
113112

114113
std::string model = ""; // model path
115114
std::string model_draft = ""; // draft model for speculative decoding
@@ -173,7 +172,6 @@ struct gpt_params {
173172
bool flash_attn = false; // flash attention
174173

175174
bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
176-
bool ignore_eos = false; // ignore generated EOS tokens
177175
bool logits_all = false; // return logits for all tokens in the batch
178176
bool use_mmap = true; // use mmap for faster loads
179177
bool use_mlock = false; // use mlock to keep model in memory

common/sampling.cpp

Lines changed: 49 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,49 @@
11
#include "sampling.h"
22

3-
#include <random>
3+
#include "common.h"
44

5-
struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params, const struct llama_model * model) {
5+
struct llama_sampling_context * llama_sampling_init(const struct gpt_sampling_params & params, const struct llama_model * model) {
66
struct llama_sampling_context * result = new llama_sampling_context();
77

88
result->params = params;
9-
result->smpl = llama_sampling_init(model, params.grammar.c_str(), "root");
9+
10+
{
11+
auto lp = llama_sampling_default_params();
12+
13+
lp.seed = params.seed;
14+
lp.n_prev = params.n_prev;
15+
lp.n_probs = params.n_probs;
16+
lp.min_keep = params.min_keep;
17+
lp.top_k = params.top_k;
18+
lp.top_p = params.top_p;
19+
lp.min_p = params.min_p;
20+
lp.tfs_z = params.tfs_z;
21+
lp.typical_p = params.typical_p;
22+
lp.temp = params.temp;
23+
lp.dynatemp_range = params.dynatemp_range;
24+
lp.dynatemp_exponent = params.dynatemp_exponent;
25+
lp.penalty_last_n = params.penalty_last_n;
26+
lp.penalty_repeat = params.penalty_repeat;
27+
lp.penalty_freq = params.penalty_freq;
28+
lp.penalty_present = params.penalty_present;
29+
lp.mirostat = params.mirostat;
30+
lp.mirostat_tau = params.mirostat_tau;
31+
lp.mirostat_eta = params.mirostat_eta;
32+
lp.penalize_nl = params.penalize_nl;
33+
lp.ignore_eos = params.ignore_eos;
34+
35+
result->smpl = llama_sampling_init(model, lp);
36+
37+
llama_sampling_set_rng_seed (result->smpl, params.seed);
38+
llama_sampling_set_grammar (result->smpl, params.grammar.c_str(), "root");
39+
llama_sampling_set_cfg (result->smpl, params.cfg_negative_prompt.c_str(), params.cfg_scale);
40+
llama_sampling_set_logit_bias(result->smpl, params.logit_bias.size(), params.logit_bias.data());
41+
}
1042

1143
result->prev.resize(params.n_prev);
1244

1345
result->n_valid = 0;
1446

15-
llama_sampling_set_rng_seed(result->smpl, params.seed);
16-
1747
return result;
1848
}
1949

@@ -24,7 +54,7 @@ void llama_sampling_free(struct llama_sampling_context * ctx) {
2454
}
2555

2656
void llama_sampling_reset(llama_sampling_context * ctx) {
27-
llama_sampling_reset(ctx->smpl, ctx->params.grammar.c_str(), "root");
57+
llama_sampling_reset(ctx->smpl);
2858

2959
std::fill(ctx->prev.begin(), ctx->prev.end(), 0);
3060
ctx->cur.clear();
@@ -58,7 +88,7 @@ std::string llama_sampling_prev_str(llama_sampling_context * ctx_sampling, llama
5888
return result;
5989
}
6090

61-
std::string llama_sampling_print(const llama_sampling_params & params) {
91+
std::string llama_sampling_print(const gpt_sampling_params & params) {
6292
char result[1024];
6393

6494
snprintf(result, sizeof(result),
@@ -72,7 +102,7 @@ std::string llama_sampling_print(const llama_sampling_params & params) {
72102
return std::string(result);
73103
}
74104

75-
std::string llama_sampling_order_print(const llama_sampling_params & params) {
105+
std::string llama_sampling_order_print(const gpt_sampling_params & params) {
76106
std::string result = "CFG -> Penalties ";
77107
if (params.mirostat == 0) {
78108
for (auto sampler_type : params.samplers_sequence) {
@@ -176,7 +206,7 @@ static void sampler_queue(
176206
size_t min_keep) {
177207
llama_sampling * smpl = ctx_sampling->smpl;
178208

179-
const llama_sampling_params & params = ctx_sampling->params;
209+
const gpt_sampling_params & params = ctx_sampling->params;
180210

181211
const float temp = params.temp;
182212
const float dynatemp_range = params.dynatemp_range;
@@ -217,7 +247,7 @@ static llama_token llama_sampling_sample_impl(
217247
bool is_resampling) {
218248
llama_sampling * smpl = ctx_sampling->smpl;
219249

220-
const llama_sampling_params & params = ctx_sampling->params;
250+
const gpt_sampling_params & params = ctx_sampling->params;
221251

222252
const float temp = params.temp;
223253
const int mirostat = params.mirostat;
@@ -308,7 +338,7 @@ static llama_token_data_array llama_sampling_prepare_impl(
308338
std::vector<float> * original_logits) {
309339
llama_sampling * smpl = ctx_sampling->smpl;
310340

311-
const llama_sampling_params & params = ctx_sampling->params;
341+
const gpt_sampling_params & params = ctx_sampling->params;
312342

313343
const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
314344

@@ -332,13 +362,17 @@ static llama_token_data_array llama_sampling_prepare_impl(
332362
}
333363

334364
// apply params.logit_bias map
335-
for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
336-
logits[it->first] += it->second;
365+
for (const auto & logit_bias : params.logit_bias) {
366+
logits[logit_bias.token] += logit_bias.bias;
367+
}
368+
369+
if (params.ignore_eos) {
370+
logits[llama_token_eos(llama_get_model(ctx_main))] = -INFINITY;
337371
}
338372

339373
if (ctx_cfg) {
340374
float * logits_guidance = llama_get_logits_ith(ctx_cfg, idx);
341-
llama_sampling_apply_guidance(smpl, logits, logits_guidance, params.cfg_scale);
375+
llama_sampling_cfg(smpl, logits, logits_guidance, params.cfg_scale);
342376
}
343377

344378
cur.resize(n_vocab);
@@ -350,7 +384,7 @@ static llama_token_data_array llama_sampling_prepare_impl(
350384
llama_token_data_array cur_p = { cur.data(), cur.size(), false };
351385

352386
// apply penalties
353-
const auto& penalty_tokens = params.use_penalty_prompt_tokens ? params.penalty_prompt_tokens : prev;
387+
const auto & penalty_tokens = prev;
354388
const int penalty_tokens_used_size = std::min((int)penalty_tokens.size(), penalty_last_n);
355389
if (penalty_tokens_used_size) {
356390
const float nl_logit = logits[llama_token_nl(llama_get_model(ctx_main))];

common/sampling.h

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,7 @@
22

33
#include "llama.h"
44

5-
#include <random>
65
#include <string>
7-
#include <unordered_map>
86
#include <vector>
97

108
// sampler types
@@ -18,7 +16,8 @@ enum class llama_sampler_type : char {
1816
};
1917

2018
// sampling parameters
21-
typedef struct llama_sampling_params {
19+
typedef struct gpt_sampling_params {
20+
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampling_context
2221
int32_t n_prev = 64; // number of previous tokens to remember
2322
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
2423
int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens
@@ -38,7 +37,7 @@ typedef struct llama_sampling_params {
3837
float mirostat_tau = 5.00f; // target entropy
3938
float mirostat_eta = 0.10f; // learning rate
4039
bool penalize_nl = false; // consider newlines as a repeatable token
41-
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampling_context
40+
bool ignore_eos = false;
4241

4342
std::vector<llama_sampler_type> samplers_sequence = {
4443
llama_sampler_type::TOP_K,
@@ -56,17 +55,14 @@ typedef struct llama_sampling_params {
5655
std::string cfg_negative_prompt; // string to help guidance
5756
float cfg_scale = 1.f; // how strong is guidance
5857

59-
std::unordered_map<llama_token, float> logit_bias; // logit bias for specific tokens
60-
61-
std::vector<llama_token> penalty_prompt_tokens;
62-
bool use_penalty_prompt_tokens = false;
63-
} llama_sampling_params;
58+
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
59+
} gpt_sampling_params;
6460

6561
// general sampler context
6662
// TODO: move to llama.h
6763
struct llama_sampling_context {
6864
// parameters that will be used for sampling
69-
llama_sampling_params params;
65+
gpt_sampling_params params;
7066

7167
// mirostat sampler state
7268
float mirostat_mu;
@@ -80,10 +76,8 @@ struct llama_sampling_context {
8076
size_t n_valid; // Number of correct top tokens with correct probabilities.
8177
};
8278

83-
#include "common.h"
84-
8579
// Create a new sampling context instance.
86-
struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params, const struct llama_model * model);
80+
struct llama_sampling_context * llama_sampling_init(const struct gpt_sampling_params & params, const struct llama_model * model);
8781

8882
void llama_sampling_free(struct llama_sampling_context * ctx);
8983

@@ -102,10 +96,10 @@ llama_token llama_sampling_last(llama_sampling_context * ctx);
10296
std::string llama_sampling_prev_str(llama_sampling_context * ctx_sampling, llama_context * ctx_main, int n);
10397

10498
// Print sampling parameters into a string
105-
std::string llama_sampling_print(const llama_sampling_params & params);
99+
std::string llama_sampling_print(const gpt_sampling_params & params);
106100

107101
// Print sampling order into a string
108-
std::string llama_sampling_order_print(const llama_sampling_params & params);
102+
std::string llama_sampling_order_print(const gpt_sampling_params & params);
109103

110104
std::string llama_sampling_type_to_str(llama_sampler_type sampler_type);
111105

examples/batched.swift/Sources/main.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ defer {
5050
llama_free(context)
5151
}
5252

53-
let smpl = llama_sampling_init(model, nil, nil)
53+
let smpl = llama_sampling_init(model, llama_sampling_default_params())
5454
guard smpl != nil else {
5555
print("Failed to initialize sampling")
5656
exit(1)

examples/batched/batched.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ int main(int argc, char ** argv) {
6464
ctx_params.n_batch = std::max(n_predict, n_parallel);
6565

6666
llama_context * ctx = llama_new_context_with_model(model, ctx_params);
67-
llama_sampling * smpl = llama_sampling_init(model, nullptr, nullptr);
67+
llama_sampling * smpl = llama_sampling_init(model, llama_sampling_default_params());
6868

6969
if (ctx == NULL) {
7070
fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__);

examples/gritlm/gritlm.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ int main(int argc, char * argv[]) {
172172
// create generation context
173173
llama_context * ctx = llama_new_context_with_model(model, cparams);
174174

175-
llama_sampling * smpl = llama_sampling_init(model, nullptr, nullptr);
175+
llama_sampling * smpl = llama_sampling_init(model, llama_sampling_default_params());
176176

177177
// ### Embedding/Representation ###
178178
// samples taken from: https://github.com/ContextualAI/gritlm#basic

examples/infill/infill.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,14 +103,15 @@ static void sigint_handler(int signo) {
103103

104104
int main(int argc, char ** argv) {
105105
gpt_params params;
106-
llama_sampling_params & sparams = params.sparams;
107106
g_params = &params;
108107

109108
if (!gpt_params_parse(argc, argv, params)) {
110109
gpt_params_print_usage(argc, argv, params);
111110
return 1;
112111
}
113112

113+
auto & sparams = params.sparams;
114+
114115
#ifndef LOG_DISABLE_LOGS
115116
log_set_target(log_filename_generator("infill", "log"));
116117
LOG_TEE("Log start\n");

examples/llama.swiftui/llama.cpp.swift/LibLlama.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ actor LlamaContext {
4343
self.tokens_list = []
4444
self.batch = llama_batch_init(512, 0, 1)
4545
self.temporary_invalid_cchars = []
46-
self.sampling = llama_sampling_init(context, nil, nil);
46+
self.sampling = llama_sampling_init(context, llama_sampling_default_params())
4747
}
4848

4949
deinit {

examples/main/main.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ int main(int argc, char ** argv) {
137137
return 1;
138138
}
139139

140-
llama_sampling_params & sparams = params.sparams;
140+
auto & sparams = params.sparams;
141141

142142
#ifndef LOG_DISABLE_LOGS
143143
log_set_target(log_filename_generator("main", "log"));

0 commit comments

Comments
 (0)