Skip to content

Commit 165794d

Browse files
committed
feat: Initial implementation of llama_kv_cache_hybrid
Condensed from initial version https://github.com/gabe-l-hart/llama.cpp/tree/ec08571 The only difference is the removal of m_layer_cache_map which was unused and unnecessary now that child caches are instantiated with their own filters. Branch: HybridCache Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
1 parent d78ca19 commit 165794d

File tree

3 files changed

+403
-0
lines changed

3 files changed

+403
-0
lines changed

src/llama-kv-cache.cpp

Lines changed: 237 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2933,3 +2933,240 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
29332933

29342934
return true;
29352935
}
2936+
2937+
//
2938+
// llama_kv_cache_hybrid
2939+
//
2940+
llama_kv_cache_hybrid::llama_kv_cache_hybrid(
2941+
const llama_hparams & hparams,
2942+
std::vector<child_cache> children) :
2943+
m_hparams(hparams),
2944+
m_children(
2945+
[](std::vector<child_cache>& caches) -> std::set<std::unique_ptr<llama_kv_cache>> {
2946+
// Sort the caches by the lowest layer ID so the order is repeatable
2947+
for (auto & cache : caches) {
2948+
GGML_ASSERT(cache.layer_ids.size() > 0);
2949+
std::sort(cache.layer_ids.begin(), cache.layer_ids.end());
2950+
}
2951+
std::sort(caches.begin(), caches.end(), [](const child_cache & a, const child_cache & b) {
2952+
return a.layer_ids[0] < b.layer_ids[0];
2953+
});
2954+
std::set<std::unique_ptr<llama_kv_cache>> unique_caches;
2955+
for (auto & cache : caches) {
2956+
unique_caches.emplace(cache.child.release());
2957+
}
2958+
return unique_caches;
2959+
}(children)
2960+
),
2961+
m_has_recurrent(
2962+
[](const std::set<std::unique_ptr<llama_kv_cache>> & caches) -> bool {
2963+
for (const auto & cache : caches) {
2964+
if (dynamic_cast<llama_kv_cache_recurrent *>(cache.get())) {
2965+
return true;
2966+
}
2967+
}
2968+
return false;
2969+
}(m_children)
2970+
)
2971+
{
2972+
// Ensure at least one child
2973+
GGML_ASSERT(m_children.size() > 0);
2974+
2975+
// Ensure layers are not overlapping and are concurrent
2976+
std::set<size_t> seen_layers;
2977+
size_t max_layer = 0;
2978+
for (const auto & cache : children) {
2979+
for (const auto & layer_id : cache.layer_ids) {
2980+
GGML_ASSERT(seen_layers.find(layer_id) == seen_layers.end());
2981+
seen_layers.insert(layer_id);
2982+
if (layer_id > max_layer) {
2983+
max_layer = layer_id;
2984+
}
2985+
}
2986+
}
2987+
LLAMA_LOG_DEBUG("max_layer=%zu, seen_layers.size()=%zu\n", max_layer, seen_layers.size());
2988+
GGML_ASSERT(max_layer + 1 == seen_layers.size());
2989+
}
2990+
2991+
void llama_kv_cache_hybrid::clear() {
2992+
for (const auto & cache : m_children) {
2993+
cache->clear();
2994+
}
2995+
}
2996+
2997+
bool llama_kv_cache_hybrid::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
2998+
// First check if we can do this removal. This checks all children so that
2999+
// no mutation happens before we know if it's possible
3000+
if (!can_seq_rm(seq_id, p0, p1)) {
3001+
return false;
3002+
}
3003+
3004+
// Do the removal from each child which should never fail
3005+
for (const auto & cache : m_children) {
3006+
const bool failed = cache->seq_rm(seq_id, p0, p1);
3007+
GGML_ASSERT(!failed);
3008+
}
3009+
return true;
3010+
}
3011+
3012+
void llama_kv_cache_hybrid::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
3013+
for (const auto & cache : m_children) {
3014+
cache->seq_cp(seq_id_src, seq_id_dst, p0, p1);
3015+
}
3016+
}
3017+
3018+
void llama_kv_cache_hybrid::seq_keep(llama_seq_id seq_id) {
3019+
for (const auto & cache : m_children) {
3020+
cache->seq_keep(seq_id);
3021+
}
3022+
}
3023+
3024+
void llama_kv_cache_hybrid::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
3025+
for (const auto & cache : m_children) {
3026+
cache->seq_add(seq_id, p0, p1, delta);
3027+
}
3028+
}
3029+
3030+
void llama_kv_cache_hybrid::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
3031+
for (const auto & cache : m_children) {
3032+
cache->seq_div(seq_id, p0, p1, d);
3033+
}
3034+
}
3035+
3036+
llama_pos llama_kv_cache_hybrid::seq_pos_min(llama_seq_id seq_id) const {
3037+
llama_pos min_pos = -1;
3038+
for (const auto & cache : m_children) {
3039+
const auto child_min_pos = cache->seq_pos_min(seq_id);
3040+
min_pos = min_pos == -1 ? child_min_pos : std::min(min_pos, child_min_pos);
3041+
}
3042+
return min_pos;
3043+
}
3044+
3045+
llama_pos llama_kv_cache_hybrid::seq_pos_max(llama_seq_id seq_id) const {
3046+
llama_pos max_pos = 0;
3047+
for (const auto & cache : m_children) {
3048+
max_pos = std::max(max_pos, cache->seq_pos_max(seq_id));
3049+
}
3050+
return max_pos;
3051+
}
3052+
3053+
void llama_kv_cache_hybrid::restore() {
3054+
for (const auto & cache : m_children) {
3055+
cache->restore();
3056+
}
3057+
}
3058+
3059+
void llama_kv_cache_hybrid::commit() {
3060+
for (const auto & cache : m_children) {
3061+
cache->commit();
3062+
}
3063+
}
3064+
3065+
bool llama_kv_cache_hybrid::update(llama_context & ctx) {
3066+
bool updated = false;
3067+
for (const auto & cache : m_children) {
3068+
updated = cache->update(ctx) || updated;
3069+
}
3070+
return updated;
3071+
}
3072+
3073+
void llama_kv_cache_hybrid::defrag_sched(float thold) {
3074+
for (const auto & cache : m_children) {
3075+
cache->defrag_sched(thold);
3076+
}
3077+
}
3078+
3079+
void llama_kv_cache_hybrid::set_full() {
3080+
for (const auto & cache : m_children) {
3081+
cache->set_full();
3082+
}
3083+
}
3084+
3085+
bool llama_kv_cache_hybrid::can_seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) const {
3086+
for (const auto & cache : m_children) {
3087+
if (!cache->can_seq_rm(seq_id, p0, p1)) {
3088+
return false;
3089+
}
3090+
}
3091+
return true;
3092+
}
3093+
3094+
llama_sbatch llama_kv_cache_hybrid::sbatch_init(const llama_batch & batch, bool logits_all) {
3095+
// If any of the caches are recurrent, require equal split
3096+
return llama_sbatch(batch, m_hparams.n_embd, !m_has_recurrent, logits_all);
3097+
}
3098+
3099+
llama_ubatch llama_kv_cache_hybrid::ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const {
3100+
if (embd_pooled) {
3101+
// Pooled embeddings cannot be split across ubatches (yet)
3102+
return sbatch.split_seq(n_ubatch);
3103+
}
3104+
if (m_has_recurrent) {
3105+
return sbatch.split_equal(n_ubatch);
3106+
}
3107+
return sbatch.split_simple(n_ubatch);
3108+
}
3109+
3110+
bool llama_kv_cache_hybrid::find_slot(const llama_ubatch & batch) {
3111+
bool found = true;
3112+
for (const auto & cache : m_children) {
3113+
found = cache->find_slot(batch) && found;
3114+
}
3115+
return found;
3116+
}
3117+
3118+
int32_t llama_kv_cache_hybrid::get_n_tokens() const {
3119+
// The number of tokens should be the same across all child caches
3120+
int32_t n_tokens = -1;
3121+
for (const auto & cache : m_children) {
3122+
const auto cache_n_tokens = cache->get_n_tokens();
3123+
GGML_ASSERT(n_tokens == -1 || cache_n_tokens == n_tokens);
3124+
n_tokens = cache_n_tokens;
3125+
}
3126+
return n_tokens;
3127+
}
3128+
3129+
int32_t llama_kv_cache_hybrid::get_used_cells() const {
3130+
// TODO: Is this correct?
3131+
// Return the largest number of used cells
3132+
int32_t used_cells = -1;
3133+
for (const auto & cache : m_children) {
3134+
used_cells = std::max(used_cells, cache->get_used_cells());
3135+
}
3136+
return used_cells;
3137+
}
3138+
3139+
llama_pos llama_kv_cache_hybrid::get_pos_max() const {
3140+
llama_pos pos_max = -1;
3141+
for (const auto & cache : m_children) {
3142+
pos_max = std::max(pos_max, cache->get_pos_max());
3143+
}
3144+
return pos_max;
3145+
}
3146+
3147+
bool llama_kv_cache_hybrid::get_can_shift() const {
3148+
// TODO: Is this correct?
3149+
// If any children can shift, return true
3150+
for (const auto & cache : m_children) {
3151+
if (cache->get_can_shift()) {
3152+
return true;
3153+
}
3154+
}
3155+
return false;
3156+
}
3157+
3158+
void llama_kv_cache_hybrid::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
3159+
// Write each cache state in order. Note that order is guaranteed at
3160+
// initialization by using an ordered set sorted by lowest layer ID
3161+
for (const auto & cache : m_children) {
3162+
cache->state_write(io, seq_id);
3163+
}
3164+
}
3165+
3166+
void llama_kv_cache_hybrid::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
3167+
// Read each cache state in order. Note that order is guaranteed at
3168+
// initialization by using an ordered set sorted by lowest layer ID
3169+
for (const auto & cache : m_children) {
3170+
cache->state_read(io, seq_id);
3171+
}
3172+
}

src/llama-kv-cache.h

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -549,3 +549,101 @@ class llama_kv_cache_recurrent : public llama_kv_cache {
549549
bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
550550
bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
551551
};
552+
553+
//
554+
// llama_kv_cache_hybrid
555+
//
556+
557+
// utilizes multiple different cache types with each layer assigned to exactly
558+
// one cache. This is typically used for hybrid attention / recurrent caching
559+
560+
class llama_kv_cache_hybrid : public llama_kv_cache {
561+
public:
562+
563+
struct child_cache {
564+
std::unique_ptr<llama_kv_cache> child;
565+
std::vector<size_t> layer_ids;
566+
567+
child_cache(std::unique_ptr<llama_kv_cache> child_, std::vector<size_t> layer_ids_)
568+
: child(std::move(child_)), layer_ids(std::move(layer_ids_)) {}
569+
};
570+
571+
llama_kv_cache_hybrid(
572+
const llama_hparams & hparams,
573+
std::vector<child_cache> children);
574+
575+
virtual ~llama_kv_cache_hybrid() = default;
576+
577+
// getters for specific child cache type
578+
// NOTE: This will fail if there are multiple of the given type
579+
template<typename child_t>
580+
const child_t * get_child_cache() const {
581+
const child_t * child = nullptr;
582+
for (const auto & child_cache : m_children) {
583+
const child_t * child_cast = dynamic_cast<const child_t *>(child_cache.get());
584+
if (child_cast) {
585+
GGML_ASSERT(!child);
586+
child = child_cast;
587+
}
588+
}
589+
return child;
590+
}
591+
592+
//
593+
// llama_memory_i
594+
//
595+
596+
void clear() override;
597+
598+
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
599+
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
600+
void seq_keep(llama_seq_id seq_id) override;
601+
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) override;
602+
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
603+
604+
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
605+
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
606+
607+
//
608+
// llama_kv_cache
609+
//
610+
611+
void restore() override;
612+
void commit() override;
613+
614+
bool update(llama_context & ctx) override;
615+
616+
void defrag_sched(float thold) override;
617+
618+
void set_full() override;
619+
620+
bool can_seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) const override;
621+
622+
llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override;
623+
624+
llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;
625+
626+
// updates the cache head
627+
// Note: On success, it's important that cache.head points
628+
// to the first cell of the slot.
629+
bool find_slot(const llama_ubatch & batch) override;
630+
631+
int32_t get_n_tokens() const override;
632+
int32_t get_used_cells() const override;
633+
634+
// TODO: better data structures to reduce the cost of this operation
635+
llama_pos get_pos_max() const override;
636+
637+
bool get_can_shift() const override;
638+
639+
// state write/load
640+
641+
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
642+
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
643+
644+
private:
645+
646+
const llama_hparams & m_hparams;
647+
const std::set<std::unique_ptr<llama_kv_cache>> m_children; // Ordered for state IO
648+
const bool m_has_recurrent;
649+
};

0 commit comments

Comments
 (0)