From a4556f0f2d4200fb213f2d5b27d41f6abc9a1d36 Mon Sep 17 00:00:00 2001 From: ochafik Date: Mon, 20 May 2024 22:22:29 +0100 Subject: [PATCH 1/2] grammars: fix resampling logic --- common/sampling.cpp | 14 ++++++++------ examples/main/main.cpp | 4 ++-- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index f0f1b92d37f59..5583c52588688 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -179,7 +179,7 @@ static llama_token llama_sampling_sample_impl( struct llama_context * ctx_main, struct llama_context * ctx_cfg, const int idx, - bool is_resampling) { // Add a parameter to indicate if we are resampling + bool is_resampling) { const llama_sampling_params & params = ctx_sampling->params; const float temp = params.temp; @@ -188,8 +188,8 @@ static llama_token llama_sampling_sample_impl( const float mirostat_eta = params.mirostat_eta; std::vector original_logits; - auto cur_p = llama_sampling_prepare(ctx_sampling, ctx_main, ctx_cfg, idx, !is_resampling, &original_logits); - if (!is_resampling) { + auto cur_p = llama_sampling_prepare(ctx_sampling, ctx_main, ctx_cfg, idx, /* apply_grammar= */ is_resampling, &original_logits); + if (ctx_sampling->grammar != NULL && !is_resampling) { GGML_ASSERT(!original_logits.empty()); } llama_token id = 0; @@ -252,7 +252,7 @@ static llama_token llama_sampling_sample_impl( // Restore logits from the copy std::copy(original_logits.begin(), original_logits.end(), logits); - return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, true); // Pass true for is_resampling + return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, /* is_resampling= */ true); } } @@ -285,8 +285,10 @@ static llama_token_data_array llama_sampling_prepare_impl( // Get a pointer to the logits float * logits = llama_get_logits_ith(ctx_main, idx); - if (apply_grammar && original_logits != NULL) { + if (ctx_sampling->grammar != NULL && !apply_grammar) { + GGML_ASSERT(original_logits != NULL); // Only make a copy of the original logits if we are not applying grammar checks, not sure if I actually have to do this. + // TODO: if idx >= 0 then use ctx->output_ids.size() as upper bound? *original_logits = {logits, logits + llama_n_vocab(llama_get_model(ctx_main))}; } @@ -342,7 +344,7 @@ llama_token llama_sampling_sample( struct llama_context * ctx_cfg, const int idx) { // Call the implementation function with is_resampling set to false by default - return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, false); + return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, /* is_resampling= */ false); } llama_token_data_array llama_sampling_prepare( diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 9dee41001f12c..832b51ee086be 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -707,7 +707,7 @@ int main(int argc, char ** argv) { const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance); - llama_sampling_accept(ctx_sampling, ctx, id, true); + llama_sampling_accept(ctx_sampling, ctx, id, /* apply_grammar= */ true); LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev).c_str()); @@ -728,7 +728,7 @@ int main(int argc, char ** argv) { // push the prompt in the sampling context in order to apply repetition penalties later // for the prompt, we don't apply grammar rules - llama_sampling_accept(ctx_sampling, ctx, embd_inp[n_consumed], false); + llama_sampling_accept(ctx_sampling, ctx, embd_inp[n_consumed], /* apply_grammar= */ false); ++n_consumed; if ((int) embd.size() >= params.n_batch) { From fbe6bc5c99ce52a552c2269633ab449ded468fb0 Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Tue, 21 May 2024 15:27:44 +0100 Subject: [PATCH 2/2] grammars: remove todo --- common/sampling.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index 5583c52588688..7fc2e2158d5c4 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -288,7 +288,6 @@ static llama_token_data_array llama_sampling_prepare_impl( if (ctx_sampling->grammar != NULL && !apply_grammar) { GGML_ASSERT(original_logits != NULL); // Only make a copy of the original logits if we are not applying grammar checks, not sure if I actually have to do this. - // TODO: if idx >= 0 then use ctx->output_ids.size() as upper bound? *original_logits = {logits, logits + llama_n_vocab(llama_get_model(ctx_main))}; }