Skip to content

Commit 97731bf

Browse files
committed
cont : simplify common/sampling
ggml-ci
1 parent 694c4b1 commit 97731bf

File tree

12 files changed

+68
-127
lines changed

12 files changed

+68
-127
lines changed

common/sampling.cpp

Lines changed: 35 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,35 @@
22

33
#include "common.h"
44

5+
std::string gpt_sampling_params::print_all() const {
6+
char result[1024];
7+
8+
snprintf(result, sizeof(result),
9+
"\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n"
10+
"\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, min_p = %.3f, typical_p = %.3f, temp = %.3f\n"
11+
"\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f",
12+
penalty_last_n, penalty_repeat, penalty_freq, penalty_present,
13+
top_k, tfs_z, top_p, min_p, typ_p, temp,
14+
mirostat, mirostat_eta, mirostat_tau);
15+
16+
return std::string(result);
17+
}
18+
19+
std::string gpt_sampling_params::print_samplers() const {
20+
std::string result = "CFG -> Penalties ";
21+
if (mirostat == 0) {
22+
for (const auto & sampler : samplers) {
23+
const auto name = llama_sampling_type_to_str(sampler);
24+
if (!name.empty()) {
25+
result += "-> " + name + " ";
26+
}
27+
}
28+
} else {
29+
result += "-> mirostat ";
30+
}
31+
32+
return result;
33+
}
534
struct llama_sampling_context * llama_sampling_init(const struct llama_model * model, const struct gpt_sampling_params & params) {
635
struct llama_sampling_context * result = new llama_sampling_context();
736

@@ -52,10 +81,6 @@ void llama_sampling_free(struct llama_sampling_context * ctx) {
5281
delete ctx;
5382
}
5483

55-
void llama_sampling_reset(llama_sampling_context * ctx) {
56-
llama_sampling_reset(ctx->smpl);
57-
}
58-
5984
void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst) {
6085
if (dst->smpl) {
6186
llama_sampling_free(dst->smpl);
@@ -89,38 +114,8 @@ std::string llama_sampling_prev_str(llama_sampling_context * ctx_sampling, llama
89114
return result;
90115
}
91116

92-
std::string llama_sampling_print(const gpt_sampling_params & params) {
93-
char result[1024];
94-
95-
snprintf(result, sizeof(result),
96-
"\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n"
97-
"\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, min_p = %.3f, typical_p = %.3f, temp = %.3f\n"
98-
"\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f",
99-
params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present,
100-
params.top_k, params.tfs_z, params.top_p, params.min_p, params.typ_p, params.temp,
101-
params.mirostat, params.mirostat_eta, params.mirostat_tau);
102-
103-
return std::string(result);
104-
}
105-
106-
std::string llama_sampling_order_print(const gpt_sampling_params & params) {
107-
std::string result = "CFG -> Penalties ";
108-
if (params.mirostat == 0) {
109-
for (auto sampler_type : params.samplers) {
110-
const auto sampler_type_name = llama_sampling_type_to_str(sampler_type);
111-
if (!sampler_type_name.empty()) {
112-
result += "-> " + sampler_type_name + " ";
113-
}
114-
}
115-
} else {
116-
result += "-> mirostat ";
117-
}
118-
119-
return result;
120-
}
121-
122-
char llama_sampling_type_to_chr(llama_sampler_type sampler_type) {
123-
switch (sampler_type) {
117+
char llama_sampling_type_to_chr(llama_sampler_type sampler) {
118+
switch (sampler) {
124119
case LLAMA_SAMPLER_TYPE_TOP_K: return 'k';
125120
case LLAMA_SAMPLER_TYPE_TFS_Z: return 'f';
126121
case LLAMA_SAMPLER_TYPE_TYPICAL_P: return 'y';
@@ -131,8 +126,8 @@ char llama_sampling_type_to_chr(llama_sampler_type sampler_type) {
131126
}
132127
}
133128

134-
std::string llama_sampling_type_to_str(llama_sampler_type sampler_type) {
135-
switch (sampler_type) {
129+
std::string llama_sampling_type_to_str(llama_sampler_type sampler) {
130+
switch (sampler) {
136131
case LLAMA_SAMPLER_TYPE_TOP_K: return "top_k";
137132
case LLAMA_SAMPLER_TYPE_TFS_Z: return "tfs_z";
138133
case LLAMA_SAMPLER_TYPE_TYPICAL_P: return "typ_p";
@@ -210,35 +205,15 @@ std::vector<llama_sampler_type> llama_sampling_types_from_chars(const std::strin
210205
return sampler_types;
211206
}
212207

213-
void llama_sampling_prepare(
214-
struct llama_sampling_context * ctx_sampling,
215-
struct llama_context * ctx_main,
216-
int idx) {
217-
llama_sampling_set_logits(ctx_sampling->smpl, llama_get_logits_ith(ctx_main, idx));
218-
}
219-
220-
static llama_token llama_sampling_sample(
221-
struct llama_sampling_context * ctx_sampling,
222-
struct llama_token_data_array * cur_p) {
223-
return llama_sampling_sample(ctx_sampling->smpl, cur_p);
224-
}
225-
226208
llama_token llama_sampling_sample(
227209
struct llama_sampling_context * ctx_sampling,
228210
struct llama_context * ctx_main,
229211
int idx) {
230-
llama_sampling_prepare(ctx_sampling, ctx_main, idx);
212+
llama_sampling_set_logits(ctx_sampling->smpl, llama_get_logits_ith(ctx_main, idx));
231213

232214
auto * cur_p = llama_sampling_get_candidates(ctx_sampling->smpl);
233215

234216
llama_sampling_grammar(ctx_sampling->smpl, cur_p);
235217

236-
return llama_sampling_sample(ctx_sampling, cur_p);
237-
}
238-
239-
void llama_sampling_accept(
240-
struct llama_sampling_context * ctx_sampling,
241-
llama_token id,
242-
bool apply_grammar) {
243-
llama_sampling_accept(ctx_sampling->smpl, id, apply_grammar);
218+
return llama_sampling_sample(ctx_sampling->smpl, cur_p);
244219
}

common/sampling.h

Lines changed: 6 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,12 @@ typedef struct gpt_sampling_params {
4242
std::string grammar; // optional BNF-like grammar to constrain sampling
4343

4444
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
45+
46+
// print the parameters into a string
47+
std::string print_all() const;
48+
49+
// print the samplers into a string
50+
std::string print_samplers() const;
4551
} gpt_sampling_params;
4652

4753
// general sampler context
@@ -58,11 +64,6 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_model * m
5864

5965
void llama_sampling_free(struct llama_sampling_context * ctx);
6066

61-
// Reset the sampler context
62-
// - clear prev tokens
63-
// - reset grammar
64-
void llama_sampling_reset(llama_sampling_context * ctx);
65-
6667
// Copy the sampler context
6768
void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst);
6869

@@ -72,50 +73,13 @@ llama_token llama_sampling_last(llama_sampling_context * ctx);
7273
// Get a string representation of the last accepted tokens
7374
std::string llama_sampling_prev_str(llama_sampling_context * ctx_sampling, llama_context * ctx_main, int n);
7475

75-
// Print sampling parameters into a string
76-
std::string llama_sampling_print(const gpt_sampling_params & params);
77-
78-
// Print sampling order into a string
79-
std::string llama_sampling_order_print(const gpt_sampling_params & params);
80-
8176
char llama_sampling_type_to_chr(llama_sampler_type sampler_type);
8277
std::string llama_sampling_type_to_str(llama_sampler_type sampler_type);
8378

8479
std::vector<llama_sampler_type> llama_sampling_types_from_names(const std::vector<std::string> & names, bool allow_alt_names);
8580
std::vector<llama_sampler_type> llama_sampling_types_from_chars(const std::string & names_string);
8681

87-
// Prepares and adjusts the set of token candidates for sampling based on penalties, biases, and sampling parameters.
88-
void llama_sampling_prepare(
89-
struct llama_sampling_context * ctx_sampling,
90-
struct llama_context * ctx_main,
91-
int idx);
92-
93-
// this is a common sampling function used across the examples for convenience
94-
// it can serve as a starting point for implementing your own sampling function
95-
// Note: When using multiple sequences, it is the caller's responsibility to call
96-
// llama_sampling_reset when a sequence ends
97-
//
98-
// required:
99-
// - ctx_main: context to use for sampling
100-
// - ctx_sampling: sampling-specific context
101-
//
102-
// optional:
103-
// - idx: sample from llama_get_logits_ith(ctx, idx)
104-
//
105-
// returns:
106-
// - token: sampled token
107-
// - candidates: vector of candidate tokens
108-
//
109-
//llama_token llama_sampling_sample(
110-
// struct llama_sampling_context * ctx_sampling,
111-
// struct llama_token_data_array * cur_p);
112-
11382
llama_token llama_sampling_sample(
11483
struct llama_sampling_context * ctx_sampling,
11584
struct llama_context * ctx_main,
11685
int idx = -1);
117-
118-
void llama_sampling_accept(
119-
struct llama_sampling_context * ctx_sampling,
120-
llama_token id,
121-
bool apply_grammar);

examples/infill/infill.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@ int main(int argc, char ** argv) {
301301
LOG_TEE("Input suffix: '%s'\n", params.input_suffix.c_str());
302302
}
303303
}
304-
LOG_TEE("sampling: \n%s\n", llama_sampling_print(sparams).c_str());
304+
LOG_TEE("sampling: \n%s\n", sparams.print_all().c_str());
305305
LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
306306
LOG_TEE("\n\n");
307307

@@ -419,7 +419,7 @@ int main(int argc, char ** argv) {
419419
if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
420420
const llama_token id = llama_sampling_sample(ctx_sampling, ctx);
421421

422-
llama_sampling_accept(ctx_sampling, id, true);
422+
llama_sampling_accept(ctx_sampling->smpl, id, true);
423423

424424
// LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev.to_vector()).c_str());
425425

@@ -440,7 +440,7 @@ int main(int argc, char ** argv) {
440440

441441
// push the prompt in the sampling context in order to apply repetition penalties later
442442
// for the prompt, we don't apply grammar rules
443-
llama_sampling_accept(ctx_sampling, embd_inp[n_consumed], false);
443+
llama_sampling_accept(ctx_sampling->smpl, embd_inp[n_consumed], false);
444444

445445
++n_consumed;
446446
if ((int) embd.size() >= params.n_batch) {
@@ -611,7 +611,7 @@ int main(int argc, char ** argv) {
611611

612612
if (n_past > 0) {
613613
if (is_interacting) {
614-
llama_sampling_reset(ctx_sampling);
614+
llama_sampling_reset(ctx_sampling->smpl);
615615
}
616616
is_interacting = false;
617617
}

examples/llava/llava-cli.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ static const char * sample(struct llama_sampling_context * ctx_sampling,
4444
struct llama_context * ctx_llama,
4545
int * n_past) {
4646
const llama_token id = llama_sampling_sample(ctx_sampling, ctx_llama);
47-
llama_sampling_accept(ctx_sampling, id, true);
47+
llama_sampling_accept(ctx_sampling->smpl, id, true);
4848
static std::string ret;
4949
if (llama_token_is_eog(llama_get_model(ctx_llama), id)) {
5050
ret = "</s>";

examples/llava/minicpmv-cli.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ static const char * sample(struct llama_sampling_context * ctx_sampling,
167167
struct llama_context * ctx_llama,
168168
int * n_past) {
169169
const llama_token id = llama_sampling_sample(ctx_sampling, ctx_llama);
170-
llama_sampling_accept(ctx_sampling, id, true);
170+
llama_sampling_accept(ctx_sampling->smpl, id, true);
171171
static std::string ret;
172172
if (llama_token_is_eog(llama_get_model(ctx_llama), id)) {
173173
ret = "</s>";

examples/lookahead/lookahead.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ int main(int argc, char ** argv) {
160160
{
161161
id = llama_sampling_sample(ctx_sampling, ctx, 0);
162162

163-
llama_sampling_accept(ctx_sampling, id, true);
163+
llama_sampling_accept(ctx_sampling->smpl, id, true);
164164

165165
{
166166
const std::string token_str = llama_token_to_piece(ctx, id);
@@ -285,7 +285,7 @@ int main(int argc, char ** argv) {
285285
// sample the next token
286286
id = llama_sampling_sample(ctx_sampling, ctx, i_batch);
287287

288-
llama_sampling_accept(ctx_sampling, id, true);
288+
llama_sampling_accept(ctx_sampling->smpl, id, true);
289289

290290
// print
291291
{

examples/lookup/lookup.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ int main(int argc, char ** argv){
130130
// sample from the target model
131131
llama_token id = llama_sampling_sample(ctx_sampling, ctx, i_dft);
132132

133-
llama_sampling_accept(ctx_sampling, id, true);
133+
llama_sampling_accept(ctx_sampling->smpl, id, true);
134134

135135
const std::string token_str = llama_token_to_piece(ctx, id);
136136

examples/main/main.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -426,8 +426,8 @@ int main(int argc, char ** argv) {
426426
}
427427
}
428428
}
429-
LOG_TEE("sampling: \n%s\n", llama_sampling_print(sparams).c_str());
430-
LOG_TEE("sampling order: \n%s\n", llama_sampling_order_print(sparams).c_str());
429+
LOG_TEE("sampling params: \n%s\n", sparams.print_all().c_str());
430+
LOG_TEE("sampling order: \n%s\n", sparams.print_samplers().c_str());
431431
LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
432432

433433
// group-attention state
@@ -652,7 +652,7 @@ int main(int argc, char ** argv) {
652652

653653
const llama_token id = llama_sampling_sample(ctx_sampling, ctx);
654654

655-
llama_sampling_accept(ctx_sampling, id, /* apply_grammar= */ true);
655+
llama_sampling_accept(ctx_sampling->smpl, id, /* apply_grammar= */ true);
656656

657657
// LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev.to_vector()).c_str());
658658

@@ -673,7 +673,7 @@ int main(int argc, char ** argv) {
673673

674674
// push the prompt in the sampling context in order to apply repetition penalties later
675675
// for the prompt, we don't apply grammar rules
676-
llama_sampling_accept(ctx_sampling, embd_inp[n_consumed], /* apply_grammar= */ false);
676+
llama_sampling_accept(ctx_sampling->smpl, embd_inp[n_consumed], /* apply_grammar= */ false);
677677

678678
++n_consumed;
679679
if ((int) embd.size() >= params.n_batch) {
@@ -872,7 +872,7 @@ int main(int argc, char ** argv) {
872872

873873
if (n_past > 0) {
874874
if (is_interacting) {
875-
llama_sampling_reset(ctx_sampling);
875+
llama_sampling_reset(ctx_sampling->smpl);
876876
}
877877
is_interacting = false;
878878
}

examples/parallel/parallel.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ int main(int argc, char ** argv) {
253253
client.prompt = client.input + "\nAssistant:";
254254
client.response = "";
255255

256-
llama_sampling_reset(client.ctx_sampling);
256+
llama_sampling_reset(client.ctx_sampling->smpl);
257257

258258
// do not prepend BOS because we have a system prompt!
259259
std::vector<llama_token> tokens_prompt;
@@ -343,7 +343,7 @@ int main(int argc, char ** argv) {
343343

344344
const llama_token id = llama_sampling_sample(client.ctx_sampling, ctx, client.i_batch - i);
345345

346-
llama_sampling_accept(client.ctx_sampling, id, true);
346+
llama_sampling_accept(client.ctx_sampling->smpl, id, true);
347347

348348
if (client.n_decoded == 1) {
349349
// start measuring generation time after the first token to make sure all concurrent clients

examples/server/server.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2098,7 +2098,7 @@ struct server_context {
20982098
GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx);
20992099
}
21002100

2101-
llama_sampling_reset(slot.ctx_sampling);
2101+
llama_sampling_reset(slot.ctx_sampling->smpl);
21022102

21032103
if (!slot.params.cache_prompt) {
21042104
slot.n_past_se = 0;
@@ -2111,7 +2111,7 @@ struct server_context {
21112111

21122112
// push the prompt into the sampling context (do not apply grammar)
21132113
for (int i = 0; i < slot.n_past; ++i) {
2114-
llama_sampling_accept(slot.ctx_sampling, slot.cache_tokens[i], false);
2114+
llama_sampling_accept(slot.ctx_sampling->smpl, slot.cache_tokens[i], false);
21152115
}
21162116
}
21172117
}
@@ -2164,7 +2164,7 @@ struct server_context {
21642164
slot.n_past_se = 0;
21652165
slot.ga_i = 0;
21662166
// TODO: is the system prompt ever in the sampling context?
2167-
llama_sampling_reset(slot.ctx_sampling);
2167+
llama_sampling_reset(slot.ctx_sampling->smpl);
21682168
}
21692169

21702170
// remove the non-common part from the cache
@@ -2343,7 +2343,7 @@ struct server_context {
23432343
completion_token_output result;
23442344
const llama_token id = llama_sampling_sample(slot.ctx_sampling, ctx, slot.i_batch - i);
23452345

2346-
llama_sampling_accept(slot.ctx_sampling, id, true);
2346+
llama_sampling_accept(slot.ctx_sampling->smpl, id, true);
23472347

23482348
slot.n_decoded += 1;
23492349
if (slot.n_decoded == 1) {

0 commit comments

Comments
 (0)