Skip to content

Commit 850e853

Browse files
committed
wip
1 parent 201a190 commit 850e853

33 files changed

+249
-221
lines changed

common/common.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2127,7 +2127,7 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
21272127
llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0));
21282128
llama_kv_cache_clear(lctx);
21292129
llama_synchronize(lctx);
2130-
llama_reset_timings(lctx, nullptr, nullptr);
2130+
llama_reset_timings(lctx, nullptr);
21312131
}
21322132

21332133
return std::make_tuple(model, lctx);

common/sampling.cpp

Lines changed: 24 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -2,41 +2,20 @@
22

33
#include <random>
44

5-
struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params, struct llama_sampling * smpl) {
6-
struct llama_sampling_context * result = new llama_sampling_context();
7-
8-
result->params = params;
9-
result->smpl = smpl;
10-
result->grammar = nullptr;
5+
struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params, const struct llama_model * model) {
6+
auto result = llama_sampling_init(params, llama_sampling_init(model, params.grammar.c_str(), "root"));
117

12-
// if there is a grammar, parse it
13-
if (!params.grammar.empty()) {
14-
result->parsed_grammar = grammar_parser::parse(params.grammar.c_str());
8+
result->owned = true;
159

16-
// will be empty (default) if there are parse errors
17-
if (result->parsed_grammar.rules.empty()) {
18-
fprintf(stderr, "%s: failed to parse grammar\n", __func__);
19-
delete result;
20-
return nullptr;
21-
}
10+
return result;
11+
}
2212

23-
// Ensure that there is a "root" node.
24-
if (result->parsed_grammar.symbol_ids.find("root") == result->parsed_grammar.symbol_ids.end()) {
25-
fprintf(stderr, "%s: grammar does not contain a 'root' symbol\n", __func__);
26-
delete result;
27-
return nullptr;
28-
}
13+
struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params, struct llama_sampling * smpl) {
14+
struct llama_sampling_context * result = new llama_sampling_context();
2915

30-
std::vector<const llama_grammar_element *> grammar_rules(result->parsed_grammar.c_rules());
31-
32-
struct llama_grammar * grammar = llama_grammar_init(
33-
grammar_rules.data(),
34-
grammar_rules.size(), result->parsed_grammar.symbol_ids.at("root"));
35-
if (grammar == nullptr) {
36-
throw std::runtime_error("Failed to initialize llama_grammar");
37-
}
38-
result->grammar = grammar;
39-
}
16+
result->params = params;
17+
result->owned = false;
18+
result->smpl = smpl;
4019

4120
result->prev.resize(params.n_prev);
4221

@@ -48,46 +27,27 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_
4827
}
4928

5029
void llama_sampling_free(struct llama_sampling_context * ctx) {
51-
if (ctx->grammar != NULL) {
52-
llama_grammar_free(ctx->grammar);
30+
if (ctx->owned) {
31+
llama_sampling_free(ctx->smpl);
5332
}
5433

5534
delete ctx;
5635
}
5736

5837
void llama_sampling_reset(llama_sampling_context * ctx) {
59-
if (ctx->grammar != NULL) {
60-
llama_grammar_free(ctx->grammar);
61-
ctx->grammar = NULL;
62-
}
63-
64-
if (!ctx->parsed_grammar.rules.empty()) {
65-
std::vector<const llama_grammar_element *> grammar_rules(ctx->parsed_grammar.c_rules());
66-
67-
struct llama_grammar * grammar = llama_grammar_init(
68-
grammar_rules.data(),
69-
grammar_rules.size(), ctx->parsed_grammar.symbol_ids.at("root"));
70-
if (grammar == nullptr) {
71-
throw std::runtime_error("Failed to initialize llama_grammar");
72-
}
73-
ctx->grammar = grammar;
74-
}
38+
llama_sampling_reset(ctx->smpl, ctx->params.grammar.c_str(), "root");
7539

7640
std::fill(ctx->prev.begin(), ctx->prev.end(), 0);
7741
ctx->cur.clear();
7842
ctx->n_valid = 0;
7943
}
8044

8145
void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst) {
82-
if (dst->grammar) {
83-
llama_grammar_free(dst->grammar);
84-
dst->grammar = nullptr;
85-
}
86-
87-
if (src->grammar) {
88-
dst->grammar = llama_grammar_copy(src->grammar);
46+
if (dst->smpl) {
47+
llama_sampling_free(dst->smpl);
8948
}
9049

50+
dst->smpl = llama_sampling_cp(src->smpl);
9151
dst->prev = src->prev;
9252
}
9353

@@ -277,7 +237,7 @@ static llama_token llama_sampling_sample_impl(
277237

278238
std::vector<float> original_logits;
279239
auto cur_p = llama_sampling_prepare(ctx_sampling, ctx_main, ctx_cfg, idx, /* apply_grammar= */ is_resampling, &original_logits);
280-
if (ctx_sampling->grammar != NULL && !is_resampling) {
240+
if (!is_resampling) {
281241
GGML_ASSERT(!original_logits.empty());
282242
}
283243
llama_token id = 0;
@@ -320,7 +280,7 @@ static llama_token llama_sampling_sample_impl(
320280
}
321281
}
322282

323-
if (ctx_sampling->grammar != NULL && !is_resampling) {
283+
if (!is_resampling) {
324284
// Get a pointer to the logits
325285
float * logits = llama_get_logits_ith(ctx_main, idx);
326286

@@ -329,7 +289,7 @@ static llama_token llama_sampling_sample_impl(
329289
llama_token_data_array single_token_data_array = { &single_token_data, 1, false };
330290

331291
// Apply grammar constraints to the single token
332-
llama_grammar_sample(ctx_sampling->grammar, ctx_main, &single_token_data_array);
292+
llama_sampling_grammar(ctx_sampling->smpl, &single_token_data_array);
333293

334294
// Check if the token is valid according to the grammar by seeing if its logit has been set to -INFINITY
335295
bool is_valid = single_token_data_array.data[0].logit != -INFINITY;
@@ -376,7 +336,7 @@ static llama_token_data_array llama_sampling_prepare_impl(
376336
// Get a pointer to the logits
377337
float * logits = llama_get_logits_ith(ctx_main, idx);
378338

379-
if (ctx_sampling->grammar != NULL && !apply_grammar) {
339+
if (!apply_grammar) {
380340
GGML_ASSERT(original_logits != NULL);
381341
// Only make a copy of the original logits if we are not applying grammar checks, not sure if I actually have to do this.
382342
*original_logits = {logits, logits + n_vocab};
@@ -421,8 +381,8 @@ static llama_token_data_array llama_sampling_prepare_impl(
421381
}
422382

423383
// apply grammar checks before sampling logic
424-
if (apply_grammar && ctx_sampling->grammar != NULL) {
425-
llama_grammar_sample(ctx_sampling->grammar, ctx_main, &cur_p);
384+
if (apply_grammar) {
385+
llama_sampling_grammar(ctx_sampling->smpl, &cur_p);
426386
}
427387

428388
return cur_p;
@@ -449,13 +409,12 @@ llama_token_data_array llama_sampling_prepare(
449409

450410
void llama_sampling_accept(
451411
struct llama_sampling_context * ctx_sampling,
452-
struct llama_context * ctx_main,
453412
llama_token id,
454413
bool apply_grammar) {
455414
ctx_sampling->prev.erase(ctx_sampling->prev.begin());
456415
ctx_sampling->prev.push_back(id);
457416

458-
if (ctx_sampling->grammar != NULL && apply_grammar) {
459-
llama_grammar_accept_token(ctx_sampling->grammar, ctx_main, id);
417+
if (apply_grammar) {
418+
llama_sampling_accept(ctx_sampling->smpl, id);
460419
}
461420
}

common/sampling.h

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -71,11 +71,9 @@ struct llama_sampling_context {
7171
// mirostat sampler state
7272
float mirostat_mu;
7373

74-
llama_sampling * smpl;
75-
llama_grammar * grammar;
74+
bool owned;
7675

77-
// internal
78-
grammar_parser::parse_state parsed_grammar;
76+
llama_sampling * smpl;
7977

8078
// TODO: replace with ring-buffer
8179
std::vector<llama_token> prev;
@@ -87,6 +85,7 @@ struct llama_sampling_context {
8785
#include "common.h"
8886

8987
// Create a new sampling context instance.
88+
struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params, const struct llama_model * model);
9089
struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params, struct llama_sampling * smpl);
9190

9291
void llama_sampling_free(struct llama_sampling_context * ctx);
@@ -150,6 +149,5 @@ llama_token_data_array llama_sampling_prepare(
150149

151150
void llama_sampling_accept(
152151
struct llama_sampling_context * ctx_sampling,
153-
struct llama_context * ctx_main,
154152
llama_token id,
155153
bool apply_grammar);

examples/batched-bench/batched-bench.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ int main(int argc, char ** argv) {
200200
}
201201
}
202202

203-
llama_print_timings(ctx, nullptr, nullptr);
203+
llama_print_timings(ctx, nullptr);
204204

205205
llama_batch_free(batch);
206206

examples/batched/batched.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ int main(int argc, char ** argv) {
245245
LOG_TEE("%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n",
246246
__func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f));
247247

248-
llama_print_timings(ctx, smpl, nullptr);
248+
llama_print_timings(ctx, smpl);
249249

250250
fprintf(stderr, "\n");
251251

examples/embedding/embedding.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,7 @@ int main(int argc, char ** argv) {
258258
}
259259

260260
// clean up
261-
llama_print_timings(ctx, nullptr, nullptr);
261+
llama_print_timings(ctx, nullptr);
262262
llama_batch_free(batch);
263263
llama_free(ctx);
264264
llama_free_model(model);

examples/eval-callback/eval-callback.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ int main(int argc, char ** argv) {
182182
return 1;
183183
}
184184

185-
llama_print_timings(ctx, nullptr, nullptr);
185+
llama_print_timings(ctx, nullptr);
186186

187187
llama_free(ctx);
188188
llama_free_model(model);

examples/gbnf-validator/gbnf-validator.cpp

Lines changed: 5 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
#include "grammar-parser.h"
21
#include "ggml.h"
32
#include "llama.h"
4-
#include "llama-impl.h"
3+
#include "llama-vocab.h" // TMP
4+
#include "llama-grammar.h"
55
#include "unicode.h"
66

77
#include <cstdio>
@@ -84,27 +84,8 @@ int main(int argc, char** argv) {
8484
grammar_str = buffer.str();
8585
}
8686

87-
// Parse the GBNF grammar
88-
auto parsed_grammar = grammar_parser::parse(grammar_str.c_str());
89-
90-
// will be empty (default) if there are parse errors
91-
if (parsed_grammar.rules.empty()) {
92-
fprintf(stdout, "%s: failed to parse grammar\n", __func__);
93-
return 1;
94-
}
95-
96-
// Ensure that there is a "root" node.
97-
if (parsed_grammar.symbol_ids.find("root") == parsed_grammar.symbol_ids.end()) {
98-
fprintf(stdout, "%s: grammar does not contain a 'root' symbol\n", __func__);
99-
return 1;
100-
}
101-
102-
std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules());
103-
104-
// Create the LLAMA grammar
105-
auto grammar = llama_grammar_init(
106-
grammar_rules.data(),
107-
grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
87+
llama_vocab vocab; // TMP
88+
llama_grammar * grammar = llama_grammar_init_impl(vocab, grammar_str.c_str(), "root");
10889
if (grammar == nullptr) {
10990
throw std::runtime_error("Failed to initialize llama_grammar");
11091
}
@@ -130,7 +111,7 @@ int main(int argc, char** argv) {
130111
}
131112

132113
// Clean up
133-
llama_grammar_free(grammar);
114+
llama_grammar_free_impl(grammar);
134115

135116
return 0;
136117
}

examples/imatrix/imatrix.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -638,7 +638,7 @@ int main(int argc, char ** argv) {
638638

639639
g_collector.save_imatrix();
640640

641-
llama_print_timings(ctx, nullptr, nullptr);
641+
llama_print_timings(ctx, nullptr);
642642

643643
llama_free(ctx);
644644
llama_free_model(model);

examples/infill/infill.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ static void sigint_handler(int signo) {
9393
} else {
9494
console::cleanup();
9595
printf("\n");
96-
llama_print_timings(*g_ctx, (*g_ctx_sampling)->smpl, (*g_ctx_sampling)->grammar);
96+
llama_print_timings(*g_ctx, (*g_ctx_sampling)->smpl);
9797
write_logfile(*g_ctx, *g_params, *g_model, *g_input_tokens, g_output_ss->str(), *g_output_tokens);
9898
_exit(130);
9999
}
@@ -422,7 +422,7 @@ int main(int argc, char ** argv) {
422422
if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
423423
const llama_token id = llama_sampling_sample(ctx_sampling, ctx, nullptr);
424424

425-
llama_sampling_accept(ctx_sampling, ctx, id, true);
425+
llama_sampling_accept(ctx_sampling, id, true);
426426

427427
LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev).c_str());
428428

@@ -443,7 +443,7 @@ int main(int argc, char ** argv) {
443443

444444
// push the prompt in the sampling context in order to apply repetition penalties later
445445
// for the prompt, we don't apply grammar rules
446-
llama_sampling_accept(ctx_sampling, ctx, embd_inp[n_consumed], false);
446+
llama_sampling_accept(ctx_sampling, embd_inp[n_consumed], false);
447447

448448
++n_consumed;
449449
if ((int) embd.size() >= params.n_batch) {
@@ -637,7 +637,7 @@ int main(int argc, char ** argv) {
637637
fflush(stdout);
638638
}
639639

640-
llama_print_timings(ctx, ctx_sampling->smpl, ctx_sampling->grammar);
640+
llama_print_timings(ctx, ctx_sampling->smpl);
641641
write_logfile(ctx, params, model, input_tokens, output_ss.str(), output_tokens);
642642

643643
llama_free(ctx);

0 commit comments

Comments
 (0)