@@ -47,9 +47,6 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st
47
47
48
48
lparams.seed = params.seed ;
49
49
lparams.n_prev = params.n_prev ;
50
- lparams.mirostat = params.mirostat ;
51
- lparams.mirostat_tau = params.mirostat_tau ;
52
- lparams.mirostat_eta = params.mirostat_eta ;
53
50
54
51
auto * result = new gpt_sampler {
55
52
/* .params = */ params,
@@ -69,29 +66,39 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st
69
66
/* .smpl = */ llama_sampler_init (model, lparams)
70
67
};
71
68
72
- for (const auto & cnstr : params.constraints ) {
73
- switch (cnstr) {
74
- case GPT_CONSTRAINT_TYPE_TOP_K:
75
- llama_sampler_add_constraint (result->smpl , llama_constraint_init_top_k (params.top_k , params.min_keep ));
76
- break ;
77
- case GPT_CONSTRAINT_TYPE_TOP_P:
78
- llama_sampler_add_constraint (result->smpl , llama_constraint_init_top_p (params.top_p , params.min_keep ));
79
- break ;
80
- case GPT_CONSTRAINT_TYPE_MIN_P:
81
- llama_sampler_add_constraint (result->smpl , llama_constraint_init_min_p (params.min_p , params.min_keep ));
82
- break ;
83
- case GPT_CONSTRAINT_TYPE_TFS_Z:
84
- llama_sampler_add_constraint (result->smpl , llama_constraint_init_tail_free (params.tfs_z , params.min_keep ));
85
- break ;
86
- case GPT_CONSTRAINT_TYPE_TYPICAL_P:
87
- llama_sampler_add_constraint (result->smpl , llama_constraint_init_typical (params.typ_p , params.min_keep ));
88
- break ;
89
- case GPT_CONSTRAINT_TYPE_TEMPERATURE:
90
- llama_sampler_add_constraint (result->smpl , llama_constraint_init_temp_ext (params.temp , params.dynatemp_range , params.dynatemp_exponent ));
91
- break ;
92
- default :
93
- GGML_ASSERT (false && " unknown constraint type" );
69
+ if (params.mirostat == 0 ) {
70
+ for (const auto & cnstr : params.constraints ) {
71
+ switch (cnstr) {
72
+ case GPT_CONSTRAINT_TYPE_TOP_K:
73
+ llama_sampler_add_constraint (result->smpl , llama_constraint_init_top_k (params.top_k , params.min_keep ));
74
+ break ;
75
+ case GPT_CONSTRAINT_TYPE_TOP_P:
76
+ llama_sampler_add_constraint (result->smpl , llama_constraint_init_top_p (params.top_p , params.min_keep ));
77
+ break ;
78
+ case GPT_CONSTRAINT_TYPE_MIN_P:
79
+ llama_sampler_add_constraint (result->smpl , llama_constraint_init_min_p (params.min_p , params.min_keep ));
80
+ break ;
81
+ case GPT_CONSTRAINT_TYPE_TFS_Z:
82
+ llama_sampler_add_constraint (result->smpl , llama_constraint_init_tail_free (params.tfs_z , params.min_keep ));
83
+ break ;
84
+ case GPT_CONSTRAINT_TYPE_TYPICAL_P:
85
+ llama_sampler_add_constraint (result->smpl , llama_constraint_init_typical (params.typ_p , params.min_keep ));
86
+ break ;
87
+ case GPT_CONSTRAINT_TYPE_TEMPERATURE:
88
+ llama_sampler_add_constraint (result->smpl , llama_constraint_init_temp_ext (params.temp , params.dynatemp_range , params.dynatemp_exponent ));
89
+ break ;
90
+ default :
91
+ GGML_ASSERT (false && " unknown constraint type" );
92
+ }
94
93
}
94
+ } else if (params.mirostat == 1 ) {
95
+ llama_sampler_add_constraint (result->smpl , llama_constraint_init_temp (params.temp ));
96
+ llama_sampler_add_constraint (result->smpl , llama_constraint_init_mirostat (model, params.mirostat_tau , params.mirostat_eta ));
97
+ } else if (params.mirostat == 2 ) {
98
+ llama_sampler_add_constraint (result->smpl , llama_constraint_init_temp (params.temp ));
99
+ llama_sampler_add_constraint (result->smpl , llama_constraint_init_mirostat_v2 (params.mirostat_tau , params.mirostat_eta ));
100
+ } else {
101
+ GGML_ASSERT (false && " unknown mirostat version" );
95
102
}
96
103
97
104
return result;
@@ -153,7 +160,6 @@ static llama_token gpt_sampler_sample(
153
160
struct llama_sampler * smpl,
154
161
struct llama_token_data_array * cur_p,
155
162
float temp,
156
- int mirostat,
157
163
int n_probs) {
158
164
llama_token res = 0 ;
159
165
@@ -167,24 +173,20 @@ static llama_token gpt_sampler_sample(
167
173
// apply all sampling constraints and then sample
168
174
llama_sampler_apply (smpl, cur_p);
169
175
170
- if (mirostat != 0 ) {
171
- res = llama_sampler_sample_mirostat (smpl, cur_p);
172
- } else {
173
- res = llama_sampler_sample_dist (smpl, cur_p);
176
+ res = llama_sampler_sample_dist (smpl, cur_p);
174
177
175
- // {
176
- // const int n_top = 10;
177
- // LOG("top %d candidates:\n", n_top);
178
+ // {
179
+ // const int n_top = 10;
180
+ // LOG("top %d candidates:\n", n_top);
178
181
179
- // for (int i = 0; i < n_top; i++) {
180
- // const llama_token id = cur_p.data[i].id;
181
- // (void)id; // To avoid a warning that id is unused when logging is disabled.
182
- // LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(smpl, id).c_str(), cur_p.data[i].p);
183
- // }
184
- // }
182
+ // for (int i = 0; i < n_top; i++) {
183
+ // const llama_token id = cur_p.data[i].id;
184
+ // (void)id; // To avoid a warning that id is unused when logging is disabled.
185
+ // LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(smpl, id).c_str(), cur_p.data[i].p);
186
+ // }
187
+ // }
185
188
186
- // LOG("sampled token: %5d: '%s'\n", res, llama_token_to_piece(smpl, res).c_str());
187
- }
189
+ // LOG("sampled token: %5d: '%s'\n", res, llama_token_to_piece(smpl, res).c_str());
188
190
}
189
191
190
192
return res;
@@ -208,7 +210,7 @@ llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context
208
210
llama_constraint_apply (pnlt, cur_p);
209
211
210
212
// first, sample the token without any grammar constraints
211
- const llama_token id = gpt_sampler_sample (smpl, nullptr , params.temp , params.mirostat , params. n_probs );
213
+ const llama_token id = gpt_sampler_sample (smpl, nullptr , params.temp , params.n_probs );
212
214
213
215
// check if it the sampled token fits the grammar
214
216
{
@@ -231,7 +233,7 @@ llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context
231
233
llama_constraint_apply (pnlt, cur_p);
232
234
llama_constraint_apply (grmr, cur_p);
233
235
234
- return gpt_sampler_sample (smpl, cur_p, params.temp , params.mirostat , params. n_probs );
236
+ return gpt_sampler_sample (smpl, cur_p, params.temp , params.n_probs );
235
237
}
236
238
237
239
void gpt_sampler_apply_grammar (struct gpt_sampler * gsmpl, llama_token_data_array * cur_p) {
0 commit comments