Skip to content

Commit 75a0d52

Browse files
committed
cont : move output counting to class
ggml-ci
1 parent 99be6b7 commit 75a0d52

File tree

3 files changed

+18
-6
lines changed

3 files changed

+18
-6
lines changed

src/llama-batch.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -339,14 +339,24 @@ bool llama_batch_allocr::init(const llama_batch & batch_inp, const llama_vocab &
339339
batch.logits = output.data();
340340
}
341341

342+
for (int32_t i = 0; i < batch.n_tokens; ++i) {
343+
n_outputs += batch.logits[i] != 0;
344+
}
345+
342346
return true;
343347
}
344348

345349
const llama_batch & llama_batch_allocr::get_batch() const {
346350
return batch;
347351
}
348352

353+
uint32_t llama_batch_allocr::get_n_outputs() const {
354+
return n_outputs;
355+
}
356+
349357
void llama_batch_allocr::clear() {
358+
n_outputs = 0;
359+
350360
batch = {};
351361
pos.clear();
352362
n_seq_id.clear();

src/llama-batch.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,12 +87,17 @@ class llama_batch_allocr {
8787

8888
const llama_batch & get_batch() const;
8989

90+
uint32_t get_n_outputs() const;
91+
9092
private:
9193
void clear();
9294

9395
llama_batch batch;
9496

97+
uint32_t n_outputs;
98+
9599
std::array<llama_seq_id, 1> seq_id_0 = { 0 }; // default sequence id
100+
96101
std::vector<llama_pos> pos;
97102
std::vector<int32_t> n_seq_id;
98103
std::vector<llama_seq_id *> seq_id;

src/llama-context.cpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -730,6 +730,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
730730
// temporary allocate memory for the input batch if needed
731731
// note: during encode, we always pass the full sequence starting from pos = 0
732732
if (!batch_allocr->init(batch_inp, model.vocab, batch_inp.pos ? -1 : 0)) {
733+
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
733734
return -1;
734735
}
735736

@@ -904,6 +905,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
904905

905906
// temporary allocate memory for the input batch if needed
906907
if (!batch_allocr->init(batch_inp, model.vocab, batch_inp.pos ? -1 : memory->seq_pos_max(0) + 1)) {
908+
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
907909
return -1;
908910
}
909911

@@ -922,12 +924,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
922924
// this indicates we are doing pooled embedding
923925
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
924926

925-
uint32_t n_outputs_all = 0;
926-
927-
// count outputs
928-
for (uint32_t i = 0; i < n_tokens_all; ++i) {
929-
n_outputs_all += batch.logits[i] != 0;
930-
}
927+
const uint32_t n_outputs_all = batch_allocr->get_n_outputs();
931928

932929
if (embd_pooled) {
933930
// require that all tokens are output

0 commit comments

Comments
 (0)