@@ -687,31 +687,17 @@ static bool llama_grammar_match_partial_char(
687
687
// additionally memorizes the stack to its possible stacks by mapping
688
688
// < llama_grammar_stack, llama_grammar_stacks >
689
689
690
- struct VectorPointerHash {
691
- size_t operator ()(const llama_grammar_stack & v) const {
692
- size_t seed = v.size ();
693
- for (const auto * ptr : v) {
694
- seed ^= std::hash<const llama_grammar_element*>()(ptr) + 0x9e3779b9 + (seed << 6 ) + (seed >> 2 );
695
- }
696
- return seed;
697
- }
698
- };
699
-
700
- static std::unordered_map<
701
- llama_grammar_stack,
702
- llama_grammar_stacks,
703
- VectorPointerHash>
704
- llama_grammar_stacks_cache = {};
705
-
706
690
static void llama_grammar_advance_stack_memo (
707
691
const llama_grammar_rules & rules,
708
692
const llama_grammar_stack & stack,
709
- llama_grammar_stacks & new_stacks);
693
+ llama_grammar_stacks & new_stacks,
694
+ llama_grammar_stacks_cache & stacks_cache);
710
695
711
696
static void llama_grammar_advance_stack_memo_impl (
712
697
const llama_grammar_rules & rules,
713
698
const llama_grammar_stack & stack,
714
- llama_grammar_stacks & new_stacks) {
699
+ llama_grammar_stacks & new_stacks,
700
+ llama_grammar_stacks_cache & stacks_cache) {
715
701
if (stack.empty ()) {
716
702
if (std::find (new_stacks.begin (), new_stacks.end (), stack) == new_stacks.end ()) {
717
703
new_stacks.emplace_back (stack);
@@ -736,7 +722,7 @@ static void llama_grammar_advance_stack_memo_impl(
736
722
// if alternate is nonempty, add to stack
737
723
new_stack.push_back (subpos);
738
724
}
739
- llama_grammar_advance_stack_memo (rules, new_stack, new_stacks);
725
+ llama_grammar_advance_stack_memo (rules, new_stack, new_stacks, stacks_cache );
740
726
while (!llama_grammar_is_end_of_sequence (subpos)) {
741
727
// scan to end of alternate def
742
728
subpos++;
@@ -769,17 +755,18 @@ static void llama_grammar_advance_stack_memo_impl(
769
755
static void llama_grammar_advance_stack_memo (
770
756
const llama_grammar_rules & rules,
771
757
const llama_grammar_stack & stack,
772
- llama_grammar_stacks & new_stacks) {
758
+ llama_grammar_stacks & new_stacks,
759
+ llama_grammar_stacks_cache & stacks_cache) {
773
760
774
761
llama_grammar_stacks advanced_stacks;
775
762
// Look if stack is already in memory
776
- auto it = llama_grammar_stacks_cache .find (stack);
777
- if (it != llama_grammar_stacks_cache .end ()) {
763
+ auto it = stacks_cache .find (stack);
764
+ if (it != stacks_cache .end ()) {
778
765
advanced_stacks = it->second ;
779
766
} else {
780
767
// Advance stacks with memorization
781
- llama_grammar_advance_stack_memo_impl (rules, stack, advanced_stacks);
782
- llama_grammar_stacks_cache .insert (make_pair (stack, advanced_stacks));
768
+ llama_grammar_advance_stack_memo_impl (rules, stack, advanced_stacks, stacks_cache );
769
+ stacks_cache .insert (make_pair (stack, advanced_stacks));
783
770
}
784
771
// Add the advanced stacks to new_stacks avoiding duplicates
785
772
for (const auto & new_stack : advanced_stacks) {
@@ -934,7 +921,8 @@ void llama_grammar_accept(
934
921
const llama_grammar_rules & rules,
935
922
const llama_grammar_stacks & stacks,
936
923
const uint32_t chr,
937
- llama_grammar_stacks & stacks_new) {
924
+ llama_grammar_stacks & stacks_new,
925
+ llama_grammar_stacks_cache & stacks_cache) {
938
926
stacks_new.clear ();
939
927
stacks_new.reserve (stacks.size ());
940
928
@@ -952,7 +940,7 @@ void llama_grammar_accept(
952
940
if (!llama_grammar_is_end_of_sequence (pos)) {
953
941
new_stack.push_back (pos);
954
942
}
955
- llama_grammar_advance_stack_memo (rules, new_stack, stacks_new);
943
+ llama_grammar_advance_stack_memo (rules, new_stack, stacks_new, stacks_cache );
956
944
}
957
945
}
958
946
}
@@ -1019,8 +1007,6 @@ struct llama_grammar * llama_grammar_init_impl(
1019
1007
const llama_grammar_element ** rules,
1020
1008
size_t n_rules,
1021
1009
size_t start_rule_index) {
1022
- // Clear stacks cache
1023
- llama_grammar_stacks_cache.clear ();
1024
1010
const llama_grammar_element * pos;
1025
1011
1026
1012
// copy rule definitions into vectors
@@ -1048,14 +1034,15 @@ struct llama_grammar * llama_grammar_init_impl(
1048
1034
1049
1035
// loop over alternates of start rule to build initial stacks
1050
1036
llama_grammar_stacks stacks;
1037
+ llama_grammar_stacks_cache stacks_cache;
1051
1038
pos = vec_rules[start_rule_index].data ();
1052
1039
do {
1053
1040
llama_grammar_stack stack;
1054
1041
if (!llama_grammar_is_end_of_sequence (pos)) {
1055
1042
// if alternate is nonempty, add to stack
1056
1043
stack.push_back (pos);
1057
1044
}
1058
- llama_grammar_advance_stack_memo (vec_rules, stack, stacks);
1045
+ llama_grammar_advance_stack_memo (vec_rules, stack, stacks, stacks_cache );
1059
1046
while (!llama_grammar_is_end_of_sequence (pos)) {
1060
1047
// scan to end of alternate def
1061
1048
pos++;
@@ -1075,8 +1062,6 @@ struct llama_grammar * llama_grammar_init_impl(
1075
1062
}
1076
1063
1077
1064
struct llama_grammar * llama_grammar_init_impl (const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root) {
1078
- // Clear stacks cache
1079
- llama_grammar_stacks_cache.clear ();
1080
1065
llama_grammar_parser parser;
1081
1066
1082
1067
// if there is a grammar, parse it
@@ -1128,14 +1113,15 @@ struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab,
1128
1113
1129
1114
// loop over alternates of start rule to build initial stacks
1130
1115
llama_grammar_stacks stacks;
1116
+ llama_grammar_stacks_cache stacks_cache;
1131
1117
pos = vec_rules[start_rule_index].data ();
1132
1118
do {
1133
1119
llama_grammar_stack stack;
1134
1120
if (!llama_grammar_is_end_of_sequence (pos)) {
1135
1121
// if alternate is nonempty, add to stack
1136
1122
stack.push_back (pos);
1137
1123
}
1138
- llama_grammar_advance_stack_memo (vec_rules, stack, stacks);
1124
+ llama_grammar_advance_stack_memo (vec_rules, stack, stacks, stacks_cache );
1139
1125
while (!llama_grammar_is_end_of_sequence (pos)) {
1140
1126
// scan to end of alternate def
1141
1127
pos++;
@@ -1239,9 +1225,10 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
1239
1225
const auto & code_points = decoded.first ;
1240
1226
1241
1227
llama_grammar_stacks stacks_new;
1228
+ llama_grammar_stacks_cache stacks_cache;
1242
1229
1243
1230
for (auto it = code_points.begin (), end = code_points.end () - 1 ; it != end; ++it) {
1244
- llama_grammar_accept (grammar.rules , grammar.stacks , *it, stacks_new);
1231
+ llama_grammar_accept (grammar.rules , grammar.stacks , *it, stacks_new, stacks_cache );
1245
1232
grammar.stacks = std::move (stacks_new);
1246
1233
}
1247
1234
0 commit comments