Skip to content

Commit 2aa6dd2

Browse files
committed
add stacks cache into llama_grammar
1 parent 901a347 commit 2aa6dd2

File tree

4 files changed

+15
-9
lines changed

4 files changed

+15
-9
lines changed

examples/gbnf-validator/gbnf-validator.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@ static bool llama_grammar_validate(struct llama_grammar * grammar, const std::st
1313

1414
const llama_grammar_rules & rules = llama_grammar_get_rules (grammar);
1515
llama_grammar_stacks & stacks_cur = llama_grammar_get_stacks(grammar);
16+
llama_grammar_stacks_cache & stacks_cache = llama_grammar_get_stacks_cache(grammar);
1617

1718
size_t pos = 0;
18-
llama_grammar_stacks_cache stacks_cache;
1919
for (const auto & cpt : cpts) {
2020
const llama_grammar_stacks stacks_prev = llama_grammar_get_stacks(grammar); // copy
2121

src/llama-grammar.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -917,6 +917,10 @@ llama_grammar_stacks & llama_grammar_get_stacks(struct llama_grammar * grammar)
917917
return grammar->stacks;
918918
}
919919

920+
llama_grammar_stacks_cache & llama_grammar_get_stacks_cache(struct llama_grammar * grammar) {
921+
return grammar->stacks_cache;
922+
}
923+
920924
void llama_grammar_accept(
921925
const llama_grammar_rules & rules,
922926
const llama_grammar_stacks & stacks,
@@ -1058,7 +1062,7 @@ struct llama_grammar * llama_grammar_init_impl(
10581062
// Important: vec_rules has to be moved here, not copied, because stacks contains
10591063
// pointers to elements of vec_rules. If vec_rules were copied into llama_grammar
10601064
// then the pointers would be invalidated when the local vec_rules goes out of scope.
1061-
return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, };
1065+
return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, std::move(stacks_cache), };
10621066
}
10631067

10641068
struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root) {
@@ -1137,7 +1141,7 @@ struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab,
11371141
// Important: vec_rules has to be moved here, not copied, because stacks contains
11381142
// pointers to elements of vec_rules. If vec_rules were copied into llama_grammar
11391143
// then the pointers would be invalidated when the local vec_rules goes out of scope.
1140-
return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, };
1144+
return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, std::move(stacks_cache), };
11411145
}
11421146

11431147
void llama_grammar_free_impl(struct llama_grammar * grammar) {
@@ -1225,10 +1229,9 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
12251229
const auto & code_points = decoded.first;
12261230

12271231
llama_grammar_stacks stacks_new;
1228-
llama_grammar_stacks_cache stacks_cache;
12291232

12301233
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
1231-
llama_grammar_accept(grammar.rules, grammar.stacks, *it, stacks_new, stacks_cache);
1234+
llama_grammar_accept(grammar.rules, grammar.stacks, *it, stacks_new, grammar.stacks_cache);
12321235
grammar.stacks = std::move(stacks_new);
12331236
}
12341237

src/llama-grammar.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,6 @@ using llama_grammar_rules = std::vector<llama_grammar_rule>;
5959
using llama_grammar_stacks = std::vector<llama_grammar_stack>;
6060
using llama_grammar_candidates = std::vector<llama_grammar_candidate>;
6161

62-
const llama_grammar_rules & llama_grammar_get_rules (const struct llama_grammar * grammar);
63-
llama_grammar_stacks & llama_grammar_get_stacks( struct llama_grammar * grammar);
64-
6562
struct VectorPointerHash {
6663
size_t operator()(const llama_grammar_stack & v) const {
6764
size_t seed = v.size();
@@ -74,6 +71,10 @@ struct VectorPointerHash {
7471

7572
using llama_grammar_stacks_cache = std::unordered_map<llama_grammar_stack, llama_grammar_stacks, VectorPointerHash>;
7673

74+
const llama_grammar_rules & llama_grammar_get_rules (const struct llama_grammar * grammar);
75+
llama_grammar_stacks & llama_grammar_get_stacks( struct llama_grammar * grammar);
76+
llama_grammar_stacks_cache & llama_grammar_get_stacks_cache( struct llama_grammar * grammar);
77+
7778
// takes a set of possible pushdown stacks on a grammar, which are required to
7879
// be positioned at a character range (see `llama_grammar_advance_stack`), and
7980
// produces the N possible stacks if the given char is accepted at those
@@ -129,6 +130,8 @@ struct llama_grammar {
129130

130131
// buffer for partially generated UTF-8 sequence from accepted tokens
131132
llama_partial_utf8 partial_utf8;
133+
// cache N possible stacks from a stack
134+
llama_grammar_stacks_cache stacks_cache;
132135
};
133136

134137
//

tests/test-grammar-integration.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ static bool match_string(const std::string & input, llama_grammar * grammar) {
3434

3535
const llama_grammar_rules & rules = llama_grammar_get_rules (grammar);
3636
llama_grammar_stacks & stacks_cur = llama_grammar_get_stacks(grammar);
37+
llama_grammar_stacks_cache & stacks_cache = llama_grammar_get_stacks_cache(grammar);
3738

38-
llama_grammar_stacks_cache stacks_cache;
3939
for (const auto & cpt : cpts) {
4040
const llama_grammar_stacks stacks_prev = llama_grammar_get_stacks(grammar); // copy
4141

0 commit comments

Comments
 (0)