@@ -343,29 +343,23 @@ llama_memory_context_ptr llama_kv_cache_unified::init_batch(
343
343
GGML_UNUSED (embd_all);
344
344
345
345
do {
346
- balloc. split_reset ( );
346
+ auto sbatch = llama_sbatch (batch, hparams. n_embd , true , logits_all );
347
347
348
348
std::vector<llama_ubatch> ubatches;
349
- while (true ) {
350
- auto ubatch = balloc.split_simple (n_ubatch);
351
-
352
- if (ubatch.n_tokens == 0 ) {
353
- break ;
354
- }
355
-
356
- ubatches.push_back (std::move (ubatch)); // NOLINT
349
+ while (sbatch.n_tokens > 0 ) {
350
+ ubatches.push_back (sbatch.split_simple (n_ubatch));
357
351
}
358
352
359
353
auto heads = prepare (ubatches);
360
354
if (heads.empty ()) {
361
355
break ;
362
356
}
363
357
364
- return std::make_unique<llama_kv_cache_unified_context >(
365
- this , std::move (heads), std::move (ubatches));
358
+ return std::make_unique<llama_kv_cache_unified_state >(
359
+ this , std::move (sbatch), std::move ( heads), std::move (ubatches));
366
360
} while (false );
367
361
368
- return std::make_unique<llama_kv_cache_unified_context >(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
362
+ return std::make_unique<llama_kv_cache_unified_state >(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
369
363
}
370
364
371
365
llama_memory_context_ptr llama_kv_cache_unified::init_full () {
@@ -559,7 +553,6 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
559
553
}
560
554
561
555
if (debug > 0 ) {
562
- LLAMA_LOG_CONT (" \n " );
563
556
LLAMA_LOG_DEBUG (" %s: n = %5d, used = %5d, head = %5d, size = %5d, n_swa = %5d\n " , __func__, cells.used_max_p1 (), cells.get_used (), head, get_size (), n_swa);
564
557
565
558
if ((debug == 2 && n_swa > 0 ) || debug > 2 ) {
@@ -685,29 +678,39 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
685
678
}
686
679
687
680
void llama_kv_cache_unified::apply_ubatch (uint32_t head_cur, const llama_ubatch & ubatch) {
681
+ if (debug > 0 ) {
682
+ LLAMA_LOG_DEBUG (" %s: ubatch info:\n " , __func__);
683
+ LLAMA_LOG_DEBUG (" %s: n_tokens = %d, equal_seqs = %d\n " , __func__, ubatch.n_tokens , ubatch.equal_seqs );
684
+ LLAMA_LOG_DEBUG (" %s: n_seq_tokens = %d, n_seqs = %d\n " , __func__, ubatch.n_seq_tokens , ubatch.n_seqs );
685
+ }
686
+
688
687
// keep track of the max sequence position that we would overwrite with this ubatch
689
688
// for non-SWA cache, this would be always empty
690
689
llama_seq_id seq_pos_max_rm[LLAMA_MAX_PARALLEL_SEQUENCES];
691
690
for (int s = 0 ; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
692
691
seq_pos_max_rm[s] = -1 ;
693
692
}
694
693
695
- for (uint32_t i = 0 ; i < ubatch.n_tokens ; ++i ) {
696
- if (!cells. is_empty (head_cur + i) ) {
697
- assert (cells. seq_count (head_cur + i) == 1 ) ;
694
+ for (uint32_t s = 0 ; s < ubatch.n_seqs ; ++s ) {
695
+ for ( uint32_t j = 0 ; j < ubatch. n_seq_tokens ; ++j ) {
696
+ const uint32_t idx = s*ubatch. n_seq_tokens + j ;
698
697
699
- const llama_seq_id seq_id = cells.seq_get (head_cur + i);
700
- const llama_pos pos = cells.pos_get (head_cur + i );
698
+ if (! cells.is_empty (head_cur + idx)) {
699
+ assert ( cells.seq_count (head_cur + idx) == 1 );
701
700
702
- seq_pos_max_rm[seq_id] = std::max (seq_pos_max_rm[seq_id], pos);
701
+ const llama_seq_id seq_id = cells.seq_get (head_cur + idx);
702
+ const llama_pos pos = cells.pos_get (head_cur + idx);
703
703
704
- cells.rm (head_cur + i);
705
- }
704
+ seq_pos_max_rm[seq_id] = std::max (seq_pos_max_rm[seq_id], pos);
706
705
707
- cells.pos_set (head_cur + i, ubatch.pos [i]);
706
+ cells.rm (head_cur + idx);
707
+ }
708
+
709
+ cells.pos_set (head_cur + idx, ubatch.pos [idx]);
708
710
709
- for (int32_t j = 0 ; j < ubatch.n_seq_id [i]; j++) {
710
- cells.seq_add (head_cur + i, ubatch.seq_id [i][j]);
711
+ for (int32_t i = 0 ; i < ubatch.n_seq_id [s]; i++) {
712
+ cells.seq_add (head_cur + idx, ubatch.seq_id [s][i]);
713
+ }
711
714
}
712
715
}
713
716
@@ -726,7 +729,6 @@ void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch
726
729
seq_rm (s, cells.seq_pos_min (s), seq_pos_max_rm[s] + 1 );
727
730
}
728
731
}
729
-
730
732
// move the head at the end of the slot
731
733
head = head_cur + ubatch.n_tokens ;
732
734
}
@@ -823,11 +825,14 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
823
825
}
824
826
825
827
void llama_kv_cache_unified::set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
826
- const uint32_t n_tokens = ubatch->n_tokens ;
828
+ const uint32_t n_tokens = ubatch->n_tokens ;
829
+ const uint32_t n_seq_tokens = ubatch->n_seq_tokens ;
830
+ const uint32_t n_seqs = ubatch->n_seqs ;
827
831
828
832
GGML_ASSERT (ggml_backend_buffer_is_host (dst->buffer ));
829
833
float * data = (float *) dst->data ;
830
834
835
+ const int64_t n_kv = dst->ne [0 ];
831
836
const int64_t n_kv = dst->ne [0 ];
832
837
833
838
// Use only the previous KV cells of the correct sequence for each token of the ubatch.
@@ -843,10 +848,13 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
843
848
// xxxxx-----
844
849
// To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
845
850
for (uint32_t h = 0 ; h < 1 ; ++h) {
846
- for (uint32_t i = 0 ; i < n_tokens; ++i) {
847
- const llama_seq_id seq_id = ubatch->seq_id [i][0 ];
851
+ for (uint32_t s = 0 ; s < n_seqs; ++s) {
852
+ const llama_seq_id seq_id = ubatch->seq_id [s][0 ];
853
+
854
+ for (uint32_t j = 0 ; j < n_seq_tokens; ++j) {
855
+ const uint32_t idx = s*n_seq_tokens + j;
848
856
849
- const llama_pos p1 = ubatch->pos [i ];
857
+ const llama_pos p1 = ubatch->pos [idx ];
850
858
851
859
for (uint32_t j = 0 ; j < n_kv; ++j) {
852
860
float f = 0 .0f ;
@@ -876,15 +884,16 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
876
884
f = -INFINITY;
877
885
}
878
886
879
- data[h*(n_kv*n_tokens) + i*n_kv + j] = f;
887
+ data[h*(n_kv*n_tokens) + idx*n_kv + i] = f;
888
+ }
880
889
}
881
890
}
882
891
883
892
// mask padded tokens
884
893
if (data) {
885
- for (uint32_t i = n_tokens; i < GGML_PAD (n_tokens, GGML_KQ_MASK_PAD); ++i ) {
886
- for (uint32_t j = 0 ; j < n_kv; ++j ) {
887
- data[h*(n_kv*n_tokens) + i *n_kv + j ] = -INFINITY;
894
+ for (uint32_t j = n_tokens; j < GGML_PAD (n_tokens, GGML_KQ_MASK_PAD); ++j ) {
895
+ for (uint32_t i = 0 ; i < n_kv; ++i ) {
896
+ data[h*(n_kv*n_tokens) + j *n_kv + i ] = -INFINITY;
888
897
}
889
898
}
890
899
}
@@ -1534,9 +1543,12 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
1534
1543
1535
1544
seq_rm (dest_seq_id, -1 , -1 );
1536
1545
1537
- llama_batch_allocr balloc (hparams.n_pos_per_embd ());
1546
+ llama_sbatch sbatch;
1547
+ llama_ubatch ubatch = sbatch.reserve_ubatch (cell_count, /* has_embd */ false );
1538
1548
1539
- llama_ubatch ubatch = balloc.ubatch_reserve (cell_count, 1 );
1549
+ ubatch.n_tokens = cell_count;
1550
+ ubatch.n_seq_tokens = cell_count;
1551
+ ubatch.n_seqs = 1 ;
1540
1552
1541
1553
for (uint32_t i = 0 ; i < cell_count; ++i) {
1542
1554
llama_pos pos;
0 commit comments