-
Notifications
You must be signed in to change notification settings - Fork 12.4k
llama : adds llama-grammar memoization stacks (#4218) #9833
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 3 commits
cb1632b
901a347
2aa6dd2
17b3a3e
34fc44d
a33fbbe
dc68a59
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
|
@@ -682,6 +682,101 @@ static bool llama_grammar_match_partial_char( | |||||||||
return !is_positive_char; | ||||||||||
} | ||||||||||
|
||||||||||
// transforms a grammar pushdown stack into N possible stacks, all ending | ||||||||||
// at a character range (terminal element) | ||||||||||
// additionally memorizes the stack to its possible stacks by mapping | ||||||||||
// < llama_grammar_stack, llama_grammar_stacks > | ||||||||||
|
||||||||||
static void llama_grammar_advance_stack_memo( | ||||||||||
const llama_grammar_rules & rules, | ||||||||||
const llama_grammar_stack & stack, | ||||||||||
llama_grammar_stacks & new_stacks, | ||||||||||
llama_grammar_stacks_cache & stacks_cache); | ||||||||||
|
||||||||||
static void llama_grammar_advance_stack_memo_impl( | ||||||||||
const llama_grammar_rules & rules, | ||||||||||
const llama_grammar_stack & stack, | ||||||||||
llama_grammar_stacks & new_stacks, | ||||||||||
llama_grammar_stacks_cache & stacks_cache) { | ||||||||||
if (stack.empty()) { | ||||||||||
if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) { | ||||||||||
new_stacks.emplace_back(stack); | ||||||||||
} | ||||||||||
return; | ||||||||||
} | ||||||||||
|
||||||||||
const llama_grammar_element * pos = stack.back(); | ||||||||||
|
||||||||||
switch (pos->type) { | ||||||||||
case LLAMA_GRETYPE_RULE_REF: { | ||||||||||
const size_t rule_id = static_cast<size_t>(pos->value); | ||||||||||
const llama_grammar_element * subpos = rules[rule_id].data(); | ||||||||||
do { | ||||||||||
// init new stack without the top (pos) | ||||||||||
llama_grammar_stack new_stack(stack.begin(), stack.end() - 1); | ||||||||||
if (!llama_grammar_is_end_of_sequence(pos + 1)) { | ||||||||||
// if this rule ref is followed by another element, add that to stack | ||||||||||
new_stack.push_back(pos + 1); | ||||||||||
} | ||||||||||
if (!llama_grammar_is_end_of_sequence(subpos)) { | ||||||||||
// if alternate is nonempty, add to stack | ||||||||||
new_stack.push_back(subpos); | ||||||||||
} | ||||||||||
llama_grammar_advance_stack_memo(rules, new_stack, new_stacks, stacks_cache); | ||||||||||
while (!llama_grammar_is_end_of_sequence(subpos)) { | ||||||||||
// scan to end of alternate def | ||||||||||
subpos++; | ||||||||||
} | ||||||||||
if (subpos->type == LLAMA_GRETYPE_ALT) { | ||||||||||
// there's another alternate def of this rule to process | ||||||||||
subpos++; | ||||||||||
} else { | ||||||||||
break; | ||||||||||
} | ||||||||||
} while (true); | ||||||||||
break; | ||||||||||
} | ||||||||||
case LLAMA_GRETYPE_CHAR: | ||||||||||
case LLAMA_GRETYPE_CHAR_NOT: | ||||||||||
case LLAMA_GRETYPE_CHAR_ANY: | ||||||||||
if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) { | ||||||||||
// only add the stack if it's not a duplicate of one we already have | ||||||||||
new_stacks.emplace_back(stack); | ||||||||||
} | ||||||||||
break; | ||||||||||
default: | ||||||||||
// end of alternate (LLAMA_GRETYPE_END, LLAMA_GRETYPE_ALT) or middle of char range | ||||||||||
// (LLAMA_GRETYPE_CHAR_ALT, LLAMA_GRETYPE_CHAR_RNG_UPPER); stack should never be left on | ||||||||||
// those | ||||||||||
GGML_ABORT("fatal error"); | ||||||||||
} | ||||||||||
} | ||||||||||
|
||||||||||
static void llama_grammar_advance_stack_memo( | ||||||||||
const llama_grammar_rules & rules, | ||||||||||
const llama_grammar_stack & stack, | ||||||||||
llama_grammar_stacks & new_stacks, | ||||||||||
llama_grammar_stacks_cache & stacks_cache) { | ||||||||||
|
||||||||||
llama_grammar_stacks advanced_stacks; | ||||||||||
// Look if stack is already in memory | ||||||||||
auto it = stacks_cache.find(stack); | ||||||||||
if (it != stacks_cache.end()) { | ||||||||||
advanced_stacks = it->second; | ||||||||||
} else { | ||||||||||
// Advance stacks with memorization | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
llama_grammar_advance_stack_memo_impl(rules, stack, advanced_stacks, stacks_cache); | ||||||||||
stacks_cache.insert(make_pair(stack, advanced_stacks)); | ||||||||||
} | ||||||||||
// Add the advanced stacks to new_stacks avoiding duplicates | ||||||||||
for (const auto & new_stack : advanced_stacks) { | ||||||||||
if (std::find(new_stacks.begin(), new_stacks.end(), new_stack) == new_stacks.end()) { | ||||||||||
new_stacks.emplace_back(new_stack); | ||||||||||
} | ||||||||||
} | ||||||||||
|
||||||||||
} | ||||||||||
|
||||||||||
// transforms a grammar pushdown stack into N possible stacks, all ending | ||||||||||
// at a character range (terminal element) | ||||||||||
static void llama_grammar_advance_stack( | ||||||||||
|
@@ -822,11 +917,16 @@ llama_grammar_stacks & llama_grammar_get_stacks(struct llama_grammar * grammar) | |||||||||
return grammar->stacks; | ||||||||||
} | ||||||||||
|
||||||||||
llama_grammar_stacks_cache & llama_grammar_get_stacks_cache(struct llama_grammar * grammar) { | ||||||||||
return grammar->stacks_cache; | ||||||||||
} | ||||||||||
|
||||||||||
void llama_grammar_accept( | ||||||||||
const llama_grammar_rules & rules, | ||||||||||
const llama_grammar_stacks & stacks, | ||||||||||
const uint32_t chr, | ||||||||||
llama_grammar_stacks & stacks_new) { | ||||||||||
llama_grammar_stacks & stacks_new, | ||||||||||
llama_grammar_stacks_cache & stacks_cache) { | ||||||||||
stacks_new.clear(); | ||||||||||
stacks_new.reserve(stacks.size()); | ||||||||||
|
||||||||||
|
@@ -844,7 +944,7 @@ void llama_grammar_accept( | |||||||||
if (!llama_grammar_is_end_of_sequence(pos)) { | ||||||||||
new_stack.push_back(pos); | ||||||||||
} | ||||||||||
llama_grammar_advance_stack(rules, new_stack, stacks_new); | ||||||||||
llama_grammar_advance_stack_memo(rules, new_stack, stacks_new, stacks_cache); | ||||||||||
} | ||||||||||
} | ||||||||||
} | ||||||||||
|
@@ -938,14 +1038,15 @@ struct llama_grammar * llama_grammar_init_impl( | |||||||||
|
||||||||||
// loop over alternates of start rule to build initial stacks | ||||||||||
llama_grammar_stacks stacks; | ||||||||||
llama_grammar_stacks_cache stacks_cache; | ||||||||||
pos = vec_rules[start_rule_index].data(); | ||||||||||
do { | ||||||||||
llama_grammar_stack stack; | ||||||||||
if (!llama_grammar_is_end_of_sequence(pos)) { | ||||||||||
// if alternate is nonempty, add to stack | ||||||||||
stack.push_back(pos); | ||||||||||
} | ||||||||||
llama_grammar_advance_stack(vec_rules, stack, stacks); | ||||||||||
llama_grammar_advance_stack_memo(vec_rules, stack, stacks, stacks_cache); | ||||||||||
while (!llama_grammar_is_end_of_sequence(pos)) { | ||||||||||
// scan to end of alternate def | ||||||||||
pos++; | ||||||||||
|
@@ -961,7 +1062,7 @@ struct llama_grammar * llama_grammar_init_impl( | |||||||||
// Important: vec_rules has to be moved here, not copied, because stacks contains | ||||||||||
// pointers to elements of vec_rules. If vec_rules were copied into llama_grammar | ||||||||||
// then the pointers would be invalidated when the local vec_rules goes out of scope. | ||||||||||
return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, }; | ||||||||||
return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, std::move(stacks_cache), }; | ||||||||||
} | ||||||||||
|
||||||||||
struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root) { | ||||||||||
|
@@ -1016,14 +1117,15 @@ struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, | |||||||||
|
||||||||||
// loop over alternates of start rule to build initial stacks | ||||||||||
llama_grammar_stacks stacks; | ||||||||||
llama_grammar_stacks_cache stacks_cache; | ||||||||||
pos = vec_rules[start_rule_index].data(); | ||||||||||
do { | ||||||||||
llama_grammar_stack stack; | ||||||||||
if (!llama_grammar_is_end_of_sequence(pos)) { | ||||||||||
// if alternate is nonempty, add to stack | ||||||||||
stack.push_back(pos); | ||||||||||
} | ||||||||||
llama_grammar_advance_stack(vec_rules, stack, stacks); | ||||||||||
llama_grammar_advance_stack_memo(vec_rules, stack, stacks, stacks_cache); | ||||||||||
while (!llama_grammar_is_end_of_sequence(pos)) { | ||||||||||
// scan to end of alternate def | ||||||||||
pos++; | ||||||||||
|
@@ -1039,7 +1141,7 @@ struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, | |||||||||
// Important: vec_rules has to be moved here, not copied, because stacks contains | ||||||||||
// pointers to elements of vec_rules. If vec_rules were copied into llama_grammar | ||||||||||
// then the pointers would be invalidated when the local vec_rules goes out of scope. | ||||||||||
return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, }; | ||||||||||
return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, std::move(stacks_cache), }; | ||||||||||
} | ||||||||||
|
||||||||||
void llama_grammar_free_impl(struct llama_grammar * grammar) { | ||||||||||
|
@@ -1129,7 +1231,7 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token | |||||||||
llama_grammar_stacks stacks_new; | ||||||||||
|
||||||||||
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) { | ||||||||||
llama_grammar_accept(grammar.rules, grammar.stacks, *it, stacks_new); | ||||||||||
llama_grammar_accept(grammar.rules, grammar.stacks, *it, stacks_new, grammar.stacks_cache); | ||||||||||
grammar.stacks = std::move(stacks_new); | ||||||||||
} | ||||||||||
|
||||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,6 +3,7 @@ | |
#include "llama-impl.h" | ||
|
||
#include <map> | ||
#include <unordered_map> | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I can't think of any need to have a sorted key for this -- feels like |
||
|
||
struct llama_vocab; | ||
|
||
|
@@ -58,8 +59,21 @@ using llama_grammar_rules = std::vector<llama_grammar_rule>; | |
using llama_grammar_stacks = std::vector<llama_grammar_stack>; | ||
using llama_grammar_candidates = std::vector<llama_grammar_candidate>; | ||
|
||
struct VectorPointerHash { | ||
size_t operator()(const llama_grammar_stack & v) const { | ||
size_t seed = v.size(); | ||
for (const auto* ptr : v) { | ||
seed ^= std::hash<const llama_grammar_element*>()(ptr) + 0x9e3779b9 + (seed << 6) + (seed >> 2); | ||
} | ||
return seed; | ||
} | ||
}; | ||
|
||
using llama_grammar_stacks_cache = std::unordered_map<llama_grammar_stack, llama_grammar_stacks, VectorPointerHash>; | ||
|
||
const llama_grammar_rules & llama_grammar_get_rules (const struct llama_grammar * grammar); | ||
llama_grammar_stacks & llama_grammar_get_stacks( struct llama_grammar * grammar); | ||
llama_grammar_stacks_cache & llama_grammar_get_stacks_cache( struct llama_grammar * grammar); | ||
|
||
// takes a set of possible pushdown stacks on a grammar, which are required to | ||
// be positioned at a character range (see `llama_grammar_advance_stack`), and | ||
|
@@ -69,7 +83,8 @@ void llama_grammar_accept( | |
const llama_grammar_rules & rules, | ||
const llama_grammar_stacks & stacks, | ||
uint32_t chr, | ||
llama_grammar_stacks & stacks_new); | ||
llama_grammar_stacks & stacks_new, | ||
llama_grammar_stacks_cache & stacks_cache); | ||
|
||
std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_stack( | ||
const llama_grammar_rules & rules, | ||
|
@@ -115,6 +130,8 @@ struct llama_grammar { | |
|
||
// buffer for partially generated UTF-8 sequence from accepted tokens | ||
llama_partial_utf8 partial_utf8; | ||
// cache N possible stacks from a stack | ||
llama_grammar_stacks_cache stacks_cache; | ||
}; | ||
|
||
// | ||
|
Uh oh!
There was an error while loading. Please reload this page.