Skip to content

Commit 5d4c807

Browse files
committed
minor : clean-up + comments
ggml-ci
1 parent 6420268 commit 5d4c807

File tree

8 files changed

+128
-113
lines changed

8 files changed

+128
-113
lines changed

common/common.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,7 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
360360
if (arg.compare(0, arg_prefix.size(), arg_prefix) == 0) {
361361
std::replace(arg.begin(), arg.end(), '_', '-');
362362
}
363+
363364
bool invalid_param = false;
364365
if (!gpt_params_find_arg(argc, argv, arg, params, i, invalid_param)) {
365366
throw std::invalid_argument("error: unknown argument: " + arg);

common/sampling.cpp

Lines changed: 79 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -31,44 +31,41 @@ std::string gpt_sampling_params::print_samplers() const {
3131

3232
return result;
3333
}
34+
3435
struct llama_sampling * llama_sampling_init(const struct llama_model * model, const struct gpt_sampling_params & params) {
35-
struct llama_sampling * result = nullptr;
36-
37-
{
38-
auto lparams = llama_sampling_default_params();
39-
40-
lparams.seed = params.seed;
41-
lparams.n_prev = params.n_prev;
42-
lparams.n_probs = params.n_probs;
43-
lparams.min_keep = params.min_keep;
44-
lparams.top_k = params.top_k;
45-
lparams.top_p = params.top_p;
46-
lparams.min_p = params.min_p;
47-
lparams.tfs_z = params.tfs_z;
48-
lparams.typ_p = params.typ_p;
49-
lparams.temp = params.temp;
50-
lparams.dynatemp_range = params.dynatemp_range;
51-
lparams.dynatemp_exponent = params.dynatemp_exponent;
52-
lparams.penalty_last_n = params.penalty_last_n;
53-
lparams.penalty_repeat = params.penalty_repeat;
54-
lparams.penalty_freq = params.penalty_freq;
55-
lparams.penalty_present = params.penalty_present;
56-
lparams.mirostat = params.mirostat;
57-
lparams.mirostat_tau = params.mirostat_tau;
58-
lparams.mirostat_eta = params.mirostat_eta;
59-
lparams.penalize_nl = params.penalize_nl;
60-
lparams.ignore_eos = params.ignore_eos;
61-
62-
lparams.n_samplers = params.samplers.size();
63-
for (int i = 0; i < lparams.n_samplers; i++) {
64-
lparams.samplers[i] = params.samplers[i];
65-
}
36+
llama_sampling_params lparams = llama_sampling_default_params();
37+
38+
lparams.seed = params.seed;
39+
lparams.n_prev = params.n_prev;
40+
lparams.n_probs = params.n_probs;
41+
lparams.min_keep = params.min_keep;
42+
lparams.top_k = params.top_k;
43+
lparams.top_p = params.top_p;
44+
lparams.min_p = params.min_p;
45+
lparams.tfs_z = params.tfs_z;
46+
lparams.typ_p = params.typ_p;
47+
lparams.temp = params.temp;
48+
lparams.dynatemp_range = params.dynatemp_range;
49+
lparams.dynatemp_exponent = params.dynatemp_exponent;
50+
lparams.penalty_last_n = params.penalty_last_n;
51+
lparams.penalty_repeat = params.penalty_repeat;
52+
lparams.penalty_freq = params.penalty_freq;
53+
lparams.penalty_present = params.penalty_present;
54+
lparams.mirostat = params.mirostat;
55+
lparams.mirostat_tau = params.mirostat_tau;
56+
lparams.mirostat_eta = params.mirostat_eta;
57+
lparams.penalize_nl = params.penalize_nl;
58+
lparams.ignore_eos = params.ignore_eos;
59+
60+
lparams.n_samplers = params.samplers.size();
61+
for (int i = 0; i < lparams.n_samplers; i++) {
62+
lparams.samplers[i] = params.samplers[i];
63+
}
6664

67-
result = llama_sampling_init(model, lparams);
65+
struct llama_sampling * result = llama_sampling_init(model, lparams);
6866

69-
llama_sampling_set_grammar (result, params.grammar.c_str(), "root");
70-
llama_sampling_set_logit_bias(result, params.logit_bias.size(), params.logit_bias.data());
71-
}
67+
llama_sampling_set_grammar (result, params.grammar.c_str(), "root");
68+
llama_sampling_set_logit_bias(result, params.logit_bias.size(), params.logit_bias.data());
7269

7370
return result;
7471
}
@@ -81,6 +78,35 @@ void llama_sampling_cp(llama_sampling * src, llama_sampling * dst) {
8178
dst = llama_sampling_cp(src);
8279
}
8380

81+
llama_token llama_sampling_sample(
82+
struct llama_sampling * smpl,
83+
struct llama_context * ctx,
84+
int idx) {
85+
llama_sampling_set_logits(smpl, llama_get_logits_ith(ctx, idx));
86+
87+
// first, sample the token without any grammar constraints
88+
const llama_token id = llama_sampling_sample(smpl, nullptr);
89+
90+
// create an array with a single token data element for the sampled id
91+
llama_token_data single_token_data = { id, 1.0f, 0.0f };
92+
llama_token_data_array single_token_data_array = { &single_token_data, 1, false };
93+
94+
llama_sampling_grammar(smpl, &single_token_data_array);
95+
96+
// check if the token is valid according to the grammar by seeing if its logit has been set to -INFINITY
97+
const bool is_valid = single_token_data_array.data[0].logit != -INFINITY;
98+
if (is_valid) {
99+
return id;
100+
}
101+
102+
// if the token is not valid, sample again, after applying the grammar constraints
103+
llama_sampling_set_logits(smpl, llama_get_logits_ith(ctx, idx));
104+
105+
llama_sampling_grammar(smpl, nullptr);
106+
107+
return llama_sampling_sample(smpl, nullptr);
108+
}
109+
84110
std::string llama_sampling_prev_str(llama_sampling * smpl, llama_context * ctx_main, int n) {
85111
n = std::min(n, llama_sampling_n_prev(smpl));
86112

@@ -152,27 +178,27 @@ std::vector<llama_sampler_type> llama_sampling_types_from_names(const std::vecto
152178
{ "temp", LLAMA_SAMPLER_TYPE_TEMPERATURE },
153179
};
154180

155-
std::vector<llama_sampler_type> sampler_types;
156-
sampler_types.reserve(names.size());
181+
std::vector<llama_sampler_type> samplers;
182+
samplers.reserve(names.size());
157183

158184
for (const auto & name : names) {
159-
auto sampler_item = sampler_canonical_name_map.find(name);
160-
if (sampler_item != sampler_canonical_name_map.end()) {
161-
sampler_types.push_back(sampler_item->second);
185+
auto sampler = sampler_canonical_name_map.find(name);
186+
if (sampler != sampler_canonical_name_map.end()) {
187+
samplers.push_back(sampler->second);
162188
} else {
163189
if (allow_alt_names) {
164-
sampler_item = sampler_alt_name_map.find(name);
165-
if (sampler_item != sampler_alt_name_map.end()) {
166-
sampler_types.push_back(sampler_item->second);
190+
sampler = sampler_alt_name_map.find(name);
191+
if (sampler != sampler_alt_name_map.end()) {
192+
samplers.push_back(sampler->second);
167193
}
168194
}
169195
}
170196
}
171197

172-
return sampler_types;
198+
return samplers;
173199
}
174200

175-
std::vector<llama_sampler_type> llama_sampling_types_from_chars(const std::string & names_string) {
201+
std::vector<llama_sampler_type> llama_sampling_types_from_chars(const std::string & chars) {
176202
std::unordered_map<char, llama_sampler_type> sampler_name_map {
177203
{ llama_sampling_type_to_chr(LLAMA_SAMPLER_TYPE_TOP_K), LLAMA_SAMPLER_TYPE_TOP_K },
178204
{ llama_sampling_type_to_chr(LLAMA_SAMPLER_TYPE_TFS_Z), LLAMA_SAMPLER_TYPE_TFS_Z },
@@ -182,42 +208,15 @@ std::vector<llama_sampler_type> llama_sampling_types_from_chars(const std::strin
182208
{ llama_sampling_type_to_chr(LLAMA_SAMPLER_TYPE_TEMPERATURE), LLAMA_SAMPLER_TYPE_TEMPERATURE }
183209
};
184210

185-
std::vector<llama_sampler_type> sampler_types;
186-
sampler_types.reserve(names_string.size());
187-
for (const auto & c : names_string) {
188-
const auto sampler_item = sampler_name_map.find(c);
189-
if (sampler_item != sampler_name_map.end()) {
190-
sampler_types.push_back(sampler_item->second);
191-
}
192-
}
193-
return sampler_types;
194-
}
211+
std::vector<llama_sampler_type> samplers;
212+
samplers.reserve(chars.size());
195213

196-
llama_token llama_sampling_sample(
197-
struct llama_sampling * smpl,
198-
struct llama_context * ctx,
199-
int idx) {
200-
llama_sampling_set_logits(smpl, llama_get_logits_ith(ctx, idx));
201-
202-
// first, sample the token without any grammar constraints
203-
auto id = llama_sampling_sample(smpl, nullptr);
204-
205-
// create an array with a single token data element for the sampled id
206-
llama_token_data single_token_data = {id, 1.0f, 0.0f};
207-
llama_token_data_array single_token_data_array = { &single_token_data, 1, false };
208-
209-
llama_sampling_grammar(smpl, &single_token_data_array);
210-
211-
// check if the token is valid according to the grammar by seeing if its logit has been set to -INFINITY
212-
const bool is_valid = single_token_data_array.data[0].logit != -INFINITY;
213-
if (is_valid) {
214-
return id;
214+
for (const auto & c : chars) {
215+
const auto sampler = sampler_name_map.find(c);
216+
if (sampler != sampler_name_map.end()) {
217+
samplers.push_back(sampler->second);
218+
}
215219
}
216220

217-
// if the token is not valid, sample again, after applying the grammar constraints
218-
llama_sampling_set_logits(smpl, llama_get_logits_ith(ctx, idx));
219-
220-
llama_sampling_grammar(smpl, nullptr);
221-
222-
return llama_sampling_sample(smpl, nullptr);
221+
return samplers;
223222
}

common/sampling.h

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ typedef struct gpt_sampling_params {
3939
LLAMA_SAMPLER_TYPE_TEMPERATURE
4040
};
4141

42-
std::string grammar; // optional BNF-like grammar to constrain sampling
42+
std::string grammar; // optional BNF-like grammar to constrain sampling
4343

4444
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
4545

@@ -55,16 +55,25 @@ struct llama_sampling * llama_sampling_init(const struct llama_model * model, co
5555

5656
void llama_sampling_cp(llama_sampling * src, llama_sampling * dst);
5757

58+
// common sampling implementation:
59+
//
60+
// - set logits
61+
// - apply the configured sampling constraints
62+
// - check if the token fits the grammar (if any)
63+
// - if not: resample by first applying the grammar constraints and then sampling again (slower path)
64+
//
65+
llama_token llama_sampling_sample(
66+
struct llama_sampling * smpl,
67+
struct llama_context * ctx,
68+
int idx);
69+
70+
// helpers
71+
5872
// get a string representation of the last accepted tokens
5973
std::string llama_sampling_prev_str(llama_sampling * smpl, llama_context * ctx, int n);
6074

6175
char llama_sampling_type_to_chr(enum llama_sampler_type sampler_type);
6276
std::string llama_sampling_type_to_str(enum llama_sampler_type sampler_type);
6377

6478
std::vector<enum llama_sampler_type> llama_sampling_types_from_names(const std::vector<std::string> & names, bool allow_alt_names);
65-
std::vector<enum llama_sampler_type> llama_sampling_types_from_chars(const std::string & names_string);
66-
67-
llama_token llama_sampling_sample(
68-
struct llama_sampling * smpl,
69-
struct llama_context * ctx,
70-
int idx);
79+
std::vector<enum llama_sampler_type> llama_sampling_types_from_chars(const std::string & chars);

examples/gritlm/gritlm.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -109,14 +109,18 @@ static std::string generate(llama_context * ctx, llama_sampling * smpl, const st
109109

110110
while (true) {
111111
llama_batch_clear(bat);
112-
auto n_inputs = (int32_t)inputs.size();
113-
for (int32_t i = 0; i < n_inputs; i++) {
114-
llama_batch_add(bat, inputs[i], i_current_token++, { 0 }, i == n_inputs - 1);
112+
{
113+
const int32_t n_inputs = inputs.size();
114+
115+
for (int32_t i = 0; i < n_inputs; i++) {
116+
llama_batch_add(bat, inputs[i], i_current_token++, { 0 }, i == n_inputs - 1);
117+
}
115118
}
116119
inputs.clear();
117120

118121
llama_decode(ctx, bat);
119-
auto * logits = llama_get_logits_ith(ctx, bat.n_tokens - 1);
122+
123+
const auto * logits = llama_get_logits_ith(ctx, bat.n_tokens - 1);
120124

121125
llama_sampling_set_logits(smpl, logits);
122126

examples/llama.android/llama/src/main/cpp/llama-android.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,8 @@ Java_android_llama_cpp_LLamaAndroid_new_1context(JNIEnv *env, jobject, jlong jmo
120120
LOGi("Using %d threads", n_threads);
121121

122122
llama_context_params ctx_params = llama_context_default_params();
123-
ctx_params.n_ctx = 2048;
123+
124+
ctx_params.n_ctx = 2048;
124125
ctx_params.n_threads = n_threads;
125126
ctx_params.n_threads_batch = n_threads;
126127

@@ -393,8 +394,7 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop(
393394
if (!la_int_var_value) la_int_var_value = env->GetMethodID(la_int_var, "getValue", "()I");
394395
if (!la_int_var_inc) la_int_var_inc = env->GetMethodID(la_int_var, "inc", "()V");
395396

396-
auto n_vocab = llama_n_vocab(model);
397-
auto logits = llama_get_logits_ith(context, batch->n_tokens - 1);
397+
const auto * logits = llama_get_logits_ith(context, batch->n_tokens - 1);
398398

399399
llama_sampling_set_logits(sampling, logits);
400400

examples/server/server.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2356,6 +2356,8 @@ struct server_context {
23562356

23572357
const auto * cur_p = llama_sampling_get_candidates(slot.smpl);
23582358

2359+
// TODO: this logic might have been broken during https://github.com/ggerganov/llama.cpp/pull/8643
2360+
// fix if necessary
23592361
for (size_t i = 0; i < (size_t) slot.sparams.n_probs; ++i) {
23602362
result.probs.push_back({
23612363
cur_p->data[i].id,

include/llama.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -439,8 +439,8 @@ extern "C" {
439439
// Helpers for getting default parameters
440440
LLAMA_API struct llama_model_params llama_model_default_params(void);
441441
LLAMA_API struct llama_context_params llama_context_default_params(void);
442-
LLAMA_API struct llama_model_quantize_params llama_model_quantize_default_params(void);
443442
LLAMA_API struct llama_sampling_params llama_sampling_default_params(void);
443+
LLAMA_API struct llama_model_quantize_params llama_model_quantize_default_params(void);
444444

445445
// Initialize the llama + ggml backend
446446
// If numa is true, use NUMA optimizations

src/llama.cpp

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17399,24 +17399,6 @@ struct llama_context_params llama_context_default_params() {
1739917399
return result;
1740017400
}
1740117401

17402-
struct llama_model_quantize_params llama_model_quantize_default_params() {
17403-
struct llama_model_quantize_params result = {
17404-
/*.nthread =*/ 0,
17405-
/*.ftype =*/ LLAMA_FTYPE_MOSTLY_Q5_1,
17406-
/*.output_tensor_type =*/ GGML_TYPE_COUNT,
17407-
/*.token_embedding_type =*/ GGML_TYPE_COUNT,
17408-
/*.allow_requantize =*/ false,
17409-
/*.quantize_output_tensor =*/ true,
17410-
/*.only_copy =*/ false,
17411-
/*.pure =*/ false,
17412-
/*.keep_split =*/ false,
17413-
/*.imatrix =*/ nullptr,
17414-
/*.kv_overrides =*/ nullptr,
17415-
};
17416-
17417-
return result;
17418-
}
17419-
1742017402
struct llama_sampling_params llama_sampling_default_params() {
1742117403
struct llama_sampling_params result = {
1742217404
/*.seed =*/ LLAMA_DEFAULT_SEED,
@@ -17447,6 +17429,24 @@ struct llama_sampling_params llama_sampling_default_params() {
1744717429
return result;
1744817430
}
1744917431

17432+
struct llama_model_quantize_params llama_model_quantize_default_params() {
17433+
struct llama_model_quantize_params result = {
17434+
/*.nthread =*/ 0,
17435+
/*.ftype =*/ LLAMA_FTYPE_MOSTLY_Q5_1,
17436+
/*.output_tensor_type =*/ GGML_TYPE_COUNT,
17437+
/*.token_embedding_type =*/ GGML_TYPE_COUNT,
17438+
/*.allow_requantize =*/ false,
17439+
/*.quantize_output_tensor =*/ true,
17440+
/*.only_copy =*/ false,
17441+
/*.pure =*/ false,
17442+
/*.keep_split =*/ false,
17443+
/*.imatrix =*/ nullptr,
17444+
/*.kv_overrides =*/ nullptr,
17445+
};
17446+
17447+
return result;
17448+
}
17449+
1745017450
size_t llama_max_devices(void) {
1745117451
#if defined(GGML_USE_RPC)
1745217452
return GGML_RPC_MAX_SERVERS;

0 commit comments

Comments
 (0)