@@ -721,15 +721,17 @@ llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch,
721
721
return res;
722
722
}
723
723
724
- int llama_context::encode (llama_batch & inp_batch ) {
725
- if (inp_batch .n_tokens == 0 ) {
724
+ int llama_context::encode (const llama_batch & batch_inp ) {
725
+ if (batch_inp .n_tokens == 0 ) {
726
726
LLAMA_LOG_ERROR (" %s: n_tokens == 0\n " , __func__);
727
727
return -1 ;
728
728
}
729
729
730
730
// temporary allocate memory for the input batch if needed
731
731
// note: during encode, we always pass the full sequence starting from pos = 0
732
- batch_allocr->init (inp_batch, inp_batch.pos ? -1 : 0 );
732
+ if (!batch_allocr->init (batch_inp, model.vocab , batch_inp.pos ? -1 : 0 )) {
733
+ return -1 ;
734
+ }
733
735
734
736
const llama_batch & batch = batch_allocr->get_batch ();
735
737
@@ -739,21 +741,6 @@ int llama_context::encode(llama_batch & inp_batch) {
739
741
740
742
GGML_ASSERT ((!batch.token && batch.embd ) || (batch.token && !batch.embd )); // NOLINT
741
743
742
- // TODO: move the validation to the llama_batch_allocr
743
- if (batch.token ) {
744
- for (uint32_t i = 0 ; i < n_tokens; ++i) {
745
- if (batch.token [i] < 0 || (uint32_t ) batch.token [i] >= model.vocab .n_tokens ()) {
746
- LLAMA_LOG_ERROR (" %s: invalid token[%d] = %d\n " , __func__, i, batch.token [i]);
747
- return -1 ;
748
- }
749
-
750
- if (batch.seq_id && (batch.seq_id [i][0 ] < 0 || batch.seq_id [i][0 ] >= LLAMA_MAX_PARALLEL_SEQUENCES)) {
751
- LLAMA_LOG_ERROR (" %s: invalid seq_id[%d] = %d > %d\n " , __func__, i, batch.seq_id [i][0 ], LLAMA_MAX_PARALLEL_SEQUENCES);
752
- throw -1 ;
753
- }
754
- }
755
- }
756
-
757
744
// micro-batching is not possible for non-causal encoding, so we process the batch in a single shot
758
745
GGML_ASSERT (cparams.n_ubatch >= (uint32_t ) n_tokens && " encoder requires n_ubatch >= n_tokens" );
759
746
@@ -897,26 +884,28 @@ int llama_context::encode(llama_batch & inp_batch) {
897
884
return 0 ;
898
885
}
899
886
900
- int llama_context::decode (llama_batch & inp_batch ) {
887
+ int llama_context::decode (const llama_batch & batch_inp ) {
901
888
if (!memory) {
902
889
LLAMA_LOG_DEBUG (" %s: cannot decode batches with this context (calling encode() instead)\n " , __func__);
903
- return encode (inp_batch );
890
+ return encode (batch_inp );
904
891
}
905
892
906
- if (inp_batch .n_tokens == 0 ) {
893
+ if (batch_inp .n_tokens == 0 ) {
907
894
LLAMA_LOG_ERROR (" %s: n_tokens == 0\n " , __func__);
908
895
return -1 ;
909
896
}
910
897
911
- if (!inp_batch .pos ) {
912
- if (inp_batch .seq_id ) {
898
+ if (!batch_inp .pos ) {
899
+ if (batch_inp .seq_id ) {
913
900
LLAMA_LOG_ERROR (" %s: pos == NULL, but seq_id != NULL\n " , __func__);
914
901
return -1 ;
915
902
}
916
903
}
917
904
918
905
// temporary allocate memory for the input batch if needed
919
- batch_allocr->init (inp_batch, inp_batch.pos ? -1 : memory->seq_pos_max (0 ) + 1 );
906
+ if (!batch_allocr->init (batch_inp, model.vocab , batch_inp.pos ? -1 : memory->seq_pos_max (0 ) + 1 )) {
907
+ return -1 ;
908
+ }
920
909
921
910
const llama_batch & batch = batch_allocr->get_batch ();
922
911
@@ -930,21 +919,6 @@ int llama_context::decode(llama_batch & inp_batch) {
930
919
931
920
GGML_ASSERT ((!batch.token && batch.embd ) || (batch.token && !batch.embd )); // NOLINT
932
921
933
- // TODO: move the validation to the llama_batch_allocr
934
- if (batch.token ) {
935
- for (uint32_t i = 0 ; i < n_tokens_all; ++i) {
936
- if (batch.token [i] < 0 || (uint32_t ) batch.token [i] >= model.vocab .n_tokens ()) {
937
- LLAMA_LOG_ERROR (" %s: invalid token[%d] = %d\n " , __func__, i, batch.token [i]);
938
- return -1 ;
939
- }
940
-
941
- if (batch.seq_id && (batch.seq_id [i][0 ] < 0 || batch.seq_id [i][0 ] >= LLAMA_MAX_PARALLEL_SEQUENCES)) {
942
- LLAMA_LOG_ERROR (" %s: invalid seq_id[%d] = %d >= %d\n " , __func__, i, batch.seq_id [i][0 ], LLAMA_MAX_PARALLEL_SEQUENCES);
943
- return -1 ;
944
- }
945
- }
946
- }
947
-
948
922
// this indicates we are doing pooled embedding
949
923
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
950
924
0 commit comments