@@ -443,6 +443,14 @@ void llama_kv_cache_unified::set_full() {
443
443
n = size;
444
444
}
445
445
446
+ bool llama_kv_cache_unified::can_seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) const {
447
+ GGML_UNUSED (seq_id);
448
+ GGML_UNUSED (p0);
449
+ GGML_UNUSED (p1);
450
+ // Unified attention cache can always do a sequence removal
451
+ return true ;
452
+ }
453
+
446
454
llama_sbatch llama_kv_cache_unified::sbatch_init (
447
455
const llama_batch & batch,
448
456
bool logits_all) {
@@ -1479,39 +1487,33 @@ void llama_kv_cache_recurrent::clear() {
1479
1487
}
1480
1488
1481
1489
bool llama_kv_cache_recurrent::seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
1482
- uint32_t new_head = size;
1490
+ if (!can_seq_rm (seq_id, p0, p1)) {
1491
+ // could be fatal
1492
+ return false ;
1493
+ }
1483
1494
1495
+ uint32_t new_head = size;
1484
1496
if (p0 < 0 ) {
1485
1497
p0 = 0 ;
1486
1498
}
1487
-
1488
1499
if (p1 < 0 ) {
1489
1500
p1 = std::numeric_limits<llama_pos>::max ();
1490
1501
}
1491
1502
1492
- // models like Mamba or RWKV can't have a state partially erased
1493
- if (seq_id >= (int64_t ) size) {
1494
- // could be fatal
1495
- return false ;
1496
- }
1497
1503
if (0 <= seq_id) {
1498
1504
int32_t & tail_id = cells[seq_id].tail ;
1499
1505
if (tail_id >= 0 ) {
1500
1506
const kv_cell & cell = cells[tail_id];
1501
- // partial intersection is invalid
1502
- if ((0 < p0 && p0 <= cell.pos ) || (0 < p1 && p1 <= cell.pos )) {
1503
- return false ;
1504
- }
1507
+ // already validated in can_seq_rm
1508
+ GGML_ASSERT (!((0 < p0 && p0 <= cell.pos ) || (0 < p1 && p1 <= cell.pos )));
1505
1509
// invalidate tails which will be cleared
1506
1510
if (p0 <= cell.pos && cell.pos < p1) {
1507
1511
tail_id = -1 ;
1508
1512
}
1509
1513
}
1510
1514
} else {
1511
- // seq_id is negative, then the range should include everything or nothing
1512
- if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits<llama_pos>::max ())) {
1513
- return false ;
1514
- }
1515
+ // already validated in can_seq_rm
1516
+ GGML_ASSERT (!(p0 != p1 && (p0 != 0 || p1 != std::numeric_limits<llama_pos>::max ())));
1515
1517
}
1516
1518
1517
1519
for (uint32_t i = 0 ; i < size; ++i) {
@@ -1712,6 +1714,35 @@ void llama_kv_cache_recurrent::set_full() {
1712
1714
n = size;
1713
1715
}
1714
1716
1717
+ bool llama_kv_cache_recurrent::can_seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) const {
1718
+ if (p0 < 0 ) {
1719
+ p0 = 0 ;
1720
+ }
1721
+
1722
+ if (p1 < 0 ) {
1723
+ p1 = std::numeric_limits<llama_pos>::max ();
1724
+ }
1725
+ // models like Mamba or RWKV can't have a state partially erased
1726
+ if (seq_id >= (int64_t ) size) {
1727
+ // could be fatal
1728
+ return false ;
1729
+ }
1730
+ if (0 <= seq_id) {
1731
+ const int32_t & tail_id = cells[seq_id].tail ;
1732
+ if (tail_id >= 0 ) {
1733
+ const kv_cell & cell = cells[tail_id];
1734
+ // partial intersection is invalid
1735
+ if ((0 < p0 && p0 <= cell.pos ) || (0 < p1 && p1 <= cell.pos )) {
1736
+ return false ;
1737
+ }
1738
+ }
1739
+ // seq_id is negative, then the range should include everything or nothing
1740
+ } else if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits<llama_pos>::max ())) {
1741
+ return false ;
1742
+ }
1743
+ return true ;
1744
+ }
1745
+
1715
1746
llama_sbatch llama_kv_cache_recurrent::sbatch_init (
1716
1747
const llama_batch & batch,
1717
1748
bool logits_all) {
@@ -2355,6 +2386,245 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
2355
2386
return true ;
2356
2387
}
2357
2388
2389
+ //
2390
+ // llama_kv_cache_hybrid
2391
+ //
2392
+ llama_kv_cache_hybrid::llama_kv_cache_hybrid (
2393
+ const llama_hparams & hparams,
2394
+ std::vector<child_cache> children) :
2395
+ m_hparams(hparams),
2396
+ m_layer_cache_map(
2397
+ [](const std::vector<child_cache>& caches) -> std::unordered_map<size_t, llama_kv_cache*> {
2398
+ std::unordered_map<size_t , llama_kv_cache*> map;
2399
+ for (const auto & cache : caches) {
2400
+ for (size_t layer_id : cache.layer_ids ) {
2401
+ map[layer_id] = cache.child .get ();
2402
+ }
2403
+ }
2404
+
2405
+ return map;
2406
+ }(children)
2407
+ ),
2408
+ m_children (
2409
+ [](std::vector<child_cache>& caches) -> std::set<std::unique_ptr<llama_kv_cache>> {
2410
+ // Sort the caches by the lowest layer ID so the order is repeatable
2411
+ for (auto & cache : caches) {
2412
+ GGML_ASSERT (cache.layer_ids .size () > 0 );
2413
+ std::sort (cache.layer_ids .begin (), cache.layer_ids .end ());
2414
+ }
2415
+ std::sort (caches.begin (), caches.end (), [](const child_cache & a, const child_cache & b) {
2416
+ return a.layer_ids [0 ] < b.layer_ids [0 ];
2417
+ });
2418
+ std::set<std::unique_ptr<llama_kv_cache>> unique_caches;
2419
+ for (auto & cache : caches) {
2420
+ unique_caches.emplace (cache.child .release ());
2421
+ }
2422
+ return unique_caches;
2423
+ }(children)
2424
+ ),
2425
+ m_has_recurrent (
2426
+ [](const std::set<std::unique_ptr<llama_kv_cache>> & caches) -> bool {
2427
+ for (const auto & cache : caches) {
2428
+ if (dynamic_cast <llama_kv_cache_recurrent *>(cache.get ())) {
2429
+ return true ;
2430
+ }
2431
+ }
2432
+ return false ;
2433
+ }(m_children)
2434
+ )
2435
+ {
2436
+ // Ensure at least one child
2437
+ GGML_ASSERT (m_children.size () > 0 );
2438
+
2439
+ // Ensure layers are not overlapping and are concurrent
2440
+ std::set<size_t > seen_layers;
2441
+ size_t max_layer = 0 ;
2442
+ for (const auto & cache : children) {
2443
+ for (const auto & layer_id : cache.layer_ids ) {
2444
+ GGML_ASSERT (seen_layers.find (layer_id) == seen_layers.end ());
2445
+ seen_layers.insert (layer_id);
2446
+ if (layer_id > max_layer) {
2447
+ max_layer = layer_id;
2448
+ }
2449
+ }
2450
+ }
2451
+ GGML_ASSERT (max_layer == seen_layers.size ());
2452
+ }
2453
+
2454
+ void llama_kv_cache_hybrid::clear () {
2455
+ for (const auto & cache : m_children) {
2456
+ cache->clear ();
2457
+ }
2458
+ }
2459
+
2460
+ bool llama_kv_cache_hybrid::seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
2461
+ // First check if we can do this removal. This checks all children so that
2462
+ // no mutation happens before we know if it's possible
2463
+ if (!can_seq_rm (seq_id, p0, p1)) {
2464
+ return false ;
2465
+ }
2466
+
2467
+ // Do the removal from each child which should never fail
2468
+ for (const auto & cache : m_children) {
2469
+ const bool failed = cache->seq_rm (seq_id, p0, p1);
2470
+ GGML_ASSERT (!failed);
2471
+ }
2472
+ return true ;
2473
+ }
2474
+
2475
+ void llama_kv_cache_hybrid::seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
2476
+ for (const auto & cache : m_children) {
2477
+ cache->seq_cp (seq_id_src, seq_id_dst, p0, p1);
2478
+ }
2479
+ }
2480
+
2481
+ void llama_kv_cache_hybrid::seq_keep (llama_seq_id seq_id) {
2482
+ for (const auto & cache : m_children) {
2483
+ cache->seq_keep (seq_id);
2484
+ }
2485
+ }
2486
+
2487
+ void llama_kv_cache_hybrid::seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
2488
+ for (const auto & cache : m_children) {
2489
+ cache->seq_add (seq_id, p0, p1, delta);
2490
+ }
2491
+ }
2492
+
2493
+ void llama_kv_cache_hybrid::seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
2494
+ for (const auto & cache : m_children) {
2495
+ cache->seq_div (seq_id, p0, p1, d);
2496
+ }
2497
+ }
2498
+
2499
+ llama_pos llama_kv_cache_hybrid::seq_pos_max (llama_seq_id seq_id) const {
2500
+ llama_pos max_pos = 0 ;
2501
+ for (const auto & cache : m_children) {
2502
+ max_pos = std::max (max_pos, cache->seq_pos_max (seq_id));
2503
+ }
2504
+ return max_pos;
2505
+ }
2506
+
2507
+ void llama_kv_cache_hybrid::restore () {
2508
+ for (const auto & cache : m_children) {
2509
+ cache->restore ();
2510
+ }
2511
+ }
2512
+
2513
+ void llama_kv_cache_hybrid::commit () {
2514
+ for (const auto & cache : m_children) {
2515
+ cache->commit ();
2516
+ }
2517
+ }
2518
+
2519
+ bool llama_kv_cache_hybrid::update (llama_context & ctx) {
2520
+ bool updated = false ;
2521
+ for (const auto & cache : m_children) {
2522
+ updated = cache->update (ctx) || updated;
2523
+ }
2524
+ return updated;
2525
+ }
2526
+
2527
+ void llama_kv_cache_hybrid::defrag_sched (float thold) {
2528
+ for (const auto & cache : m_children) {
2529
+ cache->defrag_sched (thold);
2530
+ }
2531
+ }
2532
+
2533
+ void llama_kv_cache_hybrid::set_full () {
2534
+ for (const auto & cache : m_children) {
2535
+ cache->set_full ();
2536
+ }
2537
+ }
2538
+
2539
+ bool llama_kv_cache_hybrid::can_seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) const {
2540
+ for (const auto & cache : m_children) {
2541
+ if (!cache->can_seq_rm (seq_id, p0, p1)) {
2542
+ return false ;
2543
+ }
2544
+ }
2545
+ return true ;
2546
+ }
2547
+
2548
+ llama_sbatch llama_kv_cache_hybrid::sbatch_init (const llama_batch & batch, bool logits_all) {
2549
+ // If any of the caches are recurrent, require equal split
2550
+ return llama_sbatch (batch, m_hparams.n_embd , !m_has_recurrent, logits_all);
2551
+ }
2552
+
2553
+ llama_ubatch llama_kv_cache_hybrid::ubatch_next (llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const {
2554
+ if (embd_pooled) {
2555
+ // Pooled embeddings cannot be split across ubatches (yet)
2556
+ return sbatch.split_seq (n_ubatch);
2557
+ }
2558
+ if (m_has_recurrent) {
2559
+ return sbatch.split_equal (n_ubatch);
2560
+ }
2561
+ return sbatch.split_simple (n_ubatch);
2562
+ }
2563
+
2564
+ bool llama_kv_cache_hybrid::find_slot (const llama_ubatch & batch) {
2565
+ bool found = true ;
2566
+ for (const auto & cache : m_children) {
2567
+ found = cache->find_slot (batch) && found;
2568
+ }
2569
+ return found;
2570
+ }
2571
+
2572
+ int32_t llama_kv_cache_hybrid::get_n_tokens () const {
2573
+ // The number of tokens should be the same across all child caches
2574
+ int32_t n_tokens = -1 ;
2575
+ for (const auto & cache : m_children) {
2576
+ const auto cache_n_tokens = cache->get_n_tokens ();
2577
+ GGML_ASSERT (n_tokens == -1 || cache_n_tokens == n_tokens);
2578
+ n_tokens = cache_n_tokens;
2579
+ }
2580
+ return n_tokens;
2581
+ }
2582
+
2583
+ int32_t llama_kv_cache_hybrid::get_used_cells () const {
2584
+ // TODO: Is this correct?
2585
+ // Return the largest number of used cells
2586
+ int32_t used_cells = -1 ;
2587
+ for (const auto & cache : m_children) {
2588
+ used_cells = std::max (used_cells, cache->get_used_cells ());
2589
+ }
2590
+ return used_cells;
2591
+ }
2592
+
2593
+ llama_pos llama_kv_cache_hybrid::get_pos_max () const {
2594
+ llama_pos pos_max = -1 ;
2595
+ for (const auto & cache : m_children) {
2596
+ pos_max = std::max (pos_max, cache->get_pos_max ());
2597
+ }
2598
+ return pos_max;
2599
+ }
2600
+
2601
+ bool llama_kv_cache_hybrid::get_can_shift () const {
2602
+ // TODO: Is this correct?
2603
+ // If any children can shift, return true
2604
+ for (const auto & cache : m_children) {
2605
+ if (cache->get_can_shift ()) {
2606
+ return true ;
2607
+ }
2608
+ }
2609
+ return false ;
2610
+ }
2611
+
2612
+ void llama_kv_cache_hybrid::state_write (llama_io_write_i & io, llama_seq_id seq_id) const {
2613
+ // Write each cache state in order. Note that order is guaranteed at
2614
+ // initialization by using an ordered set sorted by lowest layer ID
2615
+ for (const auto & cache : m_children) {
2616
+ cache->state_write (io, seq_id);
2617
+ }
2618
+ }
2619
+
2620
+ void llama_kv_cache_hybrid::state_read (llama_io_read_i & io, llama_seq_id seq_id) {
2621
+ // Read each cache state in order. Note that order is guaranteed at
2622
+ // initialization by using an ordered set sorted by lowest layer ID
2623
+ for (const auto & cache : m_children) {
2624
+ cache->state_read (io, seq_id);
2625
+ }
2626
+ }
2627
+
2358
2628
//
2359
2629
// kv cache view
2360
2630
//
0 commit comments