Skip to content

Commit 901a347

Browse files
committed
move cache stack to advance stack
1 parent cb1632b commit 901a347

File tree

4 files changed

+39
-36
lines changed

4 files changed

+39
-36
lines changed

examples/gbnf-validator/gbnf-validator.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,11 @@ static bool llama_grammar_validate(struct llama_grammar * grammar, const std::st
1515
llama_grammar_stacks & stacks_cur = llama_grammar_get_stacks(grammar);
1616

1717
size_t pos = 0;
18+
llama_grammar_stacks_cache stacks_cache;
1819
for (const auto & cpt : cpts) {
1920
const llama_grammar_stacks stacks_prev = llama_grammar_get_stacks(grammar); // copy
2021

21-
llama_grammar_accept(rules, stacks_prev, cpt, stacks_cur);
22+
llama_grammar_accept(rules, stacks_prev, cpt, stacks_cur, stacks_cache);
2223

2324
if (stacks_cur.empty()) {
2425
error_pos = pos;

src/llama-grammar.cpp

Lines changed: 20 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -687,31 +687,17 @@ static bool llama_grammar_match_partial_char(
687687
// additionally memorizes the stack to its possible stacks by mapping
688688
// < llama_grammar_stack, llama_grammar_stacks >
689689

690-
struct VectorPointerHash {
691-
size_t operator()(const llama_grammar_stack & v) const {
692-
size_t seed = v.size();
693-
for (const auto* ptr : v) {
694-
seed ^= std::hash<const llama_grammar_element*>()(ptr) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
695-
}
696-
return seed;
697-
}
698-
};
699-
700-
static std::unordered_map<
701-
llama_grammar_stack,
702-
llama_grammar_stacks,
703-
VectorPointerHash>
704-
llama_grammar_stacks_cache = {};
705-
706690
static void llama_grammar_advance_stack_memo(
707691
const llama_grammar_rules & rules,
708692
const llama_grammar_stack & stack,
709-
llama_grammar_stacks & new_stacks);
693+
llama_grammar_stacks & new_stacks,
694+
llama_grammar_stacks_cache & stacks_cache);
710695

711696
static void llama_grammar_advance_stack_memo_impl(
712697
const llama_grammar_rules & rules,
713698
const llama_grammar_stack & stack,
714-
llama_grammar_stacks & new_stacks) {
699+
llama_grammar_stacks & new_stacks,
700+
llama_grammar_stacks_cache & stacks_cache) {
715701
if (stack.empty()) {
716702
if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) {
717703
new_stacks.emplace_back(stack);
@@ -736,7 +722,7 @@ static void llama_grammar_advance_stack_memo_impl(
736722
// if alternate is nonempty, add to stack
737723
new_stack.push_back(subpos);
738724
}
739-
llama_grammar_advance_stack_memo(rules, new_stack, new_stacks);
725+
llama_grammar_advance_stack_memo(rules, new_stack, new_stacks, stacks_cache);
740726
while (!llama_grammar_is_end_of_sequence(subpos)) {
741727
// scan to end of alternate def
742728
subpos++;
@@ -769,17 +755,18 @@ static void llama_grammar_advance_stack_memo_impl(
769755
static void llama_grammar_advance_stack_memo(
770756
const llama_grammar_rules & rules,
771757
const llama_grammar_stack & stack,
772-
llama_grammar_stacks & new_stacks) {
758+
llama_grammar_stacks & new_stacks,
759+
llama_grammar_stacks_cache & stacks_cache) {
773760

774761
llama_grammar_stacks advanced_stacks;
775762
// Look if stack is already in memory
776-
auto it = llama_grammar_stacks_cache.find(stack);
777-
if (it != llama_grammar_stacks_cache.end()) {
763+
auto it = stacks_cache.find(stack);
764+
if (it != stacks_cache.end()) {
778765
advanced_stacks = it->second;
779766
} else {
780767
// Advance stacks with memorization
781-
llama_grammar_advance_stack_memo_impl(rules, stack, advanced_stacks);
782-
llama_grammar_stacks_cache.insert(make_pair(stack, advanced_stacks));
768+
llama_grammar_advance_stack_memo_impl(rules, stack, advanced_stacks, stacks_cache);
769+
stacks_cache.insert(make_pair(stack, advanced_stacks));
783770
}
784771
// Add the advanced stacks to new_stacks avoiding duplicates
785772
for (const auto & new_stack : advanced_stacks) {
@@ -934,7 +921,8 @@ void llama_grammar_accept(
934921
const llama_grammar_rules & rules,
935922
const llama_grammar_stacks & stacks,
936923
const uint32_t chr,
937-
llama_grammar_stacks & stacks_new) {
924+
llama_grammar_stacks & stacks_new,
925+
llama_grammar_stacks_cache & stacks_cache) {
938926
stacks_new.clear();
939927
stacks_new.reserve(stacks.size());
940928

@@ -952,7 +940,7 @@ void llama_grammar_accept(
952940
if (!llama_grammar_is_end_of_sequence(pos)) {
953941
new_stack.push_back(pos);
954942
}
955-
llama_grammar_advance_stack_memo(rules, new_stack, stacks_new);
943+
llama_grammar_advance_stack_memo(rules, new_stack, stacks_new, stacks_cache);
956944
}
957945
}
958946
}
@@ -1019,8 +1007,6 @@ struct llama_grammar * llama_grammar_init_impl(
10191007
const llama_grammar_element ** rules,
10201008
size_t n_rules,
10211009
size_t start_rule_index) {
1022-
// Clear stacks cache
1023-
llama_grammar_stacks_cache.clear();
10241010
const llama_grammar_element * pos;
10251011

10261012
// copy rule definitions into vectors
@@ -1048,14 +1034,15 @@ struct llama_grammar * llama_grammar_init_impl(
10481034

10491035
// loop over alternates of start rule to build initial stacks
10501036
llama_grammar_stacks stacks;
1037+
llama_grammar_stacks_cache stacks_cache;
10511038
pos = vec_rules[start_rule_index].data();
10521039
do {
10531040
llama_grammar_stack stack;
10541041
if (!llama_grammar_is_end_of_sequence(pos)) {
10551042
// if alternate is nonempty, add to stack
10561043
stack.push_back(pos);
10571044
}
1058-
llama_grammar_advance_stack_memo(vec_rules, stack, stacks);
1045+
llama_grammar_advance_stack_memo(vec_rules, stack, stacks, stacks_cache);
10591046
while (!llama_grammar_is_end_of_sequence(pos)) {
10601047
// scan to end of alternate def
10611048
pos++;
@@ -1075,8 +1062,6 @@ struct llama_grammar * llama_grammar_init_impl(
10751062
}
10761063

10771064
struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root) {
1078-
// Clear stacks cache
1079-
llama_grammar_stacks_cache.clear();
10801065
llama_grammar_parser parser;
10811066

10821067
// if there is a grammar, parse it
@@ -1128,14 +1113,15 @@ struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab,
11281113

11291114
// loop over alternates of start rule to build initial stacks
11301115
llama_grammar_stacks stacks;
1116+
llama_grammar_stacks_cache stacks_cache;
11311117
pos = vec_rules[start_rule_index].data();
11321118
do {
11331119
llama_grammar_stack stack;
11341120
if (!llama_grammar_is_end_of_sequence(pos)) {
11351121
// if alternate is nonempty, add to stack
11361122
stack.push_back(pos);
11371123
}
1138-
llama_grammar_advance_stack_memo(vec_rules, stack, stacks);
1124+
llama_grammar_advance_stack_memo(vec_rules, stack, stacks, stacks_cache);
11391125
while (!llama_grammar_is_end_of_sequence(pos)) {
11401126
// scan to end of alternate def
11411127
pos++;
@@ -1239,9 +1225,10 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
12391225
const auto & code_points = decoded.first;
12401226

12411227
llama_grammar_stacks stacks_new;
1228+
llama_grammar_stacks_cache stacks_cache;
12421229

12431230
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
1244-
llama_grammar_accept(grammar.rules, grammar.stacks, *it, stacks_new);
1231+
llama_grammar_accept(grammar.rules, grammar.stacks, *it, stacks_new, stacks_cache);
12451232
grammar.stacks = std::move(stacks_new);
12461233
}
12471234

src/llama-grammar.h

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include "llama-impl.h"
44

55
#include <map>
6+
#include <unordered_map>
67

78
struct llama_vocab;
89

@@ -61,6 +62,18 @@ using llama_grammar_candidates = std::vector<llama_grammar_candidate>;
6162
const llama_grammar_rules & llama_grammar_get_rules (const struct llama_grammar * grammar);
6263
llama_grammar_stacks & llama_grammar_get_stacks( struct llama_grammar * grammar);
6364

65+
struct VectorPointerHash {
66+
size_t operator()(const llama_grammar_stack & v) const {
67+
size_t seed = v.size();
68+
for (const auto* ptr : v) {
69+
seed ^= std::hash<const llama_grammar_element*>()(ptr) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
70+
}
71+
return seed;
72+
}
73+
};
74+
75+
using llama_grammar_stacks_cache = std::unordered_map<llama_grammar_stack, llama_grammar_stacks, VectorPointerHash>;
76+
6477
// takes a set of possible pushdown stacks on a grammar, which are required to
6578
// be positioned at a character range (see `llama_grammar_advance_stack`), and
6679
// produces the N possible stacks if the given char is accepted at those
@@ -69,7 +82,8 @@ void llama_grammar_accept(
6982
const llama_grammar_rules & rules,
7083
const llama_grammar_stacks & stacks,
7184
uint32_t chr,
72-
llama_grammar_stacks & stacks_new);
85+
llama_grammar_stacks & stacks_new,
86+
llama_grammar_stacks_cache & stacks_cache);
7387

7488
std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_stack(
7589
const llama_grammar_rules & rules,

tests/test-grammar-integration.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,11 @@ static bool match_string(const std::string & input, llama_grammar * grammar) {
3535
const llama_grammar_rules & rules = llama_grammar_get_rules (grammar);
3636
llama_grammar_stacks & stacks_cur = llama_grammar_get_stacks(grammar);
3737

38+
llama_grammar_stacks_cache stacks_cache;
3839
for (const auto & cpt : cpts) {
3940
const llama_grammar_stacks stacks_prev = llama_grammar_get_stacks(grammar); // copy
4041

41-
llama_grammar_accept(rules, stacks_prev, cpt, stacks_cur);
42+
llama_grammar_accept(rules, stacks_prev, cpt, stacks_cur, stacks_cache);
4243

4344
if (stacks_cur.empty()) {
4445
// no stacks means that the grammar failed to match at this point

0 commit comments

Comments
 (0)