Skip to content

grammars: cache decoded token codepoints for faster sampling #6811

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 19 commits into from
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 42 additions & 15 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12727,6 +12727,10 @@ static std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_
}
}

if (next_candidates.empty()) {
return rejects;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is such a small and isolated change, I almost wonder if it shouldn't be pulled out into its own PR so that we can evaluate this performance improvement separate from the other one. As it is, it's difficult to know how much to attribute to each change...?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It happens to be the first commit on the branch so you can git checkout before / after it and compare performance as follows:

( export COMMON_ARGS=(
    -mu https://huggingface.co/NousResearch/Hermes-2-Pro-Mistral-7B-GGUF/resolve/main/Hermes-2-Pro-Mistral-7B.Q4_K_M.gguf
    -m models/Hermes-2-Pro-Mistral-7B.Q4_K_M.gguf
    --prompt-cache issue4218.bin
    --grammar-file issue4218.gbnf
    -f issue4218.txt
    -c 3400
  ) && \
  hyperfine --warmup 1 --runs 5 \
    -L branch 98f33bae767dd19e213ef663b22ad99979ca71d7^,98f33bae767dd19e213ef663b22ad99979ca71d7 \
    --setup "\
      git checkout {branch} && \
      make clean && make -j LLAMA_CURL=1 main && \
      rm -f issue4218.bin && \
      ./main ${COMMON_ARGS[*]} -n 1" \
    "BRANCH={branch} \
      ./main ${COMMON_ARGS[*]} -n 128 --prompt-cache-ro --seed 12345 --no-display-prompt" )
show output
Benchmark 1: BRANCH=98f33bae767dd19e213ef663b22ad99979ca71d7^       ./main -mu https://huggingface.co/NousResearch/Hermes-2-Pro-Mistral-7B-GGUF/resolve/main/Hermes-2-Pro-Mistral-7B.Q4_K_M.gguf -m models/Hermes-2-Pro-Mistral-7B.Q4_K_M.gguf --prompt-cache issue4218.bin --grammar-file issue4218.gbnf -f issue4218.txt -c 3400 -n 128 --prompt-cache-ro --seed 12345 --no-display-prompt
  Time (mean ± σ):      7.970 s ±  0.060 s    [User: 3.829 s, System: 0.292 s]
  Range (min … max):    7.877 s …  8.025 s    5 runs
 
Benchmark 2: BRANCH=98f33bae767dd19e213ef663b22ad99979ca71d7       ./main -mu https://huggingface.co/NousResearch/Hermes-2-Pro-Mistral-7B-GGUF/resolve/main/Hermes-2-Pro-Mistral-7B.Q4_K_M.gguf -m models/Hermes-2-Pro-Mistral-7B.Q4_K_M.gguf --prompt-cache issue4218.bin --grammar-file issue4218.gbnf -f issue4218.txt -c 3400 -n 128 --prompt-cache-ro --seed 12345 --no-display-prompt
  Time (mean ± σ):      5.814 s ±  0.037 s    [User: 1.674 s, System: 0.277 s]
  Range (min … max):    5.758 s …  5.857 s    5 runs
 
Summary
  'BRANCH=98f33bae767dd19e213ef663b22ad99979ca71d7       ./main -mu https://huggingface.co/NousResearch/Hermes-2-Pro-Mistral-7B-GGUF/resolve/main/Hermes-2-Pro-Mistral-7B.Q4_K_M.gguf -m models/Hermes-2-Pro-Mistral-7B.Q4_K_M.gguf --prompt-cache issue4218.bin --grammar-file issue4218.gbnf -f issue4218.txt -c 3400 -n 128 --prompt-cache-ro --seed 12345 --no-display-prompt' ran
    1.37 ± 0.01 times faster than 'BRANCH=98f33bae767dd19e213ef663b22ad99979ca71d7^       ./main -mu https://huggingface.co/NousResearch/Hermes-2-Pro-Mistral-7B-GGUF/resolve/main/Hermes-2-Pro-Mistral-7B.Q4_K_M.gguf -m models/Hermes-2-Pro-Mistral-7B.Q4_K_M.gguf --prompt-cache issue4218.bin --grammar-file issue4218.gbnf -f issue4218.txt -c 3400 -n 128 --prompt-cache-ro --seed 12345 --no-display-prompt'

It doesn't help w/ all grammars, though.

}

const auto * stack_pos_after = llama_grammar_match_char(stack_pos, 0).second;

// update top of stack to next element, if any
Expand Down Expand Up @@ -12804,26 +12808,32 @@ struct llama_grammar * llama_grammar_init(
}
} while (true);

return new llama_grammar{ std::move(vec_rules), std::move(stacks), {} };
return new llama_grammar{ std::move(vec_rules), std::move(stacks), {}, {}, {} };
}

void llama_grammar_free(struct llama_grammar * grammar) {
delete grammar;
}

struct llama_grammar * llama_grammar_copy(const struct llama_grammar * grammar) {
llama_grammar * result = new llama_grammar{ grammar->rules, grammar->stacks, grammar->partial_utf8 };
llama_grammar * result = new llama_grammar{ grammar->rules, grammar->stacks, grammar->partial_utf8, grammar->token_pieces, grammar->token_codepoints };

std::unordered_map<const llama_grammar_element *, const llama_grammar_element *> element_map;
element_map.reserve(std::accumulate(
grammar->rules.begin(), grammar->rules.end(), 0,
[](size_t acc, const std::vector<llama_grammar_element> & rule) {
return acc + rule.size();
}));
for (size_t ir = 0; ir < grammar->rules.size(); ir++) {
for (size_t ie = 0; ie < grammar->rules[ir].size(); ie++) {
element_map[&grammar->rules[ir][ie]] = &result->rules[ir][ie];
}
}

// redirect elements in stacks to point to new rules
for (size_t is = 0; is < result->stacks.size(); is++) {
for (size_t ie = 0; ie < result->stacks[is].size(); ie++) {
for (size_t ir0 = 0; ir0 < grammar->rules.size(); ir0++) {
for (size_t ir1 = 0; ir1 < grammar->rules[ir0].size(); ir1++) {
if (grammar->stacks[is][ie] == &grammar->rules[ir0][ir1]) {
result->stacks[is][ie] = &result->rules[ir0][ir1];
}
}
}
result->stacks[is][ie] = element_map.at(grammar->stacks[is][ie]);
}
}

Expand Down Expand Up @@ -13293,7 +13303,7 @@ void llama_sample_repetition_penalties(
}
}

void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * candidates, const struct llama_grammar * grammar) {
void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * candidates, struct llama_grammar * grammar) {
GGML_ASSERT(ctx);
const int64_t t_start_sample_us = ggml_time_us();

Expand All @@ -13305,21 +13315,36 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c
}
}

if (grammar->token_codepoints.empty()) {
auto n_vocab = llama_n_vocab(llama_get_model(ctx));
grammar->token_codepoints.resize(n_vocab);
grammar->token_pieces.resize(n_vocab);
for (llama_token id = 0; id < n_vocab; ++id) {
const std::string piece = llama_token_to_piece(ctx, id, false);
grammar->token_pieces[id] = piece;
grammar->token_codepoints[id] = decode_utf8(piece, {0, 0});
}
}

std::vector<std::pair<std::vector<uint32_t>, llama_partial_utf8>> candidates_decoded;
candidates_decoded.reserve(candidates->size);
if (grammar->partial_utf8.n_remain > 0) {
candidates_decoded.reserve(candidates->size);
}
std::vector<llama_grammar_candidate> candidates_grammar;
candidates_grammar.reserve(candidates->size);

for (size_t i = 0; i < candidates->size; ++i) {
const llama_token id = candidates->data[i].id;
const std::string piece = llama_token_to_piece(ctx, id, false);

const auto & piece = grammar->token_pieces[id];
if (llama_token_is_eog(&ctx->model, id)) {
if (!allow_eog) {
candidates->data[i].logit = -INFINITY;
}
} else if (piece.empty() || piece[0] == 0) {
candidates->data[i].logit = -INFINITY;
} else if (grammar->partial_utf8.n_remain == 0){
const auto & decoded = grammar->token_codepoints.at(id);
candidates_grammar.push_back({ i, decoded.first.data(), decoded.second });
} else {
candidates_decoded.push_back(decode_utf8(piece, grammar->partial_utf8));
candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second });
Expand Down Expand Up @@ -13513,10 +13538,12 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar
GGML_ASSERT(false);
}

const std::string piece = llama_token_to_piece(ctx, token, false);
const auto & piece = grammar->token_pieces.at(token);

// Note terminating 0 in decoded string
const auto decoded = decode_utf8(piece, grammar->partial_utf8);
const auto decoded = grammar->partial_utf8.n_remain == 0
? grammar->token_codepoints[token]
: decode_utf8(piece, grammar->partial_utf8);
const auto & code_points = decoded.first;
std::vector<std::vector<const llama_grammar_element *>> tmp_new_stacks;
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
Expand Down
7 changes: 6 additions & 1 deletion llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -953,7 +953,7 @@ extern "C" {
LLAMA_API void llama_sample_grammar(
struct llama_context * ctx,
llama_token_data_array * candidates,
const struct llama_grammar * grammar);
struct llama_grammar * grammar);

/// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
Expand Down Expand Up @@ -1090,6 +1090,11 @@ struct llama_grammar {

// buffer for partially generated UTF-8 sequence from accepted tokens
llama_partial_utf8 partial_utf8;

// caching the token pieces & their decoded codepoints.
std::vector<std::string> token_pieces;
std::vector<std::pair<std::vector<uint32_t>,
llama_partial_utf8>> token_codepoints;
};

struct llama_grammar_candidate {
Expand Down