Skip to content

Commit 43cbf38

Browse files
committed
kv-cache : replace struct callbacks with llama_model &
ggml-ci
1 parent 92e626b commit 43cbf38

File tree

5 files changed

+80
-95
lines changed

5 files changed

+80
-95
lines changed

src/llama-context.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -440,7 +440,6 @@ void llama_context::kv_self_update() {
440440
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
441441

442442
need_reserve = kv_self->update({
443-
/*.arch =*/ model.arch,
444443
/*.cparams =*/ cparams,
445444
/*.sched =*/ sched.get(),
446445
/*.backends =*/ backends,

src/llama-kv-cache.cpp

Lines changed: 38 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,13 @@ uint32_t llama_kv_cache_unified::get_padding(const llama_cparams & cparams) {
2222
}
2323

2424
llama_kv_cache_unified::llama_kv_cache_unified(
25-
const llama_hparams & hparams,
26-
callbacks cbs,
27-
ggml_type type_k,
28-
ggml_type type_v,
29-
bool v_trans,
30-
uint32_t kv_size,
31-
uint32_t padding) : cbs(std::move(cbs)), hparams(hparams), v_trans(v_trans), padding(padding) {
25+
const llama_model & model,
26+
ggml_type type_k,
27+
ggml_type type_v,
28+
bool v_trans,
29+
bool offload,
30+
uint32_t kv_size,
31+
uint32_t padding) : model(model), hparams(model.hparams), v_trans(v_trans), padding(padding) {
3232
const int32_t n_layer = hparams.n_layer;
3333

3434
has_shift = false;
@@ -81,7 +81,18 @@ llama_kv_cache_unified::llama_kv_cache_unified(
8181
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
8282
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
8383

84-
ggml_backend_buffer_type_t buft = this->cbs.get_buft(i);
84+
const char * dev_name = "CPU";
85+
86+
ggml_backend_buffer_type_t buft = ggml_backend_cpu_buffer_type();
87+
88+
if (!offload) {
89+
auto * dev = model.dev_layer(i);
90+
buft = ggml_backend_dev_buffer_type(dev);
91+
92+
dev_name = ggml_backend_dev_name(dev);
93+
}
94+
95+
LLAMA_LOG_DEBUG("layer %3d: dev = %s\n", i, dev_name);
8596

8697
ggml_context * ctx = ctx_for_buft(buft);
8798
if (!ctx) {
@@ -588,7 +599,6 @@ ggml_tensor * llama_kv_cache_unified::build_rope_shift(
588599
float freq_base,
589600
float freq_scale,
590601
ggml_backend_buffer * bbuf) const {
591-
const auto & arch = params.arch;
592602
const auto & cparams = params.cparams;
593603
const auto & backends = params.backends;
594604
const auto & sched = params.sched;
@@ -604,7 +614,7 @@ ggml_tensor * llama_kv_cache_unified::build_rope_shift(
604614

605615
// See llm_build_deepseek2() for why attn_factor has to be scaled for YaRN RoPE to work correctly.
606616
// See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation.
607-
const float yarn_attn_factor = arch == LLM_ARCH_DEEPSEEK2 ? 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale)) : cparams.yarn_attn_factor;
617+
const float yarn_attn_factor = model.arch == LLM_ARCH_DEEPSEEK2 ? 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale)) : cparams.yarn_attn_factor;
608618

609619
ggml_tensor * tmp;
610620

@@ -697,7 +707,7 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
697707
const float freq_base_l = is_swa ? hparams.rope_freq_base_train_swa : cparams.rope_freq_base;
698708
const float freq_scale_l = is_swa ? hparams.rope_freq_scale_train_swa : cparams.rope_freq_scale;
699709

700-
ggml_tensor * rope_factors = cbs.get_rope_factors(n_ctx_per_seq, il);
710+
ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
701711

702712
ggml_tensor * k =
703713
ggml_view_3d(ctx, k_l[il],
@@ -1377,11 +1387,11 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
13771387
//
13781388

13791389
llama_kv_cache_recurrent::llama_kv_cache_recurrent(
1380-
const llama_hparams & hparams,
1381-
callbacks cbs,
1382-
ggml_type type_k,
1383-
ggml_type type_v,
1384-
uint32_t kv_size) : cbs(std::move(cbs)), hparams(hparams) {
1390+
const llama_model & model,
1391+
ggml_type type_k,
1392+
ggml_type type_v,
1393+
bool offload,
1394+
uint32_t kv_size) : hparams(model.hparams) {
13851395
const int32_t n_layer = hparams.n_layer;
13861396

13871397
LLAMA_LOG_INFO("%s: kv_size = %d, type_k = '%s', type_v = '%s', n_layer = %d\n",
@@ -1429,7 +1439,18 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent(
14291439
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
14301440
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
14311441

1432-
ggml_backend_buffer_type_t buft = this->cbs.get_buft(i);
1442+
const char * dev_name = "CPU";
1443+
1444+
ggml_backend_buffer_type_t buft = ggml_backend_cpu_buffer_type();
1445+
1446+
if (!offload) {
1447+
auto * dev = model.dev_layer(i);
1448+
buft = ggml_backend_dev_buffer_type(dev);
1449+
1450+
dev_name = ggml_backend_dev_name(dev);
1451+
}
1452+
1453+
LLAMA_LOG_DEBUG("layer %3d: dev = %s\n", i, dev_name);
14331454

14341455
ggml_context * ctx = ctx_for_buft(buft);
14351456
if (!ctx) {

src/llama-kv-cache.h

Lines changed: 15 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -15,19 +15,10 @@ struct llama_cparams;
1515
struct llama_hparams;
1616
struct llama_ubatch;
1717
struct llama_sbatch;
18+
struct llama_model;
1819

1920
struct llama_kv_cache : public llama_memory_i {
20-
// can be used to query data from the model if needed
21-
struct callbacks {
22-
std::function<ggml_tensor * (uint32_t n_ctx_per_seq, int il)> get_rope_factors;
23-
24-
// get the buffer type of layer il, can be used to offload KV cache layers to a different device
25-
std::function<ggml_backend_buffer_type_t (int il)> get_buft;
26-
};
27-
2821
struct graph_params {
29-
const llm_arch arch;
30-
3122
const llama_cparams & cparams;
3223

3324
const ggml_backend_sched_t & sched;
@@ -139,13 +130,13 @@ class llama_kv_cache_unified : public llama_kv_cache {
139130
static uint32_t get_padding(const llama_cparams & cparams);
140131

141132
llama_kv_cache_unified(
142-
const llama_hparams & hparams,
143-
callbacks cbs,
144-
ggml_type type_k,
145-
ggml_type type_v,
146-
bool v_trans,
147-
uint32_t kv_size,
148-
uint32_t padding);
133+
const llama_model & model,
134+
ggml_type type_k,
135+
ggml_type type_v,
136+
bool v_trans,
137+
bool offload,
138+
uint32_t kv_size,
139+
uint32_t padding);
149140

150141
~llama_kv_cache_unified() = default;
151142

@@ -208,14 +199,13 @@ class llama_kv_cache_unified : public llama_kv_cache {
208199
// computed before each graph build
209200
uint32_t n = 0;
210201

211-
callbacks cbs;
212-
213202
std::vector<kv_cell> cells;
214203

215204
std::vector<ggml_tensor *> k_l; // per layer
216205
std::vector<ggml_tensor *> v_l;
217206

218207
private:
208+
const llama_model & model;
219209
const llama_hparams & hparams;
220210

221211
bool has_shift = false;
@@ -312,11 +302,11 @@ class llama_kv_cache_recurrent : public llama_kv_cache {
312302
};
313303

314304
llama_kv_cache_recurrent(
315-
const llama_hparams & hparams,
316-
callbacks cbs,
317-
ggml_type type_k,
318-
ggml_type type_v,
319-
uint32_t kv_size);
305+
const llama_model & model,
306+
ggml_type type_k,
307+
ggml_type type_v,
308+
bool offload,
309+
uint32_t kv_size);
320310

321311
~llama_kv_cache_recurrent() = default;
322312

@@ -370,8 +360,6 @@ class llama_kv_cache_recurrent : public llama_kv_cache {
370360
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
371361
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
372362

373-
callbacks cbs;
374-
375363
// Note: The value of head isn't only used to optimize searching
376364
// for a free KV slot. llama_decode_impl also uses it, so it
377365
// cannot be freely changed after a slot has been allocated.
@@ -388,6 +376,7 @@ class llama_kv_cache_recurrent : public llama_kv_cache {
388376
std::vector<ggml_tensor *> v_l;
389377

390378
private:
379+
//const llama_model & model;
391380
const llama_hparams & hparams;
392381

393382
// commit/restore cache

src/llama-model.cpp

Lines changed: 25 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -4445,6 +4445,19 @@ const ggml_tensor * llama_model::get_tensor(const char * name) const {
44454445
return it->second;
44464446
}
44474447

4448+
ggml_tensor * llama_model::get_rope_factors(uint32_t n_ctx_per_seq, int il) const {
4449+
// choose long/short freq factors based on the context size
4450+
if (layers[il].rope_freqs != nullptr) {
4451+
return layers[il].rope_freqs;
4452+
}
4453+
4454+
if (n_ctx_per_seq > hparams.n_ctx_orig_yarn) {
4455+
return layers[il].rope_long;
4456+
}
4457+
4458+
return layers[il].rope_short;
4459+
}
4460+
44484461
struct llm_build_llama : public llm_graph_context {
44494462
llm_build_llama(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
44504463
const int64_t n_embd_head = hparams.n_embd_head_v;
@@ -4485,7 +4498,7 @@ struct llm_build_llama : public llm_graph_context {
44854498
// self-attention
44864499
{
44874500
// rope freq factors for llama3; may return nullptr for llama2 and other models
4488-
ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
4501+
ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
44894502

44904503
// compute Q and K and RoPE them
44914504
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
@@ -4710,7 +4723,7 @@ struct llm_build_deci : public llm_graph_context {
47104723
} else if (n_head > 0) {
47114724
// self-attention
47124725
// rope freq factors for llama3; may return nullptr for llama2 and other models
4713-
ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
4726+
ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
47144727

47154728
// compute Q and K and RoPE them
47164729
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
@@ -7192,7 +7205,7 @@ struct llm_build_phi3 : public llm_graph_context {
71927205
// self-attention
71937206
{
71947207
// rope freq factors for 128k context
7195-
ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
7208+
ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
71967209

71977210
ggml_tensor* attn_norm_output = build_norm(inpL,
71987211
model.layers[il].attn_norm,
@@ -7944,7 +7957,7 @@ struct llm_build_minicpm3 : public llm_graph_context {
79447957
for (int il = 0; il < n_layer; ++il) {
79457958
ggml_tensor * inpSA = inpL;
79467959

7947-
ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
7960+
ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
79487961

79497962
// norm
79507963
cur = build_norm(inpL,
@@ -9012,7 +9025,7 @@ struct llm_build_cohere2 : public llm_graph_context {
90129025
// self-attention
90139026
{
90149027
// rope freq factors for 128k context
9015-
ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
9028+
ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
90169029

90179030
// compute Q and K and RoPE them
90189031
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
@@ -9950,7 +9963,7 @@ struct llm_build_deepseek : public llm_graph_context {
99509963
// self-attention
99519964
{
99529965
// rope freq factors for llama3; may return nullptr for llama2 and other models
9953-
ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
9966+
ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
99549967

99559968
// compute Q and K and RoPE them
99569969
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
@@ -11314,7 +11327,7 @@ struct llm_build_exaone : public llm_graph_context {
1131411327
// self-attention
1131511328
{
1131611329
// rope freq factors for llama3; may return nullptr for llama2 and other models
11317-
ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
11330+
ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
1131811331

1131911332
// compute Q and K and RoPE them
1132011333
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
@@ -12695,7 +12708,7 @@ struct llm_build_bailingmoe : public llm_graph_context {
1269512708
// self-attention
1269612709
{
1269712710
// rope freq factors for llama3; may return nullptr for llama2 and other models
12698-
ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
12711+
ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
1269912712

1270012713
// compute Q and K and RoPE them
1270112714
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
@@ -12818,28 +12831,6 @@ struct llm_build_bailingmoe : public llm_graph_context {
1281812831
llama_memory_i * llama_model::create_memory(const llama_memory_params & params, llama_cparams & cparams) const {
1281912832
llama_memory_i * res;
1282012833

12821-
const bool offload = cparams.offload_kqv;
12822-
12823-
auto get_buft = [this, offload](int il) {
12824-
const char * dev_name = "CPU";
12825-
12826-
ggml_backend_buffer_type_t buft;
12827-
if (offload) {
12828-
auto * dev = dev_layer(il);
12829-
buft = ggml_backend_dev_buffer_type(dev);
12830-
12831-
dev_name = ggml_backend_dev_name(dev);
12832-
} else {
12833-
buft = ggml_backend_cpu_buffer_type();
12834-
}
12835-
12836-
LLAMA_LOG_DEBUG("layer %3d: dev = %s\n", il, dev_name);
12837-
12838-
return buft;
12839-
};
12840-
12841-
LLAMA_LOG_DEBUG("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
12842-
1284312834
switch (arch) {
1284412835
case LLM_ARCH_MAMBA:
1284512836
case LLM_ARCH_RWKV6:
@@ -12848,13 +12839,10 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
1284812839
case LLM_ARCH_ARWKV7:
1284912840
{
1285012841
res = new llama_kv_cache_recurrent(
12851-
hparams,
12852-
{
12853-
/*.get_rope_factors =*/ nullptr,
12854-
/*.get_buft =*/ get_buft,
12855-
},
12842+
*this,
1285612843
GGML_TYPE_F32,
1285712844
GGML_TYPE_F32,
12845+
cparams.offload_kqv,
1285812846
std::max((uint32_t) 1, cparams.n_seq_max));
1285912847
} break;
1286012848
default:
@@ -12866,25 +12854,11 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
1286612854
LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
1286712855

1286812856
res = new llama_kv_cache_unified(
12869-
hparams,
12870-
{
12871-
/*.get_rope_factors =*/ [this](uint32_t n_ctx_per_seq, int il) {
12872-
// choose long/short freq factors based on the context size
12873-
if (layers[il].rope_freqs != nullptr) {
12874-
return layers[il].rope_freqs;
12875-
}
12876-
12877-
if (n_ctx_per_seq > hparams.n_ctx_orig_yarn) {
12878-
return layers[il].rope_long;
12879-
}
12880-
12881-
return layers[il].rope_short;
12882-
},
12883-
/*.get_buft =*/ get_buft,
12884-
},
12857+
*this,
1288512858
params.type_k,
1288612859
params.type_v,
1288712860
!cparams.flash_attn,
12861+
cparams.offload_kqv,
1288812862
cparams.n_ctx,
1288912863
padding);
1289012864
}

src/llama-model.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,8 @@ struct llama_model {
395395

396396
const struct ggml_tensor * get_tensor(const char * name) const;
397397

398+
ggml_tensor * get_rope_factors(uint32_t n_ctx_per_seq, int il) const;
399+
398400
// note: can mutate `cparams`
399401
// TODO: move this to new llm_arch_model_i interface
400402
llama_memory_i * create_memory(const llama_memory_params & params, llama_cparams & cparams) const;

0 commit comments

Comments
 (0)