Skip to content

Commit 23f0802

Browse files
committed
sampling : convert mirostat samplers to constraints
ggml-ci
1 parent 697a20f commit 23f0802

File tree

5 files changed

+305
-215
lines changed

5 files changed

+305
-215
lines changed

common/sampling.cpp

Lines changed: 45 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,6 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st
4747

4848
lparams.seed = params.seed;
4949
lparams.n_prev = params.n_prev;
50-
lparams.mirostat = params.mirostat;
51-
lparams.mirostat_tau = params.mirostat_tau;
52-
lparams.mirostat_eta = params.mirostat_eta;
5350

5451
auto * result = new gpt_sampler {
5552
/* .params = */ params,
@@ -69,29 +66,39 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st
6966
/* .smpl = */ llama_sampler_init(model, lparams)
7067
};
7168

72-
for (const auto & cnstr : params.constraints) {
73-
switch (cnstr) {
74-
case GPT_CONSTRAINT_TYPE_TOP_K:
75-
llama_sampler_add_constraint(result->smpl, llama_constraint_init_top_k (params.top_k, params.min_keep));
76-
break;
77-
case GPT_CONSTRAINT_TYPE_TOP_P:
78-
llama_sampler_add_constraint(result->smpl, llama_constraint_init_top_p (params.top_p, params.min_keep));
79-
break;
80-
case GPT_CONSTRAINT_TYPE_MIN_P:
81-
llama_sampler_add_constraint(result->smpl, llama_constraint_init_min_p (params.min_p, params.min_keep));
82-
break;
83-
case GPT_CONSTRAINT_TYPE_TFS_Z:
84-
llama_sampler_add_constraint(result->smpl, llama_constraint_init_tail_free(params.tfs_z, params.min_keep));
85-
break;
86-
case GPT_CONSTRAINT_TYPE_TYPICAL_P:
87-
llama_sampler_add_constraint(result->smpl, llama_constraint_init_typical (params.typ_p, params.min_keep));
88-
break;
89-
case GPT_CONSTRAINT_TYPE_TEMPERATURE:
90-
llama_sampler_add_constraint(result->smpl, llama_constraint_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
91-
break;
92-
default:
93-
GGML_ASSERT(false && "unknown constraint type");
69+
if (params.mirostat == 0) {
70+
for (const auto & cnstr : params.constraints) {
71+
switch (cnstr) {
72+
case GPT_CONSTRAINT_TYPE_TOP_K:
73+
llama_sampler_add_constraint(result->smpl, llama_constraint_init_top_k (params.top_k, params.min_keep));
74+
break;
75+
case GPT_CONSTRAINT_TYPE_TOP_P:
76+
llama_sampler_add_constraint(result->smpl, llama_constraint_init_top_p (params.top_p, params.min_keep));
77+
break;
78+
case GPT_CONSTRAINT_TYPE_MIN_P:
79+
llama_sampler_add_constraint(result->smpl, llama_constraint_init_min_p (params.min_p, params.min_keep));
80+
break;
81+
case GPT_CONSTRAINT_TYPE_TFS_Z:
82+
llama_sampler_add_constraint(result->smpl, llama_constraint_init_tail_free(params.tfs_z, params.min_keep));
83+
break;
84+
case GPT_CONSTRAINT_TYPE_TYPICAL_P:
85+
llama_sampler_add_constraint(result->smpl, llama_constraint_init_typical (params.typ_p, params.min_keep));
86+
break;
87+
case GPT_CONSTRAINT_TYPE_TEMPERATURE:
88+
llama_sampler_add_constraint(result->smpl, llama_constraint_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
89+
break;
90+
default:
91+
GGML_ASSERT(false && "unknown constraint type");
92+
}
9493
}
94+
} else if (params.mirostat == 1) {
95+
llama_sampler_add_constraint(result->smpl, llama_constraint_init_temp(params.temp));
96+
llama_sampler_add_constraint(result->smpl, llama_constraint_init_mirostat(model, params.mirostat_tau, params.mirostat_eta));
97+
} else if (params.mirostat == 2) {
98+
llama_sampler_add_constraint(result->smpl, llama_constraint_init_temp(params.temp));
99+
llama_sampler_add_constraint(result->smpl, llama_constraint_init_mirostat_v2(params.mirostat_tau, params.mirostat_eta));
100+
} else {
101+
GGML_ASSERT(false && "unknown mirostat version");
95102
}
96103

97104
return result;
@@ -153,7 +160,6 @@ static llama_token gpt_sampler_sample(
153160
struct llama_sampler * smpl,
154161
struct llama_token_data_array * cur_p,
155162
float temp,
156-
int mirostat,
157163
int n_probs) {
158164
llama_token res = 0;
159165

@@ -167,24 +173,20 @@ static llama_token gpt_sampler_sample(
167173
// apply all sampling constraints and then sample
168174
llama_sampler_apply(smpl, cur_p);
169175

170-
if (mirostat != 0) {
171-
res = llama_sampler_sample_mirostat(smpl, cur_p);
172-
} else {
173-
res = llama_sampler_sample_dist(smpl, cur_p);
176+
res = llama_sampler_sample_dist(smpl, cur_p);
174177

175-
//{
176-
// const int n_top = 10;
177-
// LOG("top %d candidates:\n", n_top);
178+
//{
179+
// const int n_top = 10;
180+
// LOG("top %d candidates:\n", n_top);
178181

179-
// for (int i = 0; i < n_top; i++) {
180-
// const llama_token id = cur_p.data[i].id;
181-
// (void)id; // To avoid a warning that id is unused when logging is disabled.
182-
// LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(smpl, id).c_str(), cur_p.data[i].p);
183-
// }
184-
//}
182+
// for (int i = 0; i < n_top; i++) {
183+
// const llama_token id = cur_p.data[i].id;
184+
// (void)id; // To avoid a warning that id is unused when logging is disabled.
185+
// LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(smpl, id).c_str(), cur_p.data[i].p);
186+
// }
187+
//}
185188

186-
//LOG("sampled token: %5d: '%s'\n", res, llama_token_to_piece(smpl, res).c_str());
187-
}
189+
//LOG("sampled token: %5d: '%s'\n", res, llama_token_to_piece(smpl, res).c_str());
188190
}
189191

190192
return res;
@@ -208,7 +210,7 @@ llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context
208210
llama_constraint_apply(pnlt, cur_p);
209211

210212
// first, sample the token without any grammar constraints
211-
const llama_token id = gpt_sampler_sample(smpl, nullptr, params.temp, params.mirostat, params.n_probs);
213+
const llama_token id = gpt_sampler_sample(smpl, nullptr, params.temp, params.n_probs);
212214

213215
// check if it the sampled token fits the grammar
214216
{
@@ -231,7 +233,7 @@ llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context
231233
llama_constraint_apply(pnlt, cur_p);
232234
llama_constraint_apply(grmr, cur_p);
233235

234-
return gpt_sampler_sample(smpl, cur_p, params.temp, params.mirostat, params.n_probs);
236+
return gpt_sampler_sample(smpl, cur_p, params.temp, params.n_probs);
235237
}
236238

237239
void gpt_sampler_apply_grammar(struct gpt_sampler * gsmpl, llama_token_data_array * cur_p) {

include/llama.h

Lines changed: 29 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -367,16 +367,18 @@ extern "C" {
367367
float bias;
368368
} llama_logit_bias;
369369

370+
enum llama_sampler_type {
371+
LLAMA_SAMPLER_TYPE_GREEDY = 0,
372+
LLAMA_SAMPLER_TYPE_DIST = 1,
373+
};
374+
370375
typedef struct llama_sampler_params {
371376
uint32_t seed; // the seed used to initialize the rng of the sampler
372377

373378
int32_t n_prev; // size of ring buffer to keep previous accepted tokens (needed for llama_sampler_prev_ API)
374379

375-
int32_t mirostat; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
376-
float mirostat_tau; // target entropy
377-
float mirostat_eta; // learning rate
378-
379-
// TODO: add type of sampler: greedy, dist, mirostat, etc.
380+
// TODO: will be used by the llama_decode_with_sampler() API in the future
381+
enum llama_sampler_type type;
380382
} llama_sampler_params;
381383

382384
// performance timing information
@@ -1003,17 +1005,18 @@ extern "C" {
10031005
//
10041006
// - Samplers
10051007
// The llama_sampler samples a token based on the candidate token probabilities. Before the actual sampling, the
1006-
// sampler can apply a sequence of constraints to the candidate tokens.
1008+
// sampler can apply a sequence of constraints in order to modify the probabilities of the candidates.
10071009
//
10081010
// The llama_sampler object contains the entire sampling information:
10091011
//
10101012
// - RNG state (seed and generator)
10111013
// - Custom set of constraints (see llama_sampler_add_constraint)
1012-
// - Sampling method (greedy, dist, mirostat)
1014+
// - Sampling method (greedy, dist)
10131015
// - Previous tokens
10141016
//
10151017
// In the future, it will be utilized offload the sampling to the backends (e.g. GPU).
10161018
//
1019+
// TODO: in the future, the entire API should be changed to accept llama_vocab, instead of llama_model
10171020

10181021
// constraints
10191022

@@ -1039,14 +1042,23 @@ extern "C" {
10391042
llama_constraint_context_t ctx;
10401043
};
10411044

1042-
LLAMA_API struct llama_constraint * llama_constraint_init_softmax (void);
1043-
LLAMA_API struct llama_constraint * llama_constraint_init_top_k (int32_t k, int32_t min_keep);
1044-
LLAMA_API struct llama_constraint * llama_constraint_init_top_p (float p, int32_t min_keep);
1045-
LLAMA_API struct llama_constraint * llama_constraint_init_min_p (float p, int32_t min_keep);
1046-
LLAMA_API struct llama_constraint * llama_constraint_init_tail_free (float z, int32_t min_keep);
1047-
LLAMA_API struct llama_constraint * llama_constraint_init_typical (float p, int32_t min_keep);
1048-
LLAMA_API struct llama_constraint * llama_constraint_init_temp (float t);
1049-
LLAMA_API struct llama_constraint * llama_constraint_init_temp_ext (float t, float delta, float exponent);
1045+
LLAMA_API struct llama_constraint * llama_constraint_init_softmax (void);
1046+
LLAMA_API struct llama_constraint * llama_constraint_init_top_k (int32_t k, int32_t min_keep);
1047+
LLAMA_API struct llama_constraint * llama_constraint_init_top_p (float p, int32_t min_keep);
1048+
LLAMA_API struct llama_constraint * llama_constraint_init_min_p (float p, int32_t min_keep);
1049+
LLAMA_API struct llama_constraint * llama_constraint_init_tail_free (float z, int32_t min_keep);
1050+
LLAMA_API struct llama_constraint * llama_constraint_init_typical (float p, int32_t min_keep);
1051+
LLAMA_API struct llama_constraint * llama_constraint_init_temp (float t);
1052+
LLAMA_API struct llama_constraint * llama_constraint_init_temp_ext (float t, float delta, float exponent);
1053+
1054+
LLAMA_API struct llama_constraint * llama_constraint_init_mirostat(
1055+
const struct llama_model * model,
1056+
float tau,
1057+
float eta);
1058+
1059+
LLAMA_API struct llama_constraint * llama_constraint_init_mirostat_v2(
1060+
float tau,
1061+
float eta);
10501062

10511063
LLAMA_API struct llama_constraint * llama_constraint_init_grammar(
10521064
const struct llama_model * model,
@@ -1093,9 +1105,8 @@ extern "C" {
10931105
LLAMA_API void llama_sampler_accept(struct llama_sampler * smpl, llama_token token);
10941106
LLAMA_API void llama_sampler_apply (struct llama_sampler * smpl, llama_token_data_array * cur_p);
10951107

1096-
LLAMA_API llama_token llama_sampler_sample_dist (struct llama_sampler * smpl, llama_token_data_array * cur_p);
1097-
LLAMA_API llama_token llama_sampler_sample_greedy (struct llama_sampler * smpl, llama_token_data_array * cur_p, bool probs);
1098-
LLAMA_API llama_token llama_sampler_sample_mirostat(struct llama_sampler * smpl, llama_token_data_array * cur_p);
1108+
LLAMA_API llama_token llama_sampler_sample_dist (struct llama_sampler * smpl, llama_token_data_array * cur_p);
1109+
LLAMA_API llama_token llama_sampler_sample_greedy(struct llama_sampler * smpl, llama_token_data_array * cur_p, bool probs);
10991110

11001111
/// @details Get the number of accepted tokens so far (max of n_prev)
11011112
LLAMA_API int llama_sampler_n_prev(const struct llama_sampler * smpl);

0 commit comments

Comments
 (0)