Skip to content

Commit efbd26c

Browse files
committed
context : hide kv cache details in implementation
ggml-ci
1 parent a0af785 commit efbd26c

File tree

5 files changed

+415
-363
lines changed

5 files changed

+415
-363
lines changed

src/llama-context.cpp

Lines changed: 10 additions & 338 deletions
Original file line numberDiff line numberDiff line change
@@ -436,349 +436,21 @@ const llama_kv_cache * llama_context::get_kv_self() const {
436436
return kv_self;
437437
}
438438

439-
ggml_tensor * llama_context::build_rope_shift(
440-
ggml_context * ctx0,
441-
ggml_tensor * cur,
442-
ggml_tensor * shift,
443-
ggml_tensor * factors,
444-
float freq_base,
445-
float freq_scale,
446-
ggml_backend_buffer * bbuf) const {
447-
const auto & n_ctx_orig = cparams.n_ctx_orig_yarn;
448-
449-
const auto & yarn_ext_factor = cparams.yarn_ext_factor;
450-
const auto & yarn_beta_fast = cparams.yarn_beta_fast;
451-
const auto & yarn_beta_slow = cparams.yarn_beta_slow;
452-
453-
const auto & hparams = model.hparams;
454-
455-
const auto & n_rot = hparams.n_rot;
456-
const auto & rope_type = hparams.rope_type;
457-
458-
// See llm_build_deepseek2() for why attn_factor has to be scaled for YaRN RoPE to work correctly.
459-
// See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation.
460-
const float yarn_attn_factor = model.arch == LLM_ARCH_DEEPSEEK2 ? 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale)) : cparams.yarn_attn_factor;
461-
462-
ggml_tensor * tmp;
463-
464-
if (ggml_is_quantized(cur->type)) {
465-
// dequantize to f32 -> RoPE -> quantize back
466-
tmp = ggml_cast(ctx0, cur, GGML_TYPE_F32);
467-
468-
if (bbuf) {
469-
for (const auto & backend : backends) {
470-
// Figure out which backend KV cache belongs to
471-
if (ggml_backend_supports_buft(backend.get(), ggml_backend_buffer_get_type(bbuf))) {
472-
ggml_backend_sched_set_tensor_backend(sched.get(), tmp, backend.get());
473-
break;
474-
}
475-
}
476-
}
477-
478-
tmp = ggml_rope_ext_inplace(ctx0, tmp,
479-
shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
480-
yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
481-
482-
tmp = ggml_cpy(ctx0, tmp, cur);
483-
} else {
484-
// we rotate only the first n_rot dimensions
485-
tmp = ggml_rope_ext_inplace(ctx0, cur,
486-
shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
487-
yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
488-
}
489-
490-
return tmp;
491-
}
492-
493-
class llm_graph_input_k_shift : public llm_graph_input_i {
494-
public:
495-
llm_graph_input_k_shift(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {}
496-
virtual ~llm_graph_input_k_shift() = default;
497-
498-
void set_input(const llama_ubatch * ubatch) override;
499-
500-
ggml_tensor * k_shift; // I32 [kv_size]
501-
502-
const llama_kv_cache_unified * kv_self;
503-
};
504-
505-
void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) {
506-
GGML_UNUSED(ubatch);
507-
508-
if (k_shift) {
509-
assert(ggml_backend_buffer_is_host(k_shift->buffer));
510-
511-
int32_t * data = (int32_t *) k_shift->data;
512-
513-
for (uint32_t i = 0; i < kv_self->size; ++i) {
514-
data[i] = kv_self->cells[i].delta;
515-
}
516-
}
517-
}
518-
519-
llm_graph_result_ptr llama_context::build_kv_self_shift(
520-
ggml_context * ctx0,
521-
ggml_cgraph * gf) const {
522-
auto res = std::make_unique<llm_graph_result>();
523-
524-
const auto & hparams = model.hparams;
525-
526-
const auto & n_layer = hparams.n_layer;
527-
528-
const auto & n_embd_head_k = hparams.n_embd_head_k;
529-
//const auto & n_embd_head_v = hparams.n_embd_head_v;
530-
531-
//GGML_ASSERT(kv_self->size == n_ctx);
532-
533-
const auto * kv = static_cast<const llama_kv_cache_unified *>(memory.get());
534-
535-
auto inp = std::make_unique<llm_graph_input_k_shift>(kv);
536-
537-
inp->k_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, cparams.n_ctx);
538-
ggml_set_input(inp->k_shift);
539-
540-
for (uint32_t il = 0; il < n_layer; ++il) {
541-
const int64_t n_head_kv = hparams.n_head_kv(il);
542-
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
543-
544-
const bool is_swa = hparams.is_swa(il);
545-
546-
// note: the swa rope params could become part of the cparams in the future
547-
// if we decide to make them configurable, like the non-sliding ones
548-
const float freq_base_l = is_swa ? hparams.rope_freq_base_train_swa : cparams.rope_freq_base;
549-
const float freq_scale_l = is_swa ? hparams.rope_freq_scale_train_swa : cparams.rope_freq_scale;
550-
551-
ggml_tensor * rope_factors = kv->cbs.get_rope_factors(n_ctx_per_seq(), il);
552-
553-
ggml_tensor * k =
554-
ggml_view_3d(ctx0, kv->k_l[il],
555-
n_embd_head_k, n_head_kv, kv->size,
556-
ggml_row_size(kv->k_l[il]->type, n_embd_head_k),
557-
ggml_row_size(kv->k_l[il]->type, n_embd_k_gqa),
558-
0);
559-
560-
ggml_tensor * cur = build_rope_shift(ctx0, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l, kv->k_l[il]->buffer);
561-
562-
ggml_build_forward_expand(gf, cur);
563-
}
564-
565-
res->add_input(std::move(inp));
566-
567-
return res;
568-
}
569-
570-
llm_graph_result_ptr llama_context::build_kv_self_defrag(
571-
ggml_context * ctx0,
572-
ggml_cgraph * gf) const {
573-
auto res = std::make_unique<llm_graph_result>();
574-
575-
auto * kv = static_cast<llama_kv_cache_unified *>(memory.get());
576-
577-
const auto & hparams = model.hparams;
578-
579-
const auto & ids = kv->defrag_info.ids;
580-
581-
#if 0
582-
// CPU defrag
583-
//
584-
// TODO: optimizations are possible:
585-
// - multiple threads
586-
// - avoid copying to the host memory when already there
587-
//
588-
// likely not worth the effort, as we have ggml_graph based defrag
589-
//
590-
591-
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
592-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
593-
594-
const uint32_t kv_size = size;
595-
596-
std::vector<uint8_t> buf_k;
597-
std::vector<uint8_t> buf_v;
598-
599-
for (uint32_t il = 0; il < n_layer; ++il) {
600-
const size_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa);
601-
const size_t k_size = ggml_row_size(k_l[il]->type, n_embd_k_gqa*kv_size);
602-
603-
const size_t v_size_el = ggml_type_size(v_l[il]->type);
604-
const size_t v_size = ggml_row_size (v_l[il]->type, n_embd_v_gqa*kv_size);
605-
606-
buf_k.resize(k_size);
607-
buf_v.resize(v_size);
608-
609-
ggml_backend_tensor_get(k_l[il], buf_k.data(), 0, buf_k.size());
610-
ggml_backend_tensor_get(v_l[il], buf_v.data(), 0, buf_v.size());
611-
612-
// batch move [i, i+nm) to [id, id+nm)
613-
// note: cells can move only to a lower index
614-
for (uint32_t i = 0; i < n_kv; ++i) {
615-
const uint32_t id = ids[i];
616-
617-
if (i == id || id == n_kv) {
618-
continue;
619-
}
620-
621-
uint32_t nm = 1;
622-
623-
while (i + nm < n_kv && ids[i + nm] == id + nm) {
624-
nm++;
625-
}
626-
627-
// move keys
628-
{
629-
const int64_t os = i*k_size_row;
630-
const int64_t od = id*k_size_row;
631-
632-
memcpy(buf_k.data() + od, buf_k.data() + os, nm*k_size_row);
633-
}
634-
635-
// move values (note: they are transposed)
636-
{
637-
const int64_t os = i;
638-
const int64_t od = id;
639-
640-
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
641-
memcpy(buf_v.data() + (od + j*kv_size)*v_size_el, buf_v.data() + (os + j*kv_size)*v_size_el, nm*v_size_el);
642-
}
643-
}
644-
645-
i += nm - 1;
646-
}
647-
648-
ggml_backend_tensor_set(k_l[il], buf_k.data(), 0, buf_k.size());
649-
ggml_backend_tensor_set(v_l[il], buf_v.data(), 0, buf_v.size());
650-
}
651-
#else
652-
for (uint32_t i = 0; i < ids.size(); ++i) {
653-
const uint32_t id = ids[i];
654-
655-
if (i == id || id == ids.size()) {
656-
continue;
657-
}
658-
659-
uint32_t nm = 1;
660-
661-
while (i + nm < ids.size() && ids[i + nm] == id + nm) {
662-
nm++;
663-
}
664-
665-
for (uint32_t il = 0; il < hparams.n_layer; ++il) { // NOLINT
666-
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
667-
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
668-
669-
ggml_tensor * view_k_src = ggml_view_2d(ctx0, kv->k_l[il],
670-
n_embd_k_gqa, nm,
671-
ggml_row_size(kv->k_l[il]->type, n_embd_k_gqa),
672-
ggml_row_size(kv->k_l[il]->type, n_embd_k_gqa*i));
673-
674-
ggml_tensor * view_k_dst = ggml_view_2d(ctx0, kv->k_l[il],
675-
n_embd_k_gqa, nm,
676-
ggml_row_size(kv->k_l[il]->type, n_embd_k_gqa),
677-
ggml_row_size(kv->k_l[il]->type, n_embd_k_gqa*id));
678-
679-
ggml_tensor * view_v_src;
680-
ggml_tensor * view_v_dst;
681-
682-
if (cparams.flash_attn) {
683-
// NOTE: the V cache is not transposed when using flash attention
684-
view_v_src = ggml_view_2d(ctx0, kv->v_l[il],
685-
n_embd_v_gqa, nm,
686-
ggml_row_size(kv->v_l[il]->type, n_embd_v_gqa),
687-
ggml_row_size(kv->v_l[il]->type, n_embd_v_gqa*i));
688-
689-
view_v_dst = ggml_view_2d(ctx0, kv->v_l[il],
690-
n_embd_v_gqa, nm,
691-
ggml_row_size(kv->v_l[il]->type, n_embd_v_gqa),
692-
ggml_row_size(kv->v_l[il]->type, n_embd_v_gqa*id));
693-
} else {
694-
view_v_src = ggml_view_2d(ctx0, kv->v_l[il],
695-
nm, n_embd_v_gqa,
696-
ggml_row_size(kv->v_l[il]->type, kv->size),
697-
ggml_row_size(kv->v_l[il]->type, i));
698-
699-
view_v_dst = ggml_view_2d(ctx0, kv->v_l[il],
700-
nm, n_embd_v_gqa,
701-
ggml_row_size(kv->v_l[il]->type, kv->size),
702-
ggml_row_size(kv->v_l[il]->type, id));
703-
}
704-
705-
ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_k_src, view_k_dst));
706-
ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_v_src, view_v_dst));
707-
}
708-
709-
i += nm - 1;
710-
}
711-
712-
//LLAMA_LOG_INFO("gf->n_nodes = %d\n", gf->n_nodes);
713-
#endif
714-
715-
return res;
716-
}
717-
718439
void llama_context::kv_self_update() {
719440
bool need_reserve = false;
720441

721442
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
722443

723-
if (kv_self->get_has_shift()) {
724-
if (!kv_self->get_can_shift()) {
725-
GGML_ABORT("The current KV cache / model configuration does not support K-shift");
726-
}
727-
728-
LLAMA_LOG_DEBUG("%s: applying K-shift\n", __func__);
729-
730-
// apply K-shift if needed
731-
if (model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE) {
732-
ggml_backend_sched_reset(sched.get());
733-
734-
auto * gf = graph_init();
735-
736-
auto res = build_kv_self_shift(ctx_compute.get(), gf);
737-
738-
ggml_backend_sched_alloc_graph(sched.get(), gf);
739-
740-
res->set_inputs(nullptr);
741-
742-
graph_compute(gf, false);
743-
744-
need_reserve = true;
745-
}
746-
747-
{
748-
auto * kv = static_cast<llama_kv_cache_unified *>(kv_self);
749-
750-
kv->has_shift = false;
751-
752-
for (uint32_t i = 0; i < kv->size; ++i) {
753-
kv->cells[i].delta = 0;
754-
}
755-
}
756-
}
757-
758-
// defragment the KV cache if needed
759-
if (kv_self->get_do_defrag()) {
760-
LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__);
761-
762-
auto * kv = static_cast<llama_kv_cache_unified *>(kv_self);
763-
764-
if (kv->defrag_prepare(graph_max_nodes())) {
765-
ggml_backend_sched_reset(sched.get());
766-
767-
auto * gf = graph_init();
768-
769-
auto res = build_kv_self_defrag(ctx_compute.get(), gf);
770-
771-
ggml_backend_sched_alloc_graph(sched.get(), gf);
772-
773-
res->set_inputs(nullptr);
774-
775-
graph_compute(gf, false);
776-
777-
need_reserve = true;
778-
}
779-
780-
kv->do_defrag = false;
781-
}
444+
need_reserve = kv_self->update({
445+
/*.arch =*/ model.arch,
446+
/*.cparams =*/ cparams,
447+
/*.sched =*/ sched.get(),
448+
/*.backends =*/ backends,
449+
/*.n_max_nodes =*/ graph_max_nodes(),
450+
/*.get_ctx_compute =*/ [this]() { return ctx_compute.get(); },
451+
/*.graph_init =*/ [this]() { return graph_init(); },
452+
/*.graph_compute =*/ [this](ggml_cgraph * gf) { graph_compute(gf, false); },
453+
});
782454

783455
// reserve a worst case graph if needed
784456
if (need_reserve) {

src/llama-context.h

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -159,24 +159,6 @@ struct llama_context {
159159

160160
llm_graph_cb graph_get_cb() const;
161161

162-
// used by kv_self_update()
163-
ggml_tensor * build_rope_shift(
164-
ggml_context * ctx0,
165-
ggml_tensor * cur,
166-
ggml_tensor * shift,
167-
ggml_tensor * factors,
168-
float freq_base,
169-
float freq_scale,
170-
ggml_backend_buffer * bbuf) const;
171-
172-
llm_graph_result_ptr build_kv_self_shift(
173-
ggml_context * ctx0,
174-
ggml_cgraph * gf) const;
175-
176-
llm_graph_result_ptr build_kv_self_defrag(
177-
ggml_context * ctx0,
178-
ggml_cgraph * gf) const;
179-
180162
// TODO: read/write lora adapters and cvec
181163
size_t state_write_data(llama_io_write_i & io);
182164
size_t state_read_data (llama_io_read_i & io);

0 commit comments

Comments
 (0)