Skip to content

grammars: cache decoded token codepoints for faster sampling #6811

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 19 commits into from
Closed
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 33 additions & 5 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2328,6 +2328,11 @@ struct llama_context {
// control vectors
struct llama_control_vector cvec;

// caching token pieces & their decoded codepoints.
std::vector<std::string> token_pieces;
std::vector<std::pair<std::vector<uint32_t>, llama_partial_utf8>>
token_codepoints_without_partial_utf8_prefix;

#ifdef GGML_USE_MPI
ggml_mpi_context * ctx_mpi = NULL;
#endif
Expand Down Expand Up @@ -13037,6 +13042,10 @@ static std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_
}
}

if (next_candidates.empty()) {
return rejects;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is such a small and isolated change, I almost wonder if it shouldn't be pulled out into its own PR so that we can evaluate this performance improvement separate from the other one. As it is, it's difficult to know how much to attribute to each change...?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It happens to be the first commit on the branch so you can git checkout before / after it and compare performance as follows:

( export COMMON_ARGS=(
    -mu https://huggingface.co/NousResearch/Hermes-2-Pro-Mistral-7B-GGUF/resolve/main/Hermes-2-Pro-Mistral-7B.Q4_K_M.gguf
    -m models/Hermes-2-Pro-Mistral-7B.Q4_K_M.gguf
    --prompt-cache issue4218.bin
    --grammar-file issue4218.gbnf
    -f issue4218.txt
    -c 3400
  ) && \
  hyperfine --warmup 1 --runs 5 \
    -L branch 98f33bae767dd19e213ef663b22ad99979ca71d7^,98f33bae767dd19e213ef663b22ad99979ca71d7 \
    --setup "\
      git checkout {branch} && \
      make clean && make -j LLAMA_CURL=1 main && \
      rm -f issue4218.bin && \
      ./main ${COMMON_ARGS[*]} -n 1" \
    "BRANCH={branch} \
      ./main ${COMMON_ARGS[*]} -n 128 --prompt-cache-ro --seed 12345 --no-display-prompt" )
show output
Benchmark 1: BRANCH=98f33bae767dd19e213ef663b22ad99979ca71d7^       ./main -mu https://huggingface.co/NousResearch/Hermes-2-Pro-Mistral-7B-GGUF/resolve/main/Hermes-2-Pro-Mistral-7B.Q4_K_M.gguf -m models/Hermes-2-Pro-Mistral-7B.Q4_K_M.gguf --prompt-cache issue4218.bin --grammar-file issue4218.gbnf -f issue4218.txt -c 3400 -n 128 --prompt-cache-ro --seed 12345 --no-display-prompt
  Time (mean ± σ):      7.970 s ±  0.060 s    [User: 3.829 s, System: 0.292 s]
  Range (min … max):    7.877 s …  8.025 s    5 runs
 
Benchmark 2: BRANCH=98f33bae767dd19e213ef663b22ad99979ca71d7       ./main -mu https://huggingface.co/NousResearch/Hermes-2-Pro-Mistral-7B-GGUF/resolve/main/Hermes-2-Pro-Mistral-7B.Q4_K_M.gguf -m models/Hermes-2-Pro-Mistral-7B.Q4_K_M.gguf --prompt-cache issue4218.bin --grammar-file issue4218.gbnf -f issue4218.txt -c 3400 -n 128 --prompt-cache-ro --seed 12345 --no-display-prompt
  Time (mean ± σ):      5.814 s ±  0.037 s    [User: 1.674 s, System: 0.277 s]
  Range (min … max):    5.758 s …  5.857 s    5 runs
 
Summary
  'BRANCH=98f33bae767dd19e213ef663b22ad99979ca71d7       ./main -mu https://huggingface.co/NousResearch/Hermes-2-Pro-Mistral-7B-GGUF/resolve/main/Hermes-2-Pro-Mistral-7B.Q4_K_M.gguf -m models/Hermes-2-Pro-Mistral-7B.Q4_K_M.gguf --prompt-cache issue4218.bin --grammar-file issue4218.gbnf -f issue4218.txt -c 3400 -n 128 --prompt-cache-ro --seed 12345 --no-display-prompt' ran
    1.37 ± 0.01 times faster than 'BRANCH=98f33bae767dd19e213ef663b22ad99979ca71d7^       ./main -mu https://huggingface.co/NousResearch/Hermes-2-Pro-Mistral-7B-GGUF/resolve/main/Hermes-2-Pro-Mistral-7B.Q4_K_M.gguf -m models/Hermes-2-Pro-Mistral-7B.Q4_K_M.gguf --prompt-cache issue4218.bin --grammar-file issue4218.gbnf -f issue4218.txt -c 3400 -n 128 --prompt-cache-ro --seed 12345 --no-display-prompt'

It doesn't help w/ all grammars, though.

}

const auto * stack_pos_after = llama_grammar_match_char(stack_pos, 0).second;

// update top of stack to next element, if any
Expand Down Expand Up @@ -13615,21 +13624,26 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c
}
}

// Store decoded codepoints when they are not cached (happens when there's a partial utf8 string prefix).
std::vector<std::pair<std::vector<uint32_t>, llama_partial_utf8>> candidates_decoded;
candidates_decoded.reserve(candidates->size);
if (grammar->partial_utf8.n_remain > 0) {
candidates_decoded.reserve(candidates->size);
}
std::vector<llama_grammar_candidate> candidates_grammar;
candidates_grammar.reserve(candidates->size);

for (size_t i = 0; i < candidates->size; ++i) {
const llama_token id = candidates->data[i].id;
const std::string piece = llama_token_to_piece(ctx, id, false);

const auto & piece = ctx->token_pieces[id];
if (llama_token_is_eog(&ctx->model, id)) {
if (!allow_eog) {
candidates->data[i].logit = -INFINITY;
}
} else if (piece.empty() || piece[0] == 0) {
candidates->data[i].logit = -INFINITY;
} else if (grammar->partial_utf8.n_remain == 0){
const auto & decoded = ctx->token_codepoints_without_partial_utf8_prefix.at(id);
candidates_grammar.push_back({ i, decoded.first.data(), decoded.second });
} else {
candidates_decoded.push_back(decode_utf8(piece, grammar->partial_utf8));
candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second });
Expand Down Expand Up @@ -13826,10 +13840,12 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar
GGML_ASSERT(false);
}

const std::string piece = llama_token_to_piece(ctx, token, false);
const auto & piece = ctx->token_pieces.at(token);

// Note terminating 0 in decoded string
const auto decoded = decode_utf8(piece, grammar->partial_utf8);
const auto decoded = grammar->partial_utf8.n_remain == 0
? ctx->token_codepoints_without_partial_utf8_prefix[token]
: decode_utf8(piece, grammar->partial_utf8);
const auto & code_points = decoded.first;
std::vector<std::vector<const llama_grammar_element *>> tmp_new_stacks;
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
Expand Down Expand Up @@ -15714,6 +15730,18 @@ struct llama_context * llama_new_context_with_model(
}
}

// cache tokens & their decoded codepoints (for common case where there's no partial utf8 prefix bytes) for grammar-constrained sampling.
{
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only reservation that I have is that we're caching these data structures whether there is a grammar or not. These data structures are only used in llama_sample_grammar and llama_grammar_accept_token, so if there is no grammar, then this is wasted time (and perhaps more importantly, memory).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair point! A “grammars_enabled” context param might make sense, with a flag to turn it off in server (+ explode when grammar requested), and turned on in main only when grammar or schema flag set. Would save a couple of MBs and a tiny bit of preprocessing. Will add tonight 👌

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could also perhaps restructure this cache building that happens as a memorization step that happens inside of the first calls to decode_utf8 / llama_token_to_piece, rather than when the context is built.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That’s what I initially did (and was storing it in the grammar itself, which was pointed out as a not ideal by @ejones ). Lazy init comes with big concurrency concerns (say two slots start working with grammars at the same time in the server: can’t let them both populate the data lazily if it’s in the context)

Copy link
Collaborator Author

@ochafik ochafik May 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Oh but then, no sure anymore whether context is slot-specific? I’ll look again tonight 😅)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Oh but then, no sure anymore whether context is slot-specific? I’ll look again tonight 😅)

Yeah context is shared. And as @ejones pointed out grammar itself technically could be shared in other contexts.

I played with the idea of allowing to disable the preprocessing (here) but I haven't found a good way to articulate the flag wrt/ all the ways the API can be used. I think it'd be simpler to leave it as is and keep it as an area where to potentially squeeze a couple of MB when times are scarce. wdyt?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apologies in advance for this comment -- this is very long, and very stream-of-consciousness, so please take everything with a grain of salt. However, I want to keep moving forward on this PR, so I figure at least some reply is better than nothing -- even if it's a bit rambly. :)

I played with the idea of allowing to disable the preprocessing (here) but I haven't found a good way to articulate the flag wrt/ all the ways the API can be used.

Yeah, that's a tricky question. I wonder if instead of a command-line flag, if the precache step at the end of llama_new_context_with_model should maybe be broken out into a separate call -- something like llama_context_precache_grammar_token_pieces? Then, whoever is initializing the context can also add precaching to it...? calling the API can enable this pre-caching feature if they want to or not, but it takes an extra call...?

So I'm back to wondering if there is a way where we can still break this precache functionality into a separate function. Can we check for the need of it in llama_sample_grammar / llama_grammar_accept_token and call the precaching there as-needed?

if (ctx->token_pieces.empty()) {
    llama_context_precache_grammar_token_pieces(ctx);
}
const auto & piece = ctx->token_pieces.at(token);
...

It's annoying to run that check on every token, but the performance improvements when using a grammar should still win out. And we also remove the extra memory usage in the case of no grammar being used.

But as I'm thinking through this now, I think what you said earlier is the problem of ctx being shared across multiple threads / processes, so at the very least, we would need to add a blocking mutex here. Is that feasible / reasonable? We do a number of other little mutexes scattered throughout the code in other places, so it feels like it wouldn't be awful here...?

So yeah, I'm still on the fence about making it as a separate call that needs to be called manually if one wants the performance boost (maybe with an automatic fallback mechanism in the sampler), or else transparent caching on first call at runtime with a mutex (provided we can implement this without it being too annoying to run this check inside such a tight loop).

I think it'd be simpler to leave it as is and keep it as an area where to potentially squeeze a couple of MB when times are scarce. wdyt?

This also sounds not unreasonable, but I don't know how to weigh such things. I know I really like grammar-constrained sampling, but I don't know how popular the feature is overall, and is it worth negatively impacting hyper-resource-constrained usages (such as Raspberry Pis or whatnot) vs. grammars? That's what I'm unable to weigh -- I feel like that's a strategic decision that's a bit above my level.

Copy link

@gabriel-peracio gabriel-peracio May 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know how popular the feature is overall

User here: I've been finding the grammar feature extremely useful and can't live without it. There are also no substitutes anywhere, the competitor's solutions are worse than llama.cpp's implementation IMO

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@HanClinto thanks++ for your reply & sorry it took me so long to get back to it, the Spring has been full of distractions 😅

I had somehow written concurrent synchronization off but as you mention it, seems worth exploring, looking now!

@gabriel-peracio same vibe here, realized I couldn't live without it and then that I couldn't stand how slow it was when pushed to its limits haha! It's nearly there :-)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@HanClinto I've moved the caching back to llama_sample_grammar in lazy & mutex-protected form (thanks for the suggestion!). And sent the early exit change separately -> #7370

auto n_vocab = llama_n_vocab(llama_get_model(ctx));
ctx->token_codepoints_without_partial_utf8_prefix.resize(n_vocab);
ctx->token_pieces.resize(n_vocab);
for (llama_token id = 0; id < n_vocab; ++id) {
const std::string piece = llama_token_to_piece(ctx, id, false);
ctx->token_pieces[id] = piece;
ctx->token_codepoints_without_partial_utf8_prefix[id] = decode_utf8(piece, {0, 0});
}
}

#ifdef GGML_USE_MPI
ctx->ctx_mpi = ggml_mpi_init();

Expand Down