Skip to content

Commit b8567ce

Browse files
committed
feat: First pass at llama_kv_cache_hybrid
This implementation covers both `llama_memory_i` and `llama_kv_cache` interfaces, but they could very well not be correct. Branch: HybridCache Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
1 parent 5364ae4 commit b8567ce

File tree

2 files changed

+299
-0
lines changed

2 files changed

+299
-0
lines changed

src/llama-kv-cache.cpp

Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2392,6 +2392,231 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
23922392
return true;
23932393
}
23942394

2395+
//
2396+
// llama_kv_cache_hybrid
2397+
//
2398+
llama_kv_cache_hybrid::llama_kv_cache_hybrid(
2399+
const llama_hparams & hparams,
2400+
const std::vector<child_cache> & children) :
2401+
m_hparams(hparams),
2402+
m_layer_cache_map(
2403+
[](const std::vector<child_cache>& caches) -> std::unordered_map<size_t, llama_kv_cache*> {
2404+
std::unordered_map<size_t, llama_kv_cache*> map;
2405+
for (const auto & cache : caches) {
2406+
for (size_t layer_id : cache.layer_ids) {
2407+
map[layer_id] = cache.child;
2408+
}
2409+
}
2410+
2411+
return map;
2412+
}(children)
2413+
),
2414+
m_children(
2415+
[](std::vector<child_cache> caches) -> std::set<llama_kv_cache*> {
2416+
// Sort the caches by the lowest layer ID so the order is repeatable
2417+
for (auto & cache : caches) {
2418+
GGML_ASSERT(cache.layer_ids.size() > 0);
2419+
std::sort(cache.layer_ids.begin(), cache.layer_ids.end());
2420+
}
2421+
std::sort(caches.begin(), caches.end(), [](const child_cache & a, const child_cache & b) {
2422+
return a.layer_ids[0] < b.layer_ids[0];
2423+
});
2424+
std::set<llama_kv_cache*> unique_caches;
2425+
for (const auto & cache : caches) {
2426+
unique_caches.insert(cache.child);
2427+
}
2428+
return unique_caches;
2429+
}(children)
2430+
),
2431+
m_has_recurrent(
2432+
[](const std::vector<child_cache>& caches) -> bool {
2433+
for (const auto & cache : caches) {
2434+
if (dynamic_cast<llama_kv_cache_recurrent *>(cache.child)) {
2435+
return true;
2436+
}
2437+
}
2438+
return false;
2439+
}(children)
2440+
)
2441+
{
2442+
// Ensure at least one child
2443+
GGML_ASSERT(m_children.size() > 0);
2444+
2445+
// Ensure layers are not overlapping and are concurrent
2446+
std::set<size_t> seen_layers;
2447+
size_t max_layer = 0;
2448+
for (const auto & cache : children) {
2449+
for (const auto & layer_id : cache.layer_ids) {
2450+
GGML_ASSERT(seen_layers.find(layer_id) == seen_layers.end());
2451+
seen_layers.insert(layer_id);
2452+
if (layer_id > max_layer) {
2453+
max_layer = layer_id;
2454+
}
2455+
}
2456+
}
2457+
GGML_ASSERT(max_layer == seen_layers.size());
2458+
}
2459+
2460+
void llama_kv_cache_hybrid::clear() {
2461+
for (const auto & cache : m_children) {
2462+
cache->clear();
2463+
}
2464+
}
2465+
2466+
bool llama_kv_cache_hybrid::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
2467+
// TODO: Will it cause problems if some caches are able to remove the seq
2468+
// but others aren't?
2469+
bool removed = true;
2470+
for (const auto & cache : m_children) {
2471+
removed = cache->seq_rm(seq_id, p0, p1) && removed;
2472+
}
2473+
return removed;
2474+
}
2475+
2476+
void llama_kv_cache_hybrid::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
2477+
for (const auto & cache : m_children) {
2478+
cache->seq_cp(seq_id_src, seq_id_dst, p0, p1);
2479+
}
2480+
}
2481+
2482+
void llama_kv_cache_hybrid::seq_keep(llama_seq_id seq_id) {
2483+
for (const auto & cache : m_children) {
2484+
cache->seq_keep(seq_id);
2485+
}
2486+
}
2487+
2488+
void llama_kv_cache_hybrid::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
2489+
for (const auto & cache : m_children) {
2490+
cache->seq_add(seq_id, p0, p1, delta);
2491+
}
2492+
}
2493+
2494+
void llama_kv_cache_hybrid::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
2495+
for (const auto & cache : m_children) {
2496+
cache->seq_div(seq_id, p0, p1, d);
2497+
}
2498+
}
2499+
2500+
llama_pos llama_kv_cache_hybrid::seq_pos_max(llama_seq_id seq_id) const {
2501+
llama_pos max_pos = 0;
2502+
for (const auto & cache : m_children) {
2503+
max_pos = std::max(max_pos, cache->seq_pos_max(seq_id));
2504+
}
2505+
return max_pos;
2506+
}
2507+
2508+
void llama_kv_cache_hybrid::restore() {
2509+
for (const auto & cache : m_children) {
2510+
cache->restore();
2511+
}
2512+
}
2513+
2514+
void llama_kv_cache_hybrid::commit() {
2515+
for (const auto & cache : m_children) {
2516+
cache->commit();
2517+
}
2518+
}
2519+
2520+
bool llama_kv_cache_hybrid::update(llama_context & ctx) {
2521+
bool updated = false;
2522+
for (const auto & cache : m_children) {
2523+
updated = cache->update(ctx) || updated;
2524+
}
2525+
return updated;
2526+
}
2527+
2528+
void llama_kv_cache_hybrid::defrag_sched(float thold) {
2529+
for (const auto & cache : m_children) {
2530+
cache->defrag_sched(thold);
2531+
}
2532+
}
2533+
2534+
void llama_kv_cache_hybrid::set_full() {
2535+
for (const auto & cache : m_children) {
2536+
cache->set_full();
2537+
}
2538+
}
2539+
2540+
llama_sbatch llama_kv_cache_hybrid::sbatch_init(const llama_batch & batch, bool logits_all) {
2541+
// If any of the caches are recurrent, require simple split
2542+
return llama_sbatch(batch, m_hparams.n_embd, m_has_recurrent, logits_all);
2543+
}
2544+
2545+
llama_ubatch llama_kv_cache_hybrid::ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const {
2546+
if (m_has_recurrent) {
2547+
return sbatch.split_simple(n_ubatch);
2548+
}
2549+
if (embd_pooled) {
2550+
// Pooled embeddings cannot be split across ubatches (yet)
2551+
return sbatch.split_seq(n_ubatch);
2552+
}
2553+
return sbatch.split_equal(n_ubatch);
2554+
}
2555+
2556+
bool llama_kv_cache_hybrid::find_slot(const llama_ubatch & batch) {
2557+
bool found = true;
2558+
for (const auto & cache : m_children) {
2559+
found = cache->find_slot(batch) && found;
2560+
}
2561+
return found;
2562+
}
2563+
2564+
int32_t llama_kv_cache_hybrid::get_n_tokens() const {
2565+
// The number of tokens should be the same across all child caches
2566+
int32_t n_tokens = -1;
2567+
for (const auto & cache : m_children) {
2568+
const auto cache_n_tokens = cache->get_n_tokens();
2569+
GGML_ASSERT(n_tokens == -1 || cache_n_tokens == n_tokens);
2570+
n_tokens = cache_n_tokens;
2571+
}
2572+
return n_tokens;
2573+
}
2574+
2575+
int32_t llama_kv_cache_hybrid::get_used_cells() const {
2576+
// TODO: Is this correct?
2577+
// Return the largetst number of used cells
2578+
int32_t used_cells = -1;
2579+
for (const auto & cache : m_children) {
2580+
used_cells = std::max(used_cells, cache->get_used_cells());
2581+
}
2582+
return used_cells;
2583+
}
2584+
2585+
llama_pos llama_kv_cache_hybrid::get_pos_max() const {
2586+
llama_pos pos_max = -1;
2587+
for (const auto & cache : m_children) {
2588+
pos_max = std::max(pos_max, cache->get_pos_max());
2589+
}
2590+
return pos_max;
2591+
}
2592+
2593+
bool llama_kv_cache_hybrid::get_can_shift() const {
2594+
// TODO: Is this correct?
2595+
// If any children can shift, return true
2596+
for (const auto & cache : m_children) {
2597+
if (cache->get_can_shift()) {
2598+
return true;
2599+
}
2600+
}
2601+
return false;
2602+
}
2603+
2604+
void llama_kv_cache_hybrid::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
2605+
// Write each cache state in order. Note that order is guaranteed at
2606+
// initialization by using an ordered set sorted by lowest layer ID
2607+
for (const auto & cache : m_children) {
2608+
cache->state_write(io, seq_id);
2609+
}
2610+
}
2611+
2612+
void llama_kv_cache_hybrid::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
2613+
// Read each cache state in order. Note that order is guaranteed at
2614+
// initialization by using an ordered set sorted by lowest layer ID
2615+
for (const auto & cache : m_children) {
2616+
cache->state_read(io, seq_id);
2617+
}
2618+
}
2619+
23952620
//
23962621
// kv cache view
23972622
//

src/llama-kv-cache.h

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
#include <set>
1111
#include <vector>
12+
#include <unordered_map>
1213

1314
struct llama_cparams;
1415
struct llama_hparams;
@@ -389,6 +390,79 @@ class llama_kv_cache_recurrent : public llama_kv_cache {
389390
bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
390391
};
391392

393+
//
394+
// llama_kv_cache_hybrid
395+
//
396+
397+
class llama_kv_cache_hybrid : public llama_kv_cache {
398+
public:
399+
400+
struct child_cache {
401+
llama_kv_cache * child;
402+
std::vector<size_t> layer_ids;
403+
};
404+
405+
llama_kv_cache_hybrid(
406+
const llama_hparams & hparams,
407+
const std::vector<child_cache> & children);
408+
409+
//
410+
// llama_memory_i
411+
//
412+
413+
void clear() override;
414+
415+
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
416+
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
417+
void seq_keep(llama_seq_id seq_id) override;
418+
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) override;
419+
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
420+
421+
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
422+
423+
//
424+
// llama_kv_cache
425+
//
426+
427+
void restore() override;
428+
void commit() override;
429+
430+
bool update(llama_context & ctx) override;
431+
432+
void defrag_sched(float thold) override;
433+
434+
void set_full() override;
435+
436+
llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override;
437+
438+
llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;
439+
440+
// updates the cache head
441+
// Note: On success, it's important that cache.head points
442+
// to the first cell of the slot.
443+
bool find_slot(const llama_ubatch & batch) override;
444+
445+
int32_t get_n_tokens() const override;
446+
int32_t get_used_cells() const override;
447+
448+
// TODO: better data structures to reduce the cost of this operation
449+
llama_pos get_pos_max() const override;
450+
451+
bool get_can_shift() const override;
452+
453+
// state write/load
454+
455+
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
456+
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
457+
458+
private:
459+
460+
const llama_hparams & m_hparams;
461+
const std::unordered_map<size_t, llama_kv_cache *> m_layer_cache_map;
462+
const std::set<llama_kv_cache *> m_children; // Ordered for state IO
463+
const bool m_has_recurrent;
464+
};
465+
392466

393467
//
394468
// kv cache view

0 commit comments

Comments
 (0)