Skip to content

Commit bbbcaae

Browse files
committed
llama : rename batch.logits to batch.output
This commit renames the `logits` field of the `llama_batch` struct to `output`. The motivation for this change (apart from the TODO comment) is that the `logits` field is actually used to specify that output should be generated. For example, in the case of generating embeddings, setting logits to true can be confusing since the logits are not used when generating embeddings.
1 parent c421ac0 commit bbbcaae

File tree

13 files changed

+41
-41
lines changed

13 files changed

+41
-41
lines changed

common/common.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -554,7 +554,7 @@ std::string string_from(const struct llama_context * ctx, const struct llama_bat
554554
<< ":pos " << std::to_string(batch.pos[i])
555555
<< ":n_seq_id " << std::to_string(batch.n_seq_id[i])
556556
<< ":seq_id " << std::to_string(batch.seq_id[i][0])
557-
<< ":logits " << std::to_string(batch.logits[i]);
557+
<< ":output " << std::to_string(batch.output[i]);
558558
}
559559

560560
buf << " ]";
@@ -1480,7 +1480,7 @@ void common_batch_add(
14801480
llama_token id,
14811481
llama_pos pos,
14821482
const std::vector<llama_seq_id> & seq_ids,
1483-
bool logits) {
1483+
bool output) {
14841484
GGML_ASSERT(batch.seq_id[batch.n_tokens] && "llama_batch size exceeded");
14851485

14861486
batch.token [batch.n_tokens] = id;
@@ -1489,7 +1489,7 @@ void common_batch_add(
14891489
for (size_t i = 0; i < seq_ids.size(); ++i) {
14901490
batch.seq_id[batch.n_tokens][i] = seq_ids[i];
14911491
}
1492-
batch.logits [batch.n_tokens] = logits;
1492+
batch.output [batch.n_tokens] = output;
14931493

14941494
batch.n_tokens++;
14951495
}

examples/batched-bench/batched-bench.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ int main(int argc, char ** argv) {
7373
batch.pos + i,
7474
batch.n_seq_id + i,
7575
batch.seq_id + i,
76-
batch.logits + i,
76+
batch.output + i,
7777
};
7878

7979
const int ret = llama_decode(ctx, batch_view);
@@ -128,7 +128,7 @@ int main(int argc, char ** argv) {
128128
common_batch_add(batch, 0, i, { j }, false);
129129
}
130130
}
131-
batch.logits[batch.n_tokens - 1] = true;
131+
batch.output[batch.n_tokens - 1] = true;
132132

133133
const auto t_pp_start = ggml_time_us();
134134

examples/batched/batched.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ int main(int argc, char ** argv) {
128128
}
129129

130130
// llama_decode will output logits only for the last token of the prompt
131-
batch.logits[batch.n_tokens - 1] = true;
131+
batch.output[batch.n_tokens - 1] = true;
132132

133133
if (llama_decode(ctx, batch) != 0) {
134134
LOG_ERR("%s: llama_decode() failed\n", __func__);

examples/embedding/embedding.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
5454
}
5555

5656
for (int i = 0; i < batch.n_tokens; i++) {
57-
if (!batch.logits[i]) {
57+
if (!batch.output[i]) {
5858
continue;
5959
}
6060

examples/llava/llava.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -406,13 +406,13 @@ struct llava_embd_batch {
406406
std::vector<int32_t> n_seq_id;
407407
std::vector<llama_seq_id> seq_id_0;
408408
std::vector<llama_seq_id *> seq_ids;
409-
std::vector<int8_t> logits;
409+
std::vector<int8_t> outputs;
410410
llama_batch batch;
411411
llava_embd_batch(float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) {
412412
pos .resize(n_tokens);
413413
n_seq_id.resize(n_tokens);
414414
seq_ids .resize(n_tokens + 1);
415-
logits .resize(n_tokens);
415+
outputs .resize(n_tokens);
416416
seq_id_0.resize(1);
417417
seq_id_0[0] = seq_id;
418418
seq_ids [n_tokens] = nullptr;
@@ -423,13 +423,13 @@ struct llava_embd_batch {
423423
/*pos =*/ pos.data(),
424424
/*n_seq_id =*/ n_seq_id.data(),
425425
/*seq_id =*/ seq_ids.data(),
426-
/*logits =*/ logits.data(),
426+
/*output =*/ outputs.data(),
427427
};
428428
for (int i = 0; i < n_tokens; i++) {
429429
batch.pos [i] = pos_0 + i;
430430
batch.n_seq_id[i] = 1;
431431
batch.seq_id [i] = seq_id_0.data();
432-
batch.logits [i] = false;
432+
batch.output [i] = false;
433433
}
434434
}
435435
};

examples/parallel/parallel.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@ int main(int argc, char ** argv) {
264264

265265
// extract the logits only for the last token
266266
if (batch.n_tokens > 0) {
267-
batch.logits[batch.n_tokens - 1] = true;
267+
batch.output[batch.n_tokens - 1] = true;
268268
}
269269

270270
client.n_prompt = tokens_prompt.size();
@@ -307,7 +307,7 @@ int main(int argc, char ** argv) {
307307
batch.pos + i,
308308
batch.n_seq_id + i,
309309
batch.seq_id + i,
310-
batch.logits + i,
310+
batch.output + i,
311311
};
312312

313313
const int ret = llama_decode(ctx, batch_view);

examples/passkey/passkey.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ int main(int argc, char ** argv) {
144144
}
145145

146146
if (i + n_batch >= n_tokens_all) {
147-
batch.logits[batch.n_tokens - 1] = true;
147+
batch.output[batch.n_tokens - 1] = true;
148148
}
149149

150150
if (llama_decode(ctx, batch) != 0) {
@@ -178,7 +178,7 @@ int main(int argc, char ** argv) {
178178
}
179179

180180
if (i + n_batch >= n_tokens_all) {
181-
batch.logits[batch.n_tokens - 1] = true;
181+
batch.output[batch.n_tokens - 1] = true;
182182
}
183183

184184
if (llama_decode(ctx, batch) != 0) {

examples/perplexity/perplexity.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -615,9 +615,9 @@ static results_perplexity perplexity(llama_context * ctx, const common_params &
615615
batch.pos [idx] = j*n_batch + k;
616616
batch.n_seq_id[idx] = 1;
617617
batch.seq_id [idx][0] = seq;
618-
batch.logits [idx] = batch.pos[idx] >= first ? 1 : 0;
618+
batch.output [idx] = batch.pos[idx] >= first ? 1 : 0;
619619

620-
n_outputs += batch.logits[idx] != 0;
620+
n_outputs += batch.output[idx] != 0;
621621
}
622622
batch.n_tokens += batch_size;
623623

@@ -712,7 +712,7 @@ static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector<
712712
batch.pos + i,
713713
batch.n_seq_id + i,
714714
batch.seq_id + i,
715-
batch.logits + i,
715+
batch.output + i,
716716
};
717717

718718
const int ret = llama_decode(ctx, batch_view);
@@ -723,7 +723,7 @@ static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector<
723723

724724
int n_outputs = 0;
725725
for (int i = 0; i < n_tokens; ++i) {
726-
n_outputs += batch_view.logits[i] != 0;
726+
n_outputs += batch_view.output[i] != 0;
727727
}
728728

729729
memcpy(batch_logits.data() + size_t(prev_outputs)*n_vocab, llama_get_logits(ctx), size_t(n_outputs)*n_vocab*sizeof(float));
@@ -936,7 +936,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
936936
for (size_t i = 0; i < hs_cur.common_prefix; ++i) {
937937
common_batch_add(batch, hs_cur.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3 }, false);
938938
}
939-
batch.logits[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix
939+
batch.output[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix
940940
n_logits += 1;
941941

942942
for (int s = 0; s < 4; ++s) {
@@ -1215,7 +1215,7 @@ static void winogrande_score(llama_context * ctx, const common_params & params)
12151215
for (size_t i = 0; i < data[i1].common_prefix; ++i) {
12161216
common_batch_add(batch, data[i1].seq_tokens[0][i], i, { s0 + 0, s0 + 1 }, false);
12171217
}
1218-
batch.logits[batch.n_tokens - 1] = true;
1218+
batch.output[batch.n_tokens - 1] = true;
12191219
n_logits += 1;
12201220

12211221
for (int s = 0; s < 2; ++s) {
@@ -1581,7 +1581,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par
15811581
//llama_batch_add(batch, cur_task.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3}, false);
15821582
common_batch_add(batch, cur_task.seq_tokens[0][i], i, batch_indeces, false);
15831583
}
1584-
batch.logits[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix
1584+
batch.output[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix
15851585
n_logits += 1;
15861586

15871587
for (int s = 0; s < int(cur_task.seq_tokens.size()); ++s) {

examples/retrieval/retrieval.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
9292
}
9393

9494
for (int i = 0; i < batch.n_tokens; i++) {
95-
if (!batch.logits[i]) {
95+
if (!batch.output[i]) {
9696
continue;
9797
}
9898

examples/save-load-state/save-load-state.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ int main(int argc, char ** argv) {
5252
for (size_t i = 0; i < tokens.size(); i++) {
5353
common_batch_add(batch, tokens[i], i, {0}, false);
5454
}
55-
batch.logits[batch.n_tokens - 1] = true; // generate next token
55+
batch.output[batch.n_tokens - 1] = true; // generate next token
5656

5757
// evaluate prompt
5858
llama_decode(ctx, batch);

0 commit comments

Comments
 (0)