@@ -2354,6 +2354,231 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
2354
2354
return true ;
2355
2355
}
2356
2356
2357
+ //
2358
+ // llama_kv_cache_hybrid
2359
+ //
2360
+ llama_kv_cache_hybrid::llama_kv_cache_hybrid (
2361
+ const llama_hparams & hparams,
2362
+ const std::vector<child_cache> & children) :
2363
+ m_hparams(hparams),
2364
+ m_layer_cache_map(
2365
+ [](const std::vector<child_cache>& caches) -> std::unordered_map<size_t, llama_kv_cache*> {
2366
+ std::unordered_map<size_t , llama_kv_cache*> map;
2367
+ for (const auto & cache : caches) {
2368
+ for (size_t layer_id : cache.layer_ids ) {
2369
+ map[layer_id] = cache.child ;
2370
+ }
2371
+ }
2372
+
2373
+ return map;
2374
+ }(children)
2375
+ ),
2376
+ m_children (
2377
+ [](std::vector<child_cache> caches) -> std::set<llama_kv_cache*> {
2378
+ // Sort the caches by the lowest layer ID so the order is repeatable
2379
+ for (auto & cache : caches) {
2380
+ GGML_ASSERT (cache.layer_ids .size () > 0 );
2381
+ std::sort (cache.layer_ids .begin (), cache.layer_ids .end ());
2382
+ }
2383
+ std::sort (caches.begin (), caches.end (), [](const child_cache & a, const child_cache & b) {
2384
+ return a.layer_ids [0 ] < b.layer_ids [0 ];
2385
+ });
2386
+ std::set<llama_kv_cache*> unique_caches;
2387
+ for (const auto & cache : caches) {
2388
+ unique_caches.insert (cache.child );
2389
+ }
2390
+ return unique_caches;
2391
+ }(children)
2392
+ ),
2393
+ m_has_recurrent (
2394
+ [](const std::vector<child_cache>& caches) -> bool {
2395
+ for (const auto & cache : caches) {
2396
+ if (dynamic_cast <llama_kv_cache_recurrent *>(cache.child )) {
2397
+ return true ;
2398
+ }
2399
+ }
2400
+ return false ;
2401
+ }(children)
2402
+ )
2403
+ {
2404
+ // Ensure at least one child
2405
+ GGML_ASSERT (m_children.size () > 0 );
2406
+
2407
+ // Ensure layers are not overlapping and are concurrent
2408
+ std::set<size_t > seen_layers;
2409
+ size_t max_layer = 0 ;
2410
+ for (const auto & cache : children) {
2411
+ for (const auto & layer_id : cache.layer_ids ) {
2412
+ GGML_ASSERT (seen_layers.find (layer_id) == seen_layers.end ());
2413
+ seen_layers.insert (layer_id);
2414
+ if (layer_id > max_layer) {
2415
+ max_layer = layer_id;
2416
+ }
2417
+ }
2418
+ }
2419
+ GGML_ASSERT (max_layer == seen_layers.size ());
2420
+ }
2421
+
2422
+ void llama_kv_cache_hybrid::clear () {
2423
+ for (const auto & cache : m_children) {
2424
+ cache->clear ();
2425
+ }
2426
+ }
2427
+
2428
+ bool llama_kv_cache_hybrid::seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
2429
+ // TODO: Will it cause problems if some caches are able to remove the seq
2430
+ // but others aren't?
2431
+ bool removed = true ;
2432
+ for (const auto & cache : m_children) {
2433
+ removed = cache->seq_rm (seq_id, p0, p1) && removed;
2434
+ }
2435
+ return removed;
2436
+ }
2437
+
2438
+ void llama_kv_cache_hybrid::seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
2439
+ for (const auto & cache : m_children) {
2440
+ cache->seq_cp (seq_id_src, seq_id_dst, p0, p1);
2441
+ }
2442
+ }
2443
+
2444
+ void llama_kv_cache_hybrid::seq_keep (llama_seq_id seq_id) {
2445
+ for (const auto & cache : m_children) {
2446
+ cache->seq_keep (seq_id);
2447
+ }
2448
+ }
2449
+
2450
+ void llama_kv_cache_hybrid::seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
2451
+ for (const auto & cache : m_children) {
2452
+ cache->seq_add (seq_id, p0, p1, delta);
2453
+ }
2454
+ }
2455
+
2456
+ void llama_kv_cache_hybrid::seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
2457
+ for (const auto & cache : m_children) {
2458
+ cache->seq_div (seq_id, p0, p1, d);
2459
+ }
2460
+ }
2461
+
2462
+ llama_pos llama_kv_cache_hybrid::seq_pos_max (llama_seq_id seq_id) const {
2463
+ llama_pos max_pos = 0 ;
2464
+ for (const auto & cache : m_children) {
2465
+ max_pos = std::max (max_pos, cache->seq_pos_max (seq_id));
2466
+ }
2467
+ return max_pos;
2468
+ }
2469
+
2470
+ void llama_kv_cache_hybrid::restore () {
2471
+ for (const auto & cache : m_children) {
2472
+ cache->restore ();
2473
+ }
2474
+ }
2475
+
2476
+ void llama_kv_cache_hybrid::commit () {
2477
+ for (const auto & cache : m_children) {
2478
+ cache->commit ();
2479
+ }
2480
+ }
2481
+
2482
+ bool llama_kv_cache_hybrid::update (llama_context & ctx) {
2483
+ bool updated = false ;
2484
+ for (const auto & cache : m_children) {
2485
+ updated = cache->update (ctx) || updated;
2486
+ }
2487
+ return updated;
2488
+ }
2489
+
2490
+ void llama_kv_cache_hybrid::defrag_sched (float thold) {
2491
+ for (const auto & cache : m_children) {
2492
+ cache->defrag_sched (thold);
2493
+ }
2494
+ }
2495
+
2496
+ void llama_kv_cache_hybrid::set_full () {
2497
+ for (const auto & cache : m_children) {
2498
+ cache->set_full ();
2499
+ }
2500
+ }
2501
+
2502
+ llama_sbatch llama_kv_cache_hybrid::sbatch_init (const llama_batch & batch, bool logits_all) {
2503
+ // If any of the caches are recurrent, require simple split
2504
+ return llama_sbatch (batch, m_hparams.n_embd , m_has_recurrent, logits_all);
2505
+ }
2506
+
2507
+ llama_ubatch llama_kv_cache_hybrid::ubatch_next (llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const {
2508
+ if (m_has_recurrent) {
2509
+ return sbatch.split_simple (n_ubatch);
2510
+ }
2511
+ if (embd_pooled) {
2512
+ // Pooled embeddings cannot be split across ubatches (yet)
2513
+ return sbatch.split_seq (n_ubatch);
2514
+ }
2515
+ return sbatch.split_equal (n_ubatch);
2516
+ }
2517
+
2518
+ bool llama_kv_cache_hybrid::find_slot (const llama_ubatch & batch) {
2519
+ bool found = true ;
2520
+ for (const auto & cache : m_children) {
2521
+ found = cache->find_slot (batch) && found;
2522
+ }
2523
+ return found;
2524
+ }
2525
+
2526
+ int32_t llama_kv_cache_hybrid::get_n_tokens () const {
2527
+ // The number of tokens should be the same across all child caches
2528
+ int32_t n_tokens = -1 ;
2529
+ for (const auto & cache : m_children) {
2530
+ const auto cache_n_tokens = cache->get_n_tokens ();
2531
+ GGML_ASSERT (n_tokens == -1 || cache_n_tokens == n_tokens);
2532
+ n_tokens = cache_n_tokens;
2533
+ }
2534
+ return n_tokens;
2535
+ }
2536
+
2537
+ int32_t llama_kv_cache_hybrid::get_used_cells () const {
2538
+ // TODO: Is this correct?
2539
+ // Return the largetst number of used cells
2540
+ int32_t used_cells = -1 ;
2541
+ for (const auto & cache : m_children) {
2542
+ used_cells = std::max (used_cells, cache->get_used_cells ());
2543
+ }
2544
+ return used_cells;
2545
+ }
2546
+
2547
+ llama_pos llama_kv_cache_hybrid::get_pos_max () const {
2548
+ llama_pos pos_max = -1 ;
2549
+ for (const auto & cache : m_children) {
2550
+ pos_max = std::max (pos_max, cache->get_pos_max ());
2551
+ }
2552
+ return pos_max;
2553
+ }
2554
+
2555
+ bool llama_kv_cache_hybrid::get_can_shift () const {
2556
+ // TODO: Is this correct?
2557
+ // If any children can shift, return true
2558
+ for (const auto & cache : m_children) {
2559
+ if (cache->get_can_shift ()) {
2560
+ return true ;
2561
+ }
2562
+ }
2563
+ return false ;
2564
+ }
2565
+
2566
+ void llama_kv_cache_hybrid::state_write (llama_io_write_i & io, llama_seq_id seq_id) const {
2567
+ // Write each cache state in order. Note that order is guaranteed at
2568
+ // initialization by using an ordered set sorted by lowest layer ID
2569
+ for (const auto & cache : m_children) {
2570
+ cache->state_write (io, seq_id);
2571
+ }
2572
+ }
2573
+
2574
+ void llama_kv_cache_hybrid::state_read (llama_io_read_i & io, llama_seq_id seq_id) {
2575
+ // Read each cache state in order. Note that order is guaranteed at
2576
+ // initialization by using an ordered set sorted by lowest layer ID
2577
+ for (const auto & cache : m_children) {
2578
+ cache->state_read (io, seq_id);
2579
+ }
2580
+ }
2581
+
2357
2582
//
2358
2583
// kv cache view
2359
2584
//
0 commit comments