From 291a7855873c89be709868fbae1d076bc37bfde3 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Tue, 22 Oct 2024 15:03:00 +0200 Subject: [PATCH 1/2] llama : rename batch.logits to batch.output This commit renames the `logits` field of the `llama_batch` struct to `output`. The motivation for this change (apart from the TODO comment) is that the `logits` field is actually used to specify that output should be generated. For example, in the case of generating embeddings, setting logits to true can be confusing since the logits are not used when generating embeddings. --- common/common.cpp | 6 +++--- examples/batched-bench/batched-bench.cpp | 4 ++-- examples/batched.swift/Sources/main.swift | 6 +++--- examples/batched/batched.cpp | 2 +- examples/embedding/embedding.cpp | 2 +- .../llama/src/main/cpp/llama-android.cpp | 6 +++--- .../llama.cpp.swift/LibLlama.swift | 8 ++++---- examples/llava/llava.cpp | 8 ++++---- examples/parallel/parallel.cpp | 4 ++-- examples/passkey/passkey.cpp | 4 ++-- examples/perplexity/perplexity.cpp | 14 +++++++------- examples/retrieval/retrieval.cpp | 2 +- examples/save-load-state/save-load-state.cpp | 2 +- examples/server/server.cpp | 8 ++++---- examples/tts/tts.cpp | 2 +- include/llama.h | 2 +- src/llama-batch.cpp | 18 +++++++++--------- src/llama-batch.h | 2 +- src/llama.cpp | 5 ++--- 19 files changed, 52 insertions(+), 53 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 8661e164ada6b..859e726afb1b8 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -607,7 +607,7 @@ std::string string_from(const struct llama_context * ctx, const struct llama_bat << ", pos " << std::to_string(batch.pos[i]) << ", n_seq_id " << std::to_string(batch.n_seq_id[i]) << ", seq_id " << std::to_string(batch.seq_id[i][0]) - << ", logits " << std::to_string(batch.logits[i]); + << ", output " << std::to_string(batch.output[i]); } buf << " ]"; @@ -1617,7 +1617,7 @@ void common_batch_add( llama_token id, llama_pos pos, const std::vector & seq_ids, - bool logits) { + bool output) { GGML_ASSERT(batch.seq_id[batch.n_tokens] && "llama_batch size exceeded"); batch.token [batch.n_tokens] = id; @@ -1626,7 +1626,7 @@ void common_batch_add( for (size_t i = 0; i < seq_ids.size(); ++i) { batch.seq_id[batch.n_tokens][i] = seq_ids[i]; } - batch.logits [batch.n_tokens] = logits; + batch.output [batch.n_tokens] = output; batch.n_tokens++; } diff --git a/examples/batched-bench/batched-bench.cpp b/examples/batched-bench/batched-bench.cpp index 0659ab6f119a7..1f1c956274f57 100644 --- a/examples/batched-bench/batched-bench.cpp +++ b/examples/batched-bench/batched-bench.cpp @@ -73,7 +73,7 @@ int main(int argc, char ** argv) { batch.pos + i, batch.n_seq_id + i, batch.seq_id + i, - batch.logits + i, + batch.output + i, }; const int ret = llama_decode(ctx, batch_view); @@ -128,7 +128,7 @@ int main(int argc, char ** argv) { common_batch_add(batch, 0, i, { j }, false); } } - batch.logits[batch.n_tokens - 1] = true; + batch.output[batch.n_tokens - 1] = true; const auto t_pp_start = ggml_time_us(); diff --git a/examples/batched.swift/Sources/main.swift b/examples/batched.swift/Sources/main.swift index 55c31166ca278..18b6a21d8ca49 100644 --- a/examples/batched.swift/Sources/main.swift +++ b/examples/batched.swift/Sources/main.swift @@ -104,11 +104,11 @@ for (i, token) in tokens.enumerated() { if let seq_id = batch.seq_id[i] { seq_id[0] = 0 } - batch.logits[i] = 0 + batch.output[i] = 0 } // llama_decode will output logits only for the last token of the prompt -batch.logits[Int(batch.n_tokens) - 1] = 1 +batch.output[Int(batch.n_tokens) - 1] = 1 if llama_decode(context, batch) != 0 { print("llama_decode() failed") @@ -171,7 +171,7 @@ while n_cur <= n_len { if let seq_id = batch.seq_id[Int(batch.n_tokens)] { seq_id[0] = Int32(i) } - batch.logits[Int(batch.n_tokens)] = 1 + batch.output[Int(batch.n_tokens)] = 1 i_batch[i] = batch.n_tokens diff --git a/examples/batched/batched.cpp b/examples/batched/batched.cpp index 21b95ef5e4e83..7d2a82b518099 100644 --- a/examples/batched/batched.cpp +++ b/examples/batched/batched.cpp @@ -131,7 +131,7 @@ int main(int argc, char ** argv) { } // llama_decode will output logits only for the last token of the prompt - batch.logits[batch.n_tokens - 1] = true; + batch.output[batch.n_tokens - 1] = true; if (llama_decode(ctx, batch) != 0) { LOG_ERR("%s: llama_decode() failed\n", __func__); diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index 38d22c90f82bb..95445b5ef68d3 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -54,7 +54,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu } for (int i = 0; i < batch.n_tokens; i++) { - if (!batch.logits[i]) { + if (!batch.output[i]) { continue; } diff --git a/examples/llama.android/llama/src/main/cpp/llama-android.cpp b/examples/llama.android/llama/src/main/cpp/llama-android.cpp index 2a73983a9832f..1718d6b4f525d 100644 --- a/examples/llama.android/llama/src/main/cpp/llama-android.cpp +++ b/examples/llama.android/llama/src/main/cpp/llama-android.cpp @@ -193,7 +193,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model( common_batch_add(*batch, 0, i, { 0 }, false); } - batch->logits[batch->n_tokens - 1] = true; + batch->output[batch->n_tokens - 1] = true; llama_kv_cache_clear(context); const auto t_pp_start = ggml_time_us(); @@ -297,7 +297,7 @@ Java_android_llama_cpp_LLamaAndroid_new_1batch(JNIEnv *, jobject, jint n_tokens, for (int i = 0; i < n_tokens; ++i) { batch->seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max); } - batch->logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens); + batch->output = (int8_t *) malloc(sizeof(int8_t) * n_tokens); return reinterpret_cast(batch); } @@ -381,7 +381,7 @@ Java_android_llama_cpp_LLamaAndroid_completion_1init( } // llama_decode will output logits only for the last token of the prompt - batch->logits[batch->n_tokens - 1] = true; + batch->output[batch->n_tokens - 1] = true; if (llama_decode(context, *batch) != 0) { LOGe("llama_decode() failed"); diff --git a/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift b/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift index ee7141a663224..dfece7761e9e1 100644 --- a/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift +++ b/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift @@ -9,14 +9,14 @@ func llama_batch_clear(_ batch: inout llama_batch) { batch.n_tokens = 0 } -func llama_batch_add(_ batch: inout llama_batch, _ id: llama_token, _ pos: llama_pos, _ seq_ids: [llama_seq_id], _ logits: Bool) { +func llama_batch_add(_ batch: inout llama_batch, _ id: llama_token, _ pos: llama_pos, _ seq_ids: [llama_seq_id], _ outputs: Bool) { batch.token [Int(batch.n_tokens)] = id batch.pos [Int(batch.n_tokens)] = pos batch.n_seq_id[Int(batch.n_tokens)] = Int32(seq_ids.count) for i in 0.. n_seq_id; std::vector seq_id_0; std::vector seq_ids; - std::vector logits; + std::vector outputs; llama_batch batch; llava_embd_batch(float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) { pos .resize(n_tokens); n_seq_id.resize(n_tokens); seq_ids .resize(n_tokens + 1); - logits .resize(n_tokens); + outputs .resize(n_tokens); seq_id_0.resize(1); seq_id_0[0] = seq_id; seq_ids [n_tokens] = nullptr; @@ -458,13 +458,13 @@ struct llava_embd_batch { /*pos =*/ pos.data(), /*n_seq_id =*/ n_seq_id.data(), /*seq_id =*/ seq_ids.data(), - /*logits =*/ logits.data(), + /*output =*/ outputs.data(), }; for (int i = 0; i < n_tokens; i++) { batch.pos [i] = pos_0 + i; batch.n_seq_id[i] = 1; batch.seq_id [i] = seq_id_0.data(); - batch.logits [i] = false; + batch.output [i] = false; } } }; diff --git a/examples/parallel/parallel.cpp b/examples/parallel/parallel.cpp index 7ef43d5e12876..3f87c0a1aa53e 100644 --- a/examples/parallel/parallel.cpp +++ b/examples/parallel/parallel.cpp @@ -266,7 +266,7 @@ int main(int argc, char ** argv) { // extract the logits only for the last token if (batch.n_tokens > 0) { - batch.logits[batch.n_tokens - 1] = true; + batch.output[batch.n_tokens - 1] = true; } client.n_prompt = tokens_prompt.size(); @@ -309,7 +309,7 @@ int main(int argc, char ** argv) { batch.pos + i, batch.n_seq_id + i, batch.seq_id + i, - batch.logits + i, + batch.output + i, }; const int ret = llama_decode(ctx, batch_view); diff --git a/examples/passkey/passkey.cpp b/examples/passkey/passkey.cpp index 5953928d47d33..15f99bcdd9087 100644 --- a/examples/passkey/passkey.cpp +++ b/examples/passkey/passkey.cpp @@ -146,7 +146,7 @@ int main(int argc, char ** argv) { } if (i + n_batch >= n_tokens_all) { - batch.logits[batch.n_tokens - 1] = true; + batch.output[batch.n_tokens - 1] = true; } if (llama_decode(ctx, batch) != 0) { @@ -180,7 +180,7 @@ int main(int argc, char ** argv) { } if (i + n_batch >= n_tokens_all) { - batch.logits[batch.n_tokens - 1] = true; + batch.output[batch.n_tokens - 1] = true; } if (llama_decode(ctx, batch) != 0) { diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 9bf6c57433ab2..2b194b8d9bc74 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -572,9 +572,9 @@ static results_perplexity perplexity(llama_context * ctx, const common_params & batch.pos [idx] = j*n_batch + k; batch.n_seq_id[idx] = 1; batch.seq_id [idx][0] = seq; - batch.logits [idx] = batch.pos[idx] >= first ? 1 : 0; + batch.output [idx] = batch.pos[idx] >= first ? 1 : 0; - n_outputs += batch.logits[idx] != 0; + n_outputs += batch.output[idx] != 0; } batch.n_tokens += batch_size; @@ -669,7 +669,7 @@ static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector< batch.pos + i, batch.n_seq_id + i, batch.seq_id + i, - batch.logits + i, + batch.output + i, }; const int ret = llama_decode(ctx, batch_view); @@ -680,7 +680,7 @@ static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector< int n_outputs = 0; for (int i = 0; i < n_tokens; ++i) { - n_outputs += batch_view.logits[i] != 0; + n_outputs += batch_view.output[i] != 0; } memcpy(batch_logits.data() + size_t(prev_outputs)*n_vocab, llama_get_logits(ctx), size_t(n_outputs)*n_vocab*sizeof(float)); @@ -896,7 +896,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) { for (size_t i = 0; i < hs_cur.common_prefix; ++i) { common_batch_add(batch, hs_cur.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3 }, false); } - batch.logits[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix + batch.output[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix n_logits += 1; for (int s = 0; s < 4; ++s) { @@ -1177,7 +1177,7 @@ static void winogrande_score(llama_context * ctx, const common_params & params) for (size_t i = 0; i < data[i1].common_prefix; ++i) { common_batch_add(batch, data[i1].seq_tokens[0][i], i, { s0 + 0, s0 + 1 }, false); } - batch.logits[batch.n_tokens - 1] = true; + batch.output[batch.n_tokens - 1] = true; n_logits += 1; for (int s = 0; s < 2; ++s) { @@ -1545,7 +1545,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par //llama_batch_add(batch, cur_task.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3}, false); common_batch_add(batch, cur_task.seq_tokens[0][i], i, batch_indeces, false); } - batch.logits[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix + batch.output[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix n_logits += 1; for (int s = 0; s < int(cur_task.seq_tokens.size()); ++s) { diff --git a/examples/retrieval/retrieval.cpp b/examples/retrieval/retrieval.cpp index 2439022a229b7..2c5b5e4862228 100644 --- a/examples/retrieval/retrieval.cpp +++ b/examples/retrieval/retrieval.cpp @@ -92,7 +92,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu } for (int i = 0; i < batch.n_tokens; i++) { - if (!batch.logits[i]) { + if (!batch.output[i]) { continue; } diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp index cf7cbd8159cf8..2e5a2b5181eff 100644 --- a/examples/save-load-state/save-load-state.cpp +++ b/examples/save-load-state/save-load-state.cpp @@ -52,7 +52,7 @@ int main(int argc, char ** argv) { for (size_t i = 0; i < tokens.size(); i++) { common_batch_add(batch, tokens[i], i, {0}, false); } - batch.logits[batch.n_tokens - 1] = true; // generate next token + batch.output[batch.n_tokens - 1] = true; // generate next token // evaluate prompt llama_decode(ctx, batch); diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 9cdf2058fd037..f6642e5c820da 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -2413,7 +2413,7 @@ struct server_context { std::vector embd_res(n_embd, 0.0f); for (int i = 0; i < batch.n_tokens; ++i) { - if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) { + if (!batch.output[i] || batch.seq_id[i][0] != slot.id) { continue; } @@ -2451,7 +2451,7 @@ struct server_context { res->n_tokens = slot.n_prompt_tokens; for (int i = 0; i < batch.n_tokens; ++i) { - if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) { + if (!batch.output[i] || batch.seq_id[i][0] != slot.id) { continue; } @@ -3109,7 +3109,7 @@ struct server_context { } // extract the logits only for the last token - batch.logits[batch.n_tokens - 1] = true; + batch.output[batch.n_tokens - 1] = true; slot.n_decoded = 0; slot.i_batch = batch.n_tokens - 1; @@ -3149,7 +3149,7 @@ struct server_context { batch.pos + i, batch.n_seq_id + i, batch.seq_id + i, - batch.logits + i, + batch.output + i, }; const int ret = llama_decode(ctx, batch_view); diff --git a/examples/tts/tts.cpp b/examples/tts/tts.cpp index f78f763033a23..f700229853834 100644 --- a/examples/tts/tts.cpp +++ b/examples/tts/tts.cpp @@ -722,7 +722,7 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14 GGML_ASSERT(batch.n_tokens == (int) prompt_inp.size()); // llama_decode will output logits only for the last token of the prompt - batch.logits[batch.n_tokens - 1] = true; + batch.output[batch.n_tokens - 1] = true; if (llama_decode(ctx_ttc, batch) != 0) { LOG_ERR("%s: llama_decode() failed\n", __func__); diff --git a/include/llama.h b/include/llama.h index 61907ed404dbf..516953a729fb2 100644 --- a/include/llama.h +++ b/include/llama.h @@ -252,7 +252,7 @@ extern "C" { llama_pos * pos; int32_t * n_seq_id; llama_seq_id ** seq_id; - int8_t * logits; // TODO: rename this to "output" + int8_t * output; } llama_batch; enum llama_model_kv_override_type { diff --git a/src/llama-batch.cpp b/src/llama-batch.cpp index 01d5ca57fd82b..ba2127be66b6e 100644 --- a/src/llama-batch.cpp +++ b/src/llama-batch.cpp @@ -102,17 +102,17 @@ void llama_sbatch::add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & s ubatch.output[ubatch.n_tokens + i] = 1; out_ids.push_back(ids[seq.offset + i]); } - } else if (batch->logits) { + } else if (batch->output) { if (ubatch.equal_seqs) { for (size_t i = 0; i < length; ++i) { size_t id = ids[seq.offset + i]; - int8_t is_output = batch->logits[id]; + int8_t is_output = batch->output[id]; ubatch.output[ubatch.n_tokens + i] = is_output; if (is_output) { out_ids.push_back(id); } } } else { // simple split - ubatch.output = batch->logits + seq.offset; + ubatch.output = batch->output + seq.offset; for (size_t i = 0; i < length; ++i) { if (ubatch.output[i] != 0) { out_ids.push_back(seq.offset + i); } } @@ -298,10 +298,10 @@ llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0 } batch.seq_id = seq_id.data(); } - if (!batch.logits) { - logits.resize(batch.n_tokens); - logits[logits.size() - 1] = true; - batch.logits = logits.data(); + if (!batch.output) { + outputs.resize(batch.n_tokens); + outputs[outputs.size() - 1] = true; + batch.output = outputs.data(); } } @@ -348,7 +348,7 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_ } batch.seq_id[n_tokens_alloc] = nullptr; - batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens_alloc); + batch.output = (int8_t *) malloc(sizeof(int8_t) * n_tokens_alloc); return batch; } @@ -364,5 +364,5 @@ void llama_batch_free(struct llama_batch batch) { } free(batch.seq_id); } - if (batch.logits) free(batch.logits); + if (batch.output) free(batch.output); } diff --git a/src/llama-batch.h b/src/llama-batch.h index 773c3808b770f..002a8a62f844a 100644 --- a/src/llama-batch.h +++ b/src/llama-batch.h @@ -81,7 +81,7 @@ struct llama_batch_allocr { std::vector pos; std::vector n_seq_id; std::vector seq_id; - std::vector logits; + std::vector outputs; // optionally fulfill the batch returned by llama_batch_get_one llama_batch_allocr(struct llama_batch in_batch, llama_pos p0); diff --git a/src/llama.cpp b/src/llama.cpp index aae3c69b5a653..e24c39465c41c 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -8473,9 +8473,9 @@ static int llama_prepare_sbatch( lctx.embd_seq.clear(); // count outputs - if (batch.logits && !embd_pooled) { + if (batch.output && !embd_pooled) { for (uint32_t i = 0; i < n_tokens_all; ++i) { - n_outputs += batch.logits[i] != 0; + n_outputs += batch.output[i] != 0; } } else if (lctx.logits_all || embd_pooled) { n_outputs = n_tokens_all; @@ -9972,7 +9972,6 @@ bool llama_kv_cache_can_shift(struct llama_context * ctx) { return llama_kv_cache_can_shift(ctx->kv_self); } -/// int32_t llama_encode( struct llama_context * ctx, From 27f59dbaaa33bb0941a533ac4fd257e0cc9c564b Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Thu, 6 Feb 2025 08:00:30 +0100 Subject: [PATCH 2/2] squash! llama : rename batch.logits to batch.output Fix incorrectly named field in LibLlama.swift, outputs -> output. --- examples/llama.swiftui/llama.cpp.swift/LibLlama.swift | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift b/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift index dfece7761e9e1..7b4a55f2fc9da 100644 --- a/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift +++ b/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift @@ -16,7 +16,7 @@ func llama_batch_add(_ batch: inout llama_batch, _ id: llama_token, _ pos: llama for i in 0..