Skip to content

Commit 74edb42

Browse files
committed
llama : rename batch.logits to batch.output
This commit renames the `logits` field of the `llama_batch` struct to `output`. The motivation for this change (apart from the TODO comment) is that the `logits` field is actually used to specify that output should be generated. For example, in the case of generating embeddings, setting logits to true can be confusing since the logits are not used when generating embeddings.
1 parent 873279b commit 74edb42

File tree

16 files changed

+49
-49
lines changed

16 files changed

+49
-49
lines changed

common/common.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -554,7 +554,7 @@ std::string string_from(const struct llama_context * ctx, const struct llama_bat
554554
<< ":pos " << std::to_string(batch.pos[i])
555555
<< ":n_seq_id " << std::to_string(batch.n_seq_id[i])
556556
<< ":seq_id " << std::to_string(batch.seq_id[i][0])
557-
<< ":logits " << std::to_string(batch.logits[i]);
557+
<< ":output " << std::to_string(batch.output[i]);
558558
}
559559

560560
buf << " ]";
@@ -1480,7 +1480,7 @@ void common_batch_add(
14801480
llama_token id,
14811481
llama_pos pos,
14821482
const std::vector<llama_seq_id> & seq_ids,
1483-
bool logits) {
1483+
bool output) {
14841484
GGML_ASSERT(batch.seq_id[batch.n_tokens] && "llama_batch size exceeded");
14851485

14861486
batch.token [batch.n_tokens] = id;
@@ -1489,7 +1489,7 @@ void common_batch_add(
14891489
for (size_t i = 0; i < seq_ids.size(); ++i) {
14901490
batch.seq_id[batch.n_tokens][i] = seq_ids[i];
14911491
}
1492-
batch.logits [batch.n_tokens] = logits;
1492+
batch.output [batch.n_tokens] = output;
14931493

14941494
batch.n_tokens++;
14951495
}

examples/batched-bench/batched-bench.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ int main(int argc, char ** argv) {
7373
batch.pos + i,
7474
batch.n_seq_id + i,
7575
batch.seq_id + i,
76-
batch.logits + i,
76+
batch.output + i,
7777
};
7878

7979
const int ret = llama_decode(ctx, batch_view);
@@ -128,7 +128,7 @@ int main(int argc, char ** argv) {
128128
common_batch_add(batch, 0, i, { j }, false);
129129
}
130130
}
131-
batch.logits[batch.n_tokens - 1] = true;
131+
batch.output[batch.n_tokens - 1] = true;
132132

133133
const auto t_pp_start = ggml_time_us();
134134

examples/batched.swift/Sources/main.swift

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,11 +99,11 @@ for (i, token) in tokens.enumerated() {
9999
if let seq_id = batch.seq_id[i] {
100100
seq_id[0] = 0
101101
}
102-
batch.logits[i] = 0
102+
batch.output[i] = 0
103103
}
104104

105105
// llama_decode will output logits only for the last token of the prompt
106-
batch.logits[Int(batch.n_tokens) - 1] = 1
106+
batch.output[Int(batch.n_tokens) - 1] = 1
107107

108108
if llama_decode(context, batch) != 0 {
109109
print("llama_decode() failed")
@@ -166,7 +166,7 @@ while n_cur <= n_len {
166166
if let seq_id = batch.seq_id[Int(batch.n_tokens)] {
167167
seq_id[0] = Int32(i)
168168
}
169-
batch.logits[Int(batch.n_tokens)] = 1
169+
batch.output[Int(batch.n_tokens)] = 1
170170

171171
i_batch[i] = batch.n_tokens
172172

examples/batched/batched.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ int main(int argc, char ** argv) {
128128
}
129129

130130
// llama_decode will output logits only for the last token of the prompt
131-
batch.logits[batch.n_tokens - 1] = true;
131+
batch.output[batch.n_tokens - 1] = true;
132132

133133
if (llama_decode(ctx, batch) != 0) {
134134
LOG_ERR("%s: llama_decode() failed\n", __func__);

examples/embedding/embedding.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
5454
}
5555

5656
for (int i = 0; i < batch.n_tokens; i++) {
57-
if (!batch.logits[i]) {
57+
if (!batch.output[i]) {
5858
continue;
5959
}
6060

examples/llama.android/llama/src/main/cpp/llama-android.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model(
193193
common_batch_add(*batch, 0, i, { 0 }, false);
194194
}
195195

196-
batch->logits[batch->n_tokens - 1] = true;
196+
batch->output[batch->n_tokens - 1] = true;
197197
llama_kv_cache_clear(context);
198198

199199
const auto t_pp_start = ggml_time_us();
@@ -297,7 +297,7 @@ Java_android_llama_cpp_LLamaAndroid_new_1batch(JNIEnv *, jobject, jint n_tokens,
297297
for (int i = 0; i < n_tokens; ++i) {
298298
batch->seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max);
299299
}
300-
batch->logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens);
300+
batch->output = (int8_t *) malloc(sizeof(int8_t) * n_tokens);
301301

302302
return reinterpret_cast<jlong>(batch);
303303
}
@@ -377,7 +377,7 @@ Java_android_llama_cpp_LLamaAndroid_completion_1init(
377377
}
378378

379379
// llama_decode will output logits only for the last token of the prompt
380-
batch->logits[batch->n_tokens - 1] = true;
380+
batch->output[batch->n_tokens - 1] = true;
381381

382382
if (llama_decode(context, *batch) != 0) {
383383
LOGe("llama_decode() failed");

examples/llama.swiftui/llama.cpp.swift/LibLlama.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ actor LlamaContext {
137137
let i = Int(i1)
138138
llama_batch_add(&batch, tokens_list[i], Int32(i), [0], false)
139139
}
140-
batch.logits[Int(batch.n_tokens) - 1] = 1 // true
140+
batch.output[Int(batch.n_tokens) - 1] = 1 // true
141141

142142
if llama_decode(context, batch) != 0 {
143143
print("llama_decode() failed")
@@ -206,7 +206,7 @@ actor LlamaContext {
206206
for i in 0..<n_tokens {
207207
llama_batch_add(&batch, 0, Int32(i), [0], false)
208208
}
209-
batch.logits[Int(batch.n_tokens) - 1] = 1 // true
209+
batch.output[Int(batch.n_tokens) - 1] = 1 // true
210210

211211
llama_kv_cache_clear(context)
212212

examples/llava/llava.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -406,13 +406,13 @@ struct llava_embd_batch {
406406
std::vector<int32_t> n_seq_id;
407407
std::vector<llama_seq_id> seq_id_0;
408408
std::vector<llama_seq_id *> seq_ids;
409-
std::vector<int8_t> logits;
409+
std::vector<int8_t> outputs;
410410
llama_batch batch;
411411
llava_embd_batch(float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) {
412412
pos .resize(n_tokens);
413413
n_seq_id.resize(n_tokens);
414414
seq_ids .resize(n_tokens + 1);
415-
logits .resize(n_tokens);
415+
outputs .resize(n_tokens);
416416
seq_id_0.resize(1);
417417
seq_id_0[0] = seq_id;
418418
seq_ids [n_tokens] = nullptr;
@@ -423,13 +423,13 @@ struct llava_embd_batch {
423423
/*pos =*/ pos.data(),
424424
/*n_seq_id =*/ n_seq_id.data(),
425425
/*seq_id =*/ seq_ids.data(),
426-
/*logits =*/ logits.data(),
426+
/*output =*/ outputs.data(),
427427
};
428428
for (int i = 0; i < n_tokens; i++) {
429429
batch.pos [i] = pos_0 + i;
430430
batch.n_seq_id[i] = 1;
431431
batch.seq_id [i] = seq_id_0.data();
432-
batch.logits [i] = false;
432+
batch.output [i] = false;
433433
}
434434
}
435435
};

examples/parallel/parallel.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@ int main(int argc, char ** argv) {
264264

265265
// extract the logits only for the last token
266266
if (batch.n_tokens > 0) {
267-
batch.logits[batch.n_tokens - 1] = true;
267+
batch.output[batch.n_tokens - 1] = true;
268268
}
269269

270270
client.n_prompt = tokens_prompt.size();
@@ -307,7 +307,7 @@ int main(int argc, char ** argv) {
307307
batch.pos + i,
308308
batch.n_seq_id + i,
309309
batch.seq_id + i,
310-
batch.logits + i,
310+
batch.output + i,
311311
};
312312

313313
const int ret = llama_decode(ctx, batch_view);

examples/passkey/passkey.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ int main(int argc, char ** argv) {
144144
}
145145

146146
if (i + n_batch >= n_tokens_all) {
147-
batch.logits[batch.n_tokens - 1] = true;
147+
batch.output[batch.n_tokens - 1] = true;
148148
}
149149

150150
if (llama_decode(ctx, batch) != 0) {
@@ -178,7 +178,7 @@ int main(int argc, char ** argv) {
178178
}
179179

180180
if (i + n_batch >= n_tokens_all) {
181-
batch.logits[batch.n_tokens - 1] = true;
181+
batch.output[batch.n_tokens - 1] = true;
182182
}
183183

184184
if (llama_decode(ctx, batch) != 0) {

0 commit comments

Comments
 (0)