@@ -31,44 +31,41 @@ std::string gpt_sampling_params::print_samplers() const {
31
31
32
32
return result;
33
33
}
34
+
34
35
struct llama_sampling * llama_sampling_init (const struct llama_model * model, const struct gpt_sampling_params & params) {
35
- struct llama_sampling * result = nullptr ;
36
-
37
- {
38
- auto lparams = llama_sampling_default_params ();
39
-
40
- lparams.seed = params.seed ;
41
- lparams.n_prev = params.n_prev ;
42
- lparams.n_probs = params.n_probs ;
43
- lparams.min_keep = params.min_keep ;
44
- lparams.top_k = params.top_k ;
45
- lparams.top_p = params.top_p ;
46
- lparams.min_p = params.min_p ;
47
- lparams.tfs_z = params.tfs_z ;
48
- lparams.typ_p = params.typ_p ;
49
- lparams.temp = params.temp ;
50
- lparams.dynatemp_range = params.dynatemp_range ;
51
- lparams.dynatemp_exponent = params.dynatemp_exponent ;
52
- lparams.penalty_last_n = params.penalty_last_n ;
53
- lparams.penalty_repeat = params.penalty_repeat ;
54
- lparams.penalty_freq = params.penalty_freq ;
55
- lparams.penalty_present = params.penalty_present ;
56
- lparams.mirostat = params.mirostat ;
57
- lparams.mirostat_tau = params.mirostat_tau ;
58
- lparams.mirostat_eta = params.mirostat_eta ;
59
- lparams.penalize_nl = params.penalize_nl ;
60
- lparams.ignore_eos = params.ignore_eos ;
61
-
62
- lparams.n_samplers = params.samplers .size ();
63
- for (int i = 0 ; i < lparams.n_samplers ; i++) {
64
- lparams.samplers [i] = params.samplers [i];
65
- }
36
+ llama_sampling_params lparams = llama_sampling_default_params ();
37
+
38
+ lparams.seed = params.seed ;
39
+ lparams.n_prev = params.n_prev ;
40
+ lparams.n_probs = params.n_probs ;
41
+ lparams.min_keep = params.min_keep ;
42
+ lparams.top_k = params.top_k ;
43
+ lparams.top_p = params.top_p ;
44
+ lparams.min_p = params.min_p ;
45
+ lparams.tfs_z = params.tfs_z ;
46
+ lparams.typ_p = params.typ_p ;
47
+ lparams.temp = params.temp ;
48
+ lparams.dynatemp_range = params.dynatemp_range ;
49
+ lparams.dynatemp_exponent = params.dynatemp_exponent ;
50
+ lparams.penalty_last_n = params.penalty_last_n ;
51
+ lparams.penalty_repeat = params.penalty_repeat ;
52
+ lparams.penalty_freq = params.penalty_freq ;
53
+ lparams.penalty_present = params.penalty_present ;
54
+ lparams.mirostat = params.mirostat ;
55
+ lparams.mirostat_tau = params.mirostat_tau ;
56
+ lparams.mirostat_eta = params.mirostat_eta ;
57
+ lparams.penalize_nl = params.penalize_nl ;
58
+ lparams.ignore_eos = params.ignore_eos ;
59
+
60
+ lparams.n_samplers = params.samplers .size ();
61
+ for (int i = 0 ; i < lparams.n_samplers ; i++) {
62
+ lparams.samplers [i] = params.samplers [i];
63
+ }
66
64
67
- result = llama_sampling_init (model, lparams);
65
+ struct llama_sampling * result = llama_sampling_init (model, lparams);
68
66
69
- llama_sampling_set_grammar (result, params.grammar .c_str (), " root" );
70
- llama_sampling_set_logit_bias (result, params.logit_bias .size (), params.logit_bias .data ());
71
- }
67
+ llama_sampling_set_grammar (result, params.grammar .c_str (), " root" );
68
+ llama_sampling_set_logit_bias (result, params.logit_bias .size (), params.logit_bias .data ());
72
69
73
70
return result;
74
71
}
@@ -81,6 +78,35 @@ void llama_sampling_cp(llama_sampling * src, llama_sampling * dst) {
81
78
dst = llama_sampling_cp (src);
82
79
}
83
80
81
+ llama_token llama_sampling_sample (
82
+ struct llama_sampling * smpl,
83
+ struct llama_context * ctx,
84
+ int idx) {
85
+ llama_sampling_set_logits (smpl, llama_get_logits_ith (ctx, idx));
86
+
87
+ // first, sample the token without any grammar constraints
88
+ const llama_token id = llama_sampling_sample (smpl, nullptr );
89
+
90
+ // create an array with a single token data element for the sampled id
91
+ llama_token_data single_token_data = { id, 1 .0f , 0 .0f };
92
+ llama_token_data_array single_token_data_array = { &single_token_data, 1 , false };
93
+
94
+ llama_sampling_grammar (smpl, &single_token_data_array);
95
+
96
+ // check if the token is valid according to the grammar by seeing if its logit has been set to -INFINITY
97
+ const bool is_valid = single_token_data_array.data [0 ].logit != -INFINITY;
98
+ if (is_valid) {
99
+ return id;
100
+ }
101
+
102
+ // if the token is not valid, sample again, after applying the grammar constraints
103
+ llama_sampling_set_logits (smpl, llama_get_logits_ith (ctx, idx));
104
+
105
+ llama_sampling_grammar (smpl, nullptr );
106
+
107
+ return llama_sampling_sample (smpl, nullptr );
108
+ }
109
+
84
110
std::string llama_sampling_prev_str (llama_sampling * smpl, llama_context * ctx_main, int n) {
85
111
n = std::min (n, llama_sampling_n_prev (smpl));
86
112
@@ -152,27 +178,27 @@ std::vector<llama_sampler_type> llama_sampling_types_from_names(const std::vecto
152
178
{ " temp" , LLAMA_SAMPLER_TYPE_TEMPERATURE },
153
179
};
154
180
155
- std::vector<llama_sampler_type> sampler_types ;
156
- sampler_types .reserve (names.size ());
181
+ std::vector<llama_sampler_type> samplers ;
182
+ samplers .reserve (names.size ());
157
183
158
184
for (const auto & name : names) {
159
- auto sampler_item = sampler_canonical_name_map.find (name);
160
- if (sampler_item != sampler_canonical_name_map.end ()) {
161
- sampler_types .push_back (sampler_item ->second );
185
+ auto sampler = sampler_canonical_name_map.find (name);
186
+ if (sampler != sampler_canonical_name_map.end ()) {
187
+ samplers .push_back (sampler ->second );
162
188
} else {
163
189
if (allow_alt_names) {
164
- sampler_item = sampler_alt_name_map.find (name);
165
- if (sampler_item != sampler_alt_name_map.end ()) {
166
- sampler_types .push_back (sampler_item ->second );
190
+ sampler = sampler_alt_name_map.find (name);
191
+ if (sampler != sampler_alt_name_map.end ()) {
192
+ samplers .push_back (sampler ->second );
167
193
}
168
194
}
169
195
}
170
196
}
171
197
172
- return sampler_types ;
198
+ return samplers ;
173
199
}
174
200
175
- std::vector<llama_sampler_type> llama_sampling_types_from_chars (const std::string & names_string ) {
201
+ std::vector<llama_sampler_type> llama_sampling_types_from_chars (const std::string & chars ) {
176
202
std::unordered_map<char , llama_sampler_type> sampler_name_map {
177
203
{ llama_sampling_type_to_chr (LLAMA_SAMPLER_TYPE_TOP_K), LLAMA_SAMPLER_TYPE_TOP_K },
178
204
{ llama_sampling_type_to_chr (LLAMA_SAMPLER_TYPE_TFS_Z), LLAMA_SAMPLER_TYPE_TFS_Z },
@@ -182,42 +208,15 @@ std::vector<llama_sampler_type> llama_sampling_types_from_chars(const std::strin
182
208
{ llama_sampling_type_to_chr (LLAMA_SAMPLER_TYPE_TEMPERATURE), LLAMA_SAMPLER_TYPE_TEMPERATURE }
183
209
};
184
210
185
- std::vector<llama_sampler_type> sampler_types;
186
- sampler_types.reserve (names_string.size ());
187
- for (const auto & c : names_string) {
188
- const auto sampler_item = sampler_name_map.find (c);
189
- if (sampler_item != sampler_name_map.end ()) {
190
- sampler_types.push_back (sampler_item->second );
191
- }
192
- }
193
- return sampler_types;
194
- }
211
+ std::vector<llama_sampler_type> samplers;
212
+ samplers.reserve (chars.size ());
195
213
196
- llama_token llama_sampling_sample (
197
- struct llama_sampling * smpl,
198
- struct llama_context * ctx,
199
- int idx) {
200
- llama_sampling_set_logits (smpl, llama_get_logits_ith (ctx, idx));
201
-
202
- // first, sample the token without any grammar constraints
203
- auto id = llama_sampling_sample (smpl, nullptr );
204
-
205
- // create an array with a single token data element for the sampled id
206
- llama_token_data single_token_data = {id, 1 .0f , 0 .0f };
207
- llama_token_data_array single_token_data_array = { &single_token_data, 1 , false };
208
-
209
- llama_sampling_grammar (smpl, &single_token_data_array);
210
-
211
- // check if the token is valid according to the grammar by seeing if its logit has been set to -INFINITY
212
- const bool is_valid = single_token_data_array.data [0 ].logit != -INFINITY;
213
- if (is_valid) {
214
- return id;
214
+ for (const auto & c : chars) {
215
+ const auto sampler = sampler_name_map.find (c);
216
+ if (sampler != sampler_name_map.end ()) {
217
+ samplers.push_back (sampler->second );
218
+ }
215
219
}
216
220
217
- // if the token is not valid, sample again, after applying the grammar constraints
218
- llama_sampling_set_logits (smpl, llama_get_logits_ith (ctx, idx));
219
-
220
- llama_sampling_grammar (smpl, nullptr );
221
-
222
- return llama_sampling_sample (smpl, nullptr );
221
+ return samplers;
223
222
}
0 commit comments