@@ -20,7 +20,7 @@ llama_context::llama_context(
20
20
const llama_model & model,
21
21
llama_context_params params) :
22
22
model(model),
23
- balloc (std::make_unique<llama_batch_allocr>(model.hparams.n_pos_per_embd() )) {
23
+ batch_allocr (std::make_unique<llama_batch_allocr>()) {
24
24
LLAMA_LOG_INFO (" %s: constructing llama_context\n " , __func__);
25
25
26
26
t_start_us = model.t_start_us ;
@@ -722,26 +722,23 @@ llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch,
722
722
}
723
723
724
724
int llama_context::encode (const llama_batch & batch_inp) {
725
- GGML_ASSERT ((!batch_inp.token && batch_inp.embd ) || (batch_inp.token && !batch_inp.embd )); // NOLINT
726
-
727
725
if (batch_inp.n_tokens == 0 ) {
728
726
LLAMA_LOG_ERROR (" %s: n_tokens == 0\n " , __func__);
729
727
return -1 ;
730
728
}
731
729
732
- const auto & hparams = model.hparams ;
733
-
734
- const int64_t n_embd = hparams.n_embd ;
735
-
730
+ // temporary allocate memory for the input batch if needed
736
731
// note: during encode, we always pass the full sequence starting from pos = 0
737
- if (!balloc ->init (batch_inp, model.vocab , nullptr , n_embd, true )) {
732
+ if (!batch_allocr ->init (batch_inp, model.vocab , batch_inp. pos ? - 1 : 0 )) {
738
733
LLAMA_LOG_ERROR (" %s: failed to initialize batch\n " , __func__);
739
734
return -1 ;
740
735
}
741
736
742
- const uint32_t n_tokens = balloc-> get_n_tokens ();
737
+ const llama_batch & batch = batch_allocr-> get_batch ();
743
738
744
- const llama_ubatch ubatch = balloc->split_simple (n_tokens);
739
+ const uint32_t n_tokens = batch.n_tokens ;
740
+
741
+ const int64_t n_embd = hparams.n_embd ;
745
742
746
743
// micro-batching is not possible for non-causal encoding, so we process the batch in a single shot
747
744
GGML_ASSERT (cparams.n_ubatch >= n_tokens && " encoder requires n_ubatch >= n_tokens" );
@@ -755,6 +752,8 @@ int llama_context::encode(const llama_batch & batch_inp) {
755
752
756
753
n_queued_tokens += n_tokens;
757
754
755
+ const auto & hparams = model.hparams ;
756
+
758
757
const int64_t n_embd = hparams.n_embd ;
759
758
760
759
llama_sbatch sbatch = llama_sbatch (batch, n_embd, /* simple_split */ true );
@@ -824,6 +823,12 @@ int llama_context::encode(const llama_batch & batch_inp) {
824
823
const llama_seq_id seq_id = ubatch.seq_id_unq [s];
825
824
const int32_t seq_idx = ubatch.seq_idx [seq_id];
826
825
826
+ // TODO: fix indexing [UBATCH_IDX]
827
+ for (uint32_t i = 0 ; i < n_tokens; i++) {
828
+ const llama_seq_id seq_id = ubatch.seq_id [i][0 ];
829
+ if (embd_seq_out.find (seq_id) != embd_seq_out.end ()) {
830
+ continue ;
831
+ }
827
832
embd_seq_out[seq_id].resize (n_embd);
828
833
ggml_backend_tensor_get_async (backend_embd, t_embd, embd_seq_out[seq_id].data (), (n_embd*seq_idx)*sizeof (float ), n_embd*sizeof (float ));
829
834
}
@@ -835,10 +840,12 @@ int llama_context::encode(const llama_batch & batch_inp) {
835
840
836
841
const uint32_t n_cls_out = hparams.n_cls_out ;
837
842
838
- for (uint32_t s = 0 ; s < ubatch.n_seqs_unq ; ++s) {
839
- const llama_seq_id seq_id = ubatch.seq_id_unq [s];
840
- const int32_t seq_idx = ubatch.seq_idx [seq_id];
841
-
843
+ // TODO: fix indexing [UBATCH_IDX]
844
+ for (uint32_t s = 0 ; s < ubatch.n_seqs ; ++s) {
845
+ const llama_seq_id seq_id = ubatch.seq_id [s][0 ];
846
+ if (embd_seq_out.find (seq_id) != embd_seq_out.end ()) {
847
+ continue ;
848
+ }
842
849
embd_seq_out[seq_id].resize (n_cls_out);
843
850
ggml_backend_tensor_get_async (backend_embd, t_embd, embd_seq_out[seq_id].data (), (n_cls_out*seq_idx)*sizeof (float ), n_cls_out*sizeof (float ));
844
851
}
@@ -868,15 +875,11 @@ int llama_context::encode(const llama_batch & batch_inp) {
868
875
const auto & batch = balloc->get_batch ();
869
876
870
877
// remember the sequence ids used during the encoding - needed for cross attention later
871
- // TODO: the seuqence indexing here is likely not correct in the general case
872
- // probably works only for split_simple
873
878
cross.seq_ids_enc .resize (n_tokens);
874
879
for (uint32_t i = 0 ; i < n_tokens; i++) {
875
880
cross.seq_ids_enc [i].clear ();
876
-
877
881
for (int s = 0 ; s < batch.n_seq_id [i]; s++) {
878
- const llama_seq_id seq_id = batch.seq_id [i][s];
879
-
882
+ llama_seq_id seq_id = batch.seq_id [i][s];
880
883
cross.seq_ids_enc [i].insert (seq_id);
881
884
}
882
885
}
@@ -886,57 +889,44 @@ int llama_context::encode(const llama_batch & batch_inp) {
886
889
}
887
890
888
891
int llama_context::decode (const llama_batch & batch_inp) {
889
- GGML_ASSERT ((!batch_inp.token && batch_inp.embd ) || (batch_inp.token && !batch_inp.embd )); // NOLINT
890
-
891
892
if (!memory) {
892
893
LLAMA_LOG_DEBUG (" %s: cannot decode batches with this context (calling encode() instead)\n " , __func__);
893
894
return encode (batch_inp);
895
+ return encode (batch_inp);
894
896
}
895
897
896
898
if (batch_inp.n_tokens == 0 ) {
897
899
LLAMA_LOG_ERROR (" %s: n_tokens == 0\n " , __func__);
898
900
return -1 ;
899
901
}
900
902
903
+ // temporary allocate memory for the input batch if needed
904
+ if (!batch_allocr->init (batch_inp, model.vocab , batch_inp.pos ? -1 : memory->seq_pos_max (0 ) + 1 )) {
905
+ LLAMA_LOG_ERROR (" %s: failed to initialize batch\n " , __func__);
906
+ return -1 ;
907
+ }
908
+
909
+ const llama_batch & batch = batch_allocr->get_batch ();
910
+
901
911
const auto & vocab = model.vocab ;
902
912
const auto & hparams = model.hparams ;
903
913
904
914
const int32_t n_vocab = vocab.n_tokens ();
915
+ const int64_t n_embd = hparams.n_embd ;
905
916
906
- const int64_t n_tokens_all = batch.n_tokens ;
907
- const int64_t n_embd = hparams.n_embd ;
917
+ const uint32_t n_tokens_all = batch.n_tokens ;
908
918
909
919
GGML_ASSERT ((!batch.token && batch.embd ) || (batch.token && !batch.embd )); // NOLINT
910
920
911
- // TODO: move the validation to the llama_batch_allocr
912
- if (batch.token ) {
913
- for (int64_t i = 0 ; i < n_tokens_all; ++i) {
914
- if (batch.token [i] < 0 || (uint32_t ) batch.token [i] >= model.vocab .n_tokens ()) {
915
- LLAMA_LOG_ERROR (" %s: invalid token[%" PRId64 " ] = %d\n " , __func__, i, batch.token [i]);
916
- return -1 ;
917
- }
918
-
919
- if (batch.seq_id && (batch.seq_id [i][0 ] < 0 || batch.seq_id [i][0 ] >= LLAMA_MAX_PARALLEL_SEQUENCES)) {
920
- LLAMA_LOG_ERROR (" %s: invalid seq_id[%" PRId64 " ] = %d >= %d\n " , __func__, i, batch.seq_id [i][0 ], LLAMA_MAX_PARALLEL_SEQUENCES);
921
- return -1 ;
922
- }
923
- }
924
- }
925
-
926
921
// this indicates we are doing pooled embedding
927
922
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
928
923
929
- int64_t n_outputs_all = 0 ;
930
-
931
- // count outputs
932
- for (uint32_t i = 0 ; i < n_tokens_all; ++i) {
933
- n_outputs_all += batch.logits [i] != 0 ;
934
- }
924
+ const uint32_t n_outputs_all = batch_allocr->get_n_outputs ();
935
925
936
926
if (embd_pooled) {
937
927
// require that all tokens are output
938
928
if (n_outputs_all != n_tokens_all) {
939
- LLAMA_LOG_ERROR (" %s: pooled embedding requires that all tokens are output (n_outputs_all = %" PRId64 " , n_tokens_all = %" PRId64 " )\n " ,
929
+ LLAMA_LOG_ERROR (" %s: pooled embedding requires that all tokens are output (n_outputs_all = %d , n_tokens_all = %d )\n " ,
940
930
__func__, n_outputs_all, n_tokens_all);
941
931
return -1 ;
942
932
}
@@ -1044,6 +1034,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
1044
1034
pos_min[s] = std::numeric_limits<llama_pos>::max ();
1045
1035
}
1046
1036
1037
+ // TODO: fix sequence indexing
1047
1038
for (uint32_t i = 0 ; i < ubatch.n_tokens ; ++i) {
1048
1039
const auto & seq_id = ubatch.seq_id [i][0 ];
1049
1040
0 commit comments