Skip to content

Commit 18534e5

Browse files
committed
Merge branch 'HybridCache' into GraniteFour
* HybridCache: fix: Give ownership of child caches to the hybrid cache fix: Mark unused params correctly fix: Split up seq_rm interface into immutable can_seq_rm and mutating seq_rm fix: Fix confusion on simple vs equal splitting feat: First pass at llama_kv_cache_hybrid
2 parents 9c0264d + 01c3555 commit 18534e5

File tree

2 files changed

+373
-15
lines changed

2 files changed

+373
-15
lines changed

src/llama-kv-cache.cpp

Lines changed: 285 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,14 @@ void llama_kv_cache_unified::set_full() {
443443
n = size;
444444
}
445445

446+
bool llama_kv_cache_unified::can_seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) const {
447+
GGML_UNUSED(seq_id);
448+
GGML_UNUSED(p0);
449+
GGML_UNUSED(p1);
450+
// Unified attention cache can always do a sequence removal
451+
return true;
452+
}
453+
446454
llama_sbatch llama_kv_cache_unified::sbatch_init(
447455
const llama_batch & batch,
448456
bool logits_all) {
@@ -1479,39 +1487,33 @@ void llama_kv_cache_recurrent::clear() {
14791487
}
14801488

14811489
bool llama_kv_cache_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
1482-
uint32_t new_head = size;
1490+
if (!can_seq_rm(seq_id, p0, p1)) {
1491+
// could be fatal
1492+
return false;
1493+
}
14831494

1495+
uint32_t new_head = size;
14841496
if (p0 < 0) {
14851497
p0 = 0;
14861498
}
1487-
14881499
if (p1 < 0) {
14891500
p1 = std::numeric_limits<llama_pos>::max();
14901501
}
14911502

1492-
// models like Mamba or RWKV can't have a state partially erased
1493-
if (seq_id >= (int64_t) size) {
1494-
// could be fatal
1495-
return false;
1496-
}
14971503
if (0 <= seq_id) {
14981504
int32_t & tail_id = cells[seq_id].tail;
14991505
if (tail_id >= 0) {
15001506
const kv_cell & cell = cells[tail_id];
1501-
// partial intersection is invalid
1502-
if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) {
1503-
return false;
1504-
}
1507+
// already validated in can_seq_rm
1508+
GGML_ASSERT(!((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)));
15051509
// invalidate tails which will be cleared
15061510
if (p0 <= cell.pos && cell.pos < p1) {
15071511
tail_id = -1;
15081512
}
15091513
}
15101514
} else {
1511-
// seq_id is negative, then the range should include everything or nothing
1512-
if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits<llama_pos>::max())) {
1513-
return false;
1514-
}
1515+
// already validated in can_seq_rm
1516+
GGML_ASSERT(!(p0 != p1 && (p0 != 0 || p1 != std::numeric_limits<llama_pos>::max())));
15151517
}
15161518

15171519
for (uint32_t i = 0; i < size; ++i) {
@@ -1712,6 +1714,35 @@ void llama_kv_cache_recurrent::set_full() {
17121714
n = size;
17131715
}
17141716

1717+
bool llama_kv_cache_recurrent::can_seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) const {
1718+
if (p0 < 0) {
1719+
p0 = 0;
1720+
}
1721+
1722+
if (p1 < 0) {
1723+
p1 = std::numeric_limits<llama_pos>::max();
1724+
}
1725+
// models like Mamba or RWKV can't have a state partially erased
1726+
if (seq_id >= (int64_t) size) {
1727+
// could be fatal
1728+
return false;
1729+
}
1730+
if (0 <= seq_id) {
1731+
const int32_t & tail_id = cells[seq_id].tail;
1732+
if (tail_id >= 0) {
1733+
const kv_cell & cell = cells[tail_id];
1734+
// partial intersection is invalid
1735+
if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) {
1736+
return false;
1737+
}
1738+
}
1739+
// seq_id is negative, then the range should include everything or nothing
1740+
} else if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits<llama_pos>::max())) {
1741+
return false;
1742+
}
1743+
return true;
1744+
}
1745+
17151746
llama_sbatch llama_kv_cache_recurrent::sbatch_init(
17161747
const llama_batch & batch,
17171748
bool logits_all) {
@@ -2355,6 +2386,245 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
23552386
return true;
23562387
}
23572388

2389+
//
2390+
// llama_kv_cache_hybrid
2391+
//
2392+
llama_kv_cache_hybrid::llama_kv_cache_hybrid(
2393+
const llama_hparams & hparams,
2394+
std::vector<child_cache> children) :
2395+
m_hparams(hparams),
2396+
m_layer_cache_map(
2397+
[](const std::vector<child_cache>& caches) -> std::unordered_map<size_t, llama_kv_cache*> {
2398+
std::unordered_map<size_t, llama_kv_cache*> map;
2399+
for (const auto & cache : caches) {
2400+
for (size_t layer_id : cache.layer_ids) {
2401+
map[layer_id] = cache.child.get();
2402+
}
2403+
}
2404+
2405+
return map;
2406+
}(children)
2407+
),
2408+
m_children(
2409+
[](std::vector<child_cache>& caches) -> std::set<std::unique_ptr<llama_kv_cache>> {
2410+
// Sort the caches by the lowest layer ID so the order is repeatable
2411+
for (auto & cache : caches) {
2412+
GGML_ASSERT(cache.layer_ids.size() > 0);
2413+
std::sort(cache.layer_ids.begin(), cache.layer_ids.end());
2414+
}
2415+
std::sort(caches.begin(), caches.end(), [](const child_cache & a, const child_cache & b) {
2416+
return a.layer_ids[0] < b.layer_ids[0];
2417+
});
2418+
std::set<std::unique_ptr<llama_kv_cache>> unique_caches;
2419+
for (auto & cache : caches) {
2420+
unique_caches.emplace(cache.child.release());
2421+
}
2422+
return unique_caches;
2423+
}(children)
2424+
),
2425+
m_has_recurrent(
2426+
[](const std::set<std::unique_ptr<llama_kv_cache>> & caches) -> bool {
2427+
for (const auto & cache : caches) {
2428+
if (dynamic_cast<llama_kv_cache_recurrent *>(cache.get())) {
2429+
return true;
2430+
}
2431+
}
2432+
return false;
2433+
}(m_children)
2434+
)
2435+
{
2436+
// Ensure at least one child
2437+
GGML_ASSERT(m_children.size() > 0);
2438+
2439+
// Ensure layers are not overlapping and are concurrent
2440+
std::set<size_t> seen_layers;
2441+
size_t max_layer = 0;
2442+
for (const auto & cache : children) {
2443+
for (const auto & layer_id : cache.layer_ids) {
2444+
GGML_ASSERT(seen_layers.find(layer_id) == seen_layers.end());
2445+
seen_layers.insert(layer_id);
2446+
if (layer_id > max_layer) {
2447+
max_layer = layer_id;
2448+
}
2449+
}
2450+
}
2451+
GGML_ASSERT(max_layer == seen_layers.size());
2452+
}
2453+
2454+
void llama_kv_cache_hybrid::clear() {
2455+
for (const auto & cache : m_children) {
2456+
cache->clear();
2457+
}
2458+
}
2459+
2460+
bool llama_kv_cache_hybrid::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
2461+
// First check if we can do this removal. This checks all children so that
2462+
// no mutation happens before we know if it's possible
2463+
if (!can_seq_rm(seq_id, p0, p1)) {
2464+
return false;
2465+
}
2466+
2467+
// Do the removal from each child which should never fail
2468+
for (const auto & cache : m_children) {
2469+
const bool failed = cache->seq_rm(seq_id, p0, p1);
2470+
GGML_ASSERT(!failed);
2471+
}
2472+
return true;
2473+
}
2474+
2475+
void llama_kv_cache_hybrid::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
2476+
for (const auto & cache : m_children) {
2477+
cache->seq_cp(seq_id_src, seq_id_dst, p0, p1);
2478+
}
2479+
}
2480+
2481+
void llama_kv_cache_hybrid::seq_keep(llama_seq_id seq_id) {
2482+
for (const auto & cache : m_children) {
2483+
cache->seq_keep(seq_id);
2484+
}
2485+
}
2486+
2487+
void llama_kv_cache_hybrid::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
2488+
for (const auto & cache : m_children) {
2489+
cache->seq_add(seq_id, p0, p1, delta);
2490+
}
2491+
}
2492+
2493+
void llama_kv_cache_hybrid::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
2494+
for (const auto & cache : m_children) {
2495+
cache->seq_div(seq_id, p0, p1, d);
2496+
}
2497+
}
2498+
2499+
llama_pos llama_kv_cache_hybrid::seq_pos_max(llama_seq_id seq_id) const {
2500+
llama_pos max_pos = 0;
2501+
for (const auto & cache : m_children) {
2502+
max_pos = std::max(max_pos, cache->seq_pos_max(seq_id));
2503+
}
2504+
return max_pos;
2505+
}
2506+
2507+
void llama_kv_cache_hybrid::restore() {
2508+
for (const auto & cache : m_children) {
2509+
cache->restore();
2510+
}
2511+
}
2512+
2513+
void llama_kv_cache_hybrid::commit() {
2514+
for (const auto & cache : m_children) {
2515+
cache->commit();
2516+
}
2517+
}
2518+
2519+
bool llama_kv_cache_hybrid::update(llama_context & ctx) {
2520+
bool updated = false;
2521+
for (const auto & cache : m_children) {
2522+
updated = cache->update(ctx) || updated;
2523+
}
2524+
return updated;
2525+
}
2526+
2527+
void llama_kv_cache_hybrid::defrag_sched(float thold) {
2528+
for (const auto & cache : m_children) {
2529+
cache->defrag_sched(thold);
2530+
}
2531+
}
2532+
2533+
void llama_kv_cache_hybrid::set_full() {
2534+
for (const auto & cache : m_children) {
2535+
cache->set_full();
2536+
}
2537+
}
2538+
2539+
bool llama_kv_cache_hybrid::can_seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) const {
2540+
for (const auto & cache : m_children) {
2541+
if (!cache->can_seq_rm(seq_id, p0, p1)) {
2542+
return false;
2543+
}
2544+
}
2545+
return true;
2546+
}
2547+
2548+
llama_sbatch llama_kv_cache_hybrid::sbatch_init(const llama_batch & batch, bool logits_all) {
2549+
// If any of the caches are recurrent, require equal split
2550+
return llama_sbatch(batch, m_hparams.n_embd, !m_has_recurrent, logits_all);
2551+
}
2552+
2553+
llama_ubatch llama_kv_cache_hybrid::ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const {
2554+
if (embd_pooled) {
2555+
// Pooled embeddings cannot be split across ubatches (yet)
2556+
return sbatch.split_seq(n_ubatch);
2557+
}
2558+
if (m_has_recurrent) {
2559+
return sbatch.split_equal(n_ubatch);
2560+
}
2561+
return sbatch.split_simple(n_ubatch);
2562+
}
2563+
2564+
bool llama_kv_cache_hybrid::find_slot(const llama_ubatch & batch) {
2565+
bool found = true;
2566+
for (const auto & cache : m_children) {
2567+
found = cache->find_slot(batch) && found;
2568+
}
2569+
return found;
2570+
}
2571+
2572+
int32_t llama_kv_cache_hybrid::get_n_tokens() const {
2573+
// The number of tokens should be the same across all child caches
2574+
int32_t n_tokens = -1;
2575+
for (const auto & cache : m_children) {
2576+
const auto cache_n_tokens = cache->get_n_tokens();
2577+
GGML_ASSERT(n_tokens == -1 || cache_n_tokens == n_tokens);
2578+
n_tokens = cache_n_tokens;
2579+
}
2580+
return n_tokens;
2581+
}
2582+
2583+
int32_t llama_kv_cache_hybrid::get_used_cells() const {
2584+
// TODO: Is this correct?
2585+
// Return the largest number of used cells
2586+
int32_t used_cells = -1;
2587+
for (const auto & cache : m_children) {
2588+
used_cells = std::max(used_cells, cache->get_used_cells());
2589+
}
2590+
return used_cells;
2591+
}
2592+
2593+
llama_pos llama_kv_cache_hybrid::get_pos_max() const {
2594+
llama_pos pos_max = -1;
2595+
for (const auto & cache : m_children) {
2596+
pos_max = std::max(pos_max, cache->get_pos_max());
2597+
}
2598+
return pos_max;
2599+
}
2600+
2601+
bool llama_kv_cache_hybrid::get_can_shift() const {
2602+
// TODO: Is this correct?
2603+
// If any children can shift, return true
2604+
for (const auto & cache : m_children) {
2605+
if (cache->get_can_shift()) {
2606+
return true;
2607+
}
2608+
}
2609+
return false;
2610+
}
2611+
2612+
void llama_kv_cache_hybrid::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
2613+
// Write each cache state in order. Note that order is guaranteed at
2614+
// initialization by using an ordered set sorted by lowest layer ID
2615+
for (const auto & cache : m_children) {
2616+
cache->state_write(io, seq_id);
2617+
}
2618+
}
2619+
2620+
void llama_kv_cache_hybrid::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
2621+
// Read each cache state in order. Note that order is guaranteed at
2622+
// initialization by using an ordered set sorted by lowest layer ID
2623+
for (const auto & cache : m_children) {
2624+
cache->state_read(io, seq_id);
2625+
}
2626+
}
2627+
23582628
//
23592629
// kv cache view
23602630
//

0 commit comments

Comments
 (0)