Skip to content

Commit 99be6b7

Browse files
committed
cont : move validation inside class
ggml-ci
1 parent f164ba9 commit 99be6b7

File tree

4 files changed

+38
-46
lines changed

4 files changed

+38
-46
lines changed

src/llama-batch.cpp

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
#include "llama-batch.h"
22

3+
#include "llama-impl.h"
4+
#include "llama-cparams.h"
5+
#include "llama-vocab.h"
6+
37
#include <cassert>
48
#include <cstring>
59
#include <algorithm>
@@ -281,12 +285,26 @@ llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple
281285

282286
llama_batch_allocr::llama_batch_allocr() = default;
283287

284-
bool llama_batch_allocr::init(struct llama_batch in_batch, llama_pos p0) {
285-
GGML_ASSERT(in_batch.n_tokens > 0);
286-
288+
bool llama_batch_allocr::init(const llama_batch & batch_inp, const llama_vocab & vocab, llama_pos p0) {
287289
clear();
288290

289-
batch = in_batch;
291+
batch = batch_inp;
292+
293+
GGML_ASSERT(batch.n_tokens > 0);
294+
295+
if (batch.token) {
296+
for (int32_t i = 0; i < batch.n_tokens; ++i) {
297+
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= vocab.n_tokens()) {
298+
LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
299+
return false;
300+
}
301+
302+
if (batch.seq_id && (batch.seq_id[i][0] < 0 || batch.seq_id[i][0] >= LLAMA_MAX_PARALLEL_SEQUENCES)) {
303+
LLAMA_LOG_ERROR("%s: invalid seq_id[%d] = %d > %d\n", __func__, i, batch.seq_id[i][0], LLAMA_MAX_PARALLEL_SEQUENCES);
304+
return false;
305+
}
306+
}
307+
}
290308

291309
if (!batch.pos) {
292310
assert(p0 >= 0);

src/llama-batch.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ class llama_batch_allocr {
8383
llama_batch_allocr();
8484

8585
// optionally fulfill the batch returned by llama_batch_get_one
86-
bool init(llama_batch in_batch, llama_pos p0);
86+
bool init(const llama_batch & batch_inp, const llama_vocab & vocab, llama_pos p0);
8787

8888
const llama_batch & get_batch() const;
8989

src/llama-context.cpp

Lines changed: 13 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -721,15 +721,17 @@ llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch,
721721
return res;
722722
}
723723

724-
int llama_context::encode(llama_batch & inp_batch) {
725-
if (inp_batch.n_tokens == 0) {
724+
int llama_context::encode(const llama_batch & batch_inp) {
725+
if (batch_inp.n_tokens == 0) {
726726
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
727727
return -1;
728728
}
729729

730730
// temporary allocate memory for the input batch if needed
731731
// note: during encode, we always pass the full sequence starting from pos = 0
732-
batch_allocr->init(inp_batch, inp_batch.pos ? -1 : 0);
732+
if (!batch_allocr->init(batch_inp, model.vocab, batch_inp.pos ? -1 : 0)) {
733+
return -1;
734+
}
733735

734736
const llama_batch & batch = batch_allocr->get_batch();
735737

@@ -739,21 +741,6 @@ int llama_context::encode(llama_batch & inp_batch) {
739741

740742
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
741743

742-
// TODO: move the validation to the llama_batch_allocr
743-
if (batch.token) {
744-
for (uint32_t i = 0; i < n_tokens; ++i) {
745-
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
746-
LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
747-
return -1;
748-
}
749-
750-
if (batch.seq_id && (batch.seq_id[i][0] < 0 || batch.seq_id[i][0] >= LLAMA_MAX_PARALLEL_SEQUENCES)) {
751-
LLAMA_LOG_ERROR("%s: invalid seq_id[%d] = %d > %d\n", __func__, i, batch.seq_id[i][0], LLAMA_MAX_PARALLEL_SEQUENCES);
752-
throw -1;
753-
}
754-
}
755-
}
756-
757744
// micro-batching is not possible for non-causal encoding, so we process the batch in a single shot
758745
GGML_ASSERT(cparams.n_ubatch >= (uint32_t) n_tokens && "encoder requires n_ubatch >= n_tokens");
759746

@@ -897,26 +884,28 @@ int llama_context::encode(llama_batch & inp_batch) {
897884
return 0;
898885
}
899886

900-
int llama_context::decode(llama_batch & inp_batch) {
887+
int llama_context::decode(const llama_batch & batch_inp) {
901888
if (!memory) {
902889
LLAMA_LOG_DEBUG("%s: cannot decode batches with this context (calling encode() instead)\n", __func__);
903-
return encode(inp_batch);
890+
return encode(batch_inp);
904891
}
905892

906-
if (inp_batch.n_tokens == 0) {
893+
if (batch_inp.n_tokens == 0) {
907894
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
908895
return -1;
909896
}
910897

911-
if (!inp_batch.pos) {
912-
if (inp_batch.seq_id) {
898+
if (!batch_inp.pos) {
899+
if (batch_inp.seq_id) {
913900
LLAMA_LOG_ERROR("%s: pos == NULL, but seq_id != NULL\n", __func__);
914901
return -1;
915902
}
916903
}
917904

918905
// temporary allocate memory for the input batch if needed
919-
batch_allocr->init(inp_batch, inp_batch.pos ? -1 : memory->seq_pos_max(0) + 1);
906+
if (!batch_allocr->init(batch_inp, model.vocab, batch_inp.pos ? -1 : memory->seq_pos_max(0) + 1)) {
907+
return -1;
908+
}
920909

921910
const llama_batch & batch = batch_allocr->get_batch();
922911

@@ -930,21 +919,6 @@ int llama_context::decode(llama_batch & inp_batch) {
930919

931920
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
932921

933-
// TODO: move the validation to the llama_batch_allocr
934-
if (batch.token) {
935-
for (uint32_t i = 0; i < n_tokens_all; ++i) {
936-
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
937-
LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
938-
return -1;
939-
}
940-
941-
if (batch.seq_id && (batch.seq_id[i][0] < 0 || batch.seq_id[i][0] >= LLAMA_MAX_PARALLEL_SEQUENCES)) {
942-
LLAMA_LOG_ERROR("%s: invalid seq_id[%d] = %d >= %d\n", __func__, i, batch.seq_id[i][0], LLAMA_MAX_PARALLEL_SEQUENCES);
943-
return -1;
944-
}
945-
}
946-
}
947-
948922
// this indicates we are doing pooled embedding
949923
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
950924

src/llama-context.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,8 @@ struct llama_context {
102102
llama_memory_state_i * mstate,
103103
ggml_status & ret);
104104

105-
int encode(llama_batch & inp_batch);
106-
int decode(llama_batch & inp_batch);
105+
int encode(const llama_batch & batch_inp);
106+
int decode(const llama_batch & batch_inp);
107107

108108
//
109109
// state save/load

0 commit comments

Comments
 (0)