Skip to content

kv-cache : prepare K/V buffers for separation #14517

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions src/llama-hparams.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,46 @@ uint32_t llama_hparams::n_embd_v_gqa(uint32_t il) const {
return n_embd_head_v * n_head_kv;
}

bool llama_hparams::is_n_embd_k_gqa_variable() const {
const uint32_t val = n_embd_k_gqa();
for (uint32_t il = 0; il < n_layer; ++il) {
if (val != n_embd_k_gqa(il)) {
return true;
}
}

return false;
}

bool llama_hparams::is_n_embd_v_gqa_variable() const {
const uint32_t val = n_embd_v_gqa();
for (uint32_t il = 0; il < n_layer; ++il) {
if (val != n_embd_v_gqa(il)) {
return true;
}
}

return false;
}

uint32_t llama_hparams::n_embd_k_gqa_max() const {
uint32_t val = n_embd_k_gqa();
for (uint32_t il = 0; il < n_layer; ++il) {
val = std::max(val, n_embd_k_gqa(il));
}

return val;
}

uint32_t llama_hparams::n_embd_v_gqa_max() const {
uint32_t val = n_embd_v_gqa();
for (uint32_t il = 0; il < n_layer; ++il) {
val = std::max(val, n_embd_v_gqa(il));
}

return val;
}

uint32_t llama_hparams::n_embd_r() const {
if (wkv_head_size != 0) {
// for RWKV models
Expand Down
8 changes: 8 additions & 0 deletions src/llama-hparams.h
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,14 @@ struct llama_hparams {
// dimension of value embeddings across all k-v heads
uint32_t n_embd_v_gqa(uint32_t il = 0) const;

// true if any layer has a different n_embd_k_gqa/n_embd_v_gqa
bool is_n_embd_k_gqa_variable() const;
bool is_n_embd_v_gqa_variable() const;

// return the maximum n_embd_k_gqa/n_embd_v_gqa across all layers
uint32_t n_embd_k_gqa_max() const;
uint32_t n_embd_v_gqa_max() const;

// dimension of the rolling state embeddings
// corresponds to Mamba's conv_states size or RWKV's token_shift states size
uint32_t n_embd_r() const;
Expand Down
115 changes: 79 additions & 36 deletions src/llama-kv-cache-unified.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,21 @@ llama_kv_cache_unified::llama_kv_cache_unified(

cells.resize(kv_size);

// [TAG_V_CACHE_VARIABLE]
if (v_trans && hparams.is_n_embd_v_gqa_variable()) {
LLAMA_LOG_WARN("%s: the V embeddings have different sizes across layers and FA is not enabled - padding V cache to %d\n",
__func__, hparams.n_embd_v_gqa_max());
}

for (uint32_t il = 0; il < n_layer_cache; il++) {
if (filter && !filter(il)) {
LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, il);
continue;
}

const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
// [TAG_V_CACHE_VARIABLE]
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
const uint32_t n_embd_v_gqa = !v_trans ? hparams.n_embd_v_gqa(il) : hparams.n_embd_v_gqa_max();

const char * dev_name = "CPU";

Expand All @@ -98,8 +105,8 @@ llama_kv_cache_unified::llama_kv_cache_unified(
ggml_tensor * k;
ggml_tensor * v;

k = ggml_new_tensor_2d(ctx, type_k, n_embd_k_gqa, kv_size);
v = ggml_new_tensor_2d(ctx, type_v, n_embd_v_gqa, kv_size);
k = ggml_new_tensor_3d(ctx, type_k, n_embd_k_gqa, kv_size, 1);
v = ggml_new_tensor_3d(ctx, type_v, n_embd_v_gqa, kv_size, 1);

ggml_format_name(k, "cache_k_l%d", il);
ggml_format_name(v, "cache_v_l%d", il);
Expand Down Expand Up @@ -780,33 +787,47 @@ ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il, uint

auto * k = layers[ikv].k;

return ggml_view_3d(ctx, k,
hparams.n_embd_head_k, hparams.n_head_kv(il), n_kv,
const uint64_t kv_size = get_size();
const uint64_t n_embd_k_gqa = k->ne[0];

assert(n_embd_k_gqa == hparams.n_embd_k_gqa(il));

return ggml_view_4d(ctx, k,
hparams.n_embd_head_k, hparams.n_head_kv(il), n_kv, 1,
ggml_row_size(k->type, hparams.n_embd_head_k),
ggml_row_size(k->type, hparams.n_embd_k_gqa(il)),
0);
ggml_row_size(k->type, n_embd_k_gqa),
ggml_row_size(k->type, n_embd_k_gqa*kv_size),
ggml_row_size(k->type, n_embd_k_gqa*kv_size)*0);
}

ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint32_t n_kv) const {
const int32_t ikv = map_layer_ids.at(il);

auto * v = layers[ikv].v;

const uint64_t kv_size = get_size();
const uint64_t n_embd_v_gqa = v->ne[0];

// [TAG_V_CACHE_VARIABLE]
assert(n_embd_v_gqa >= hparams.n_embd_v_gqa(il));

if (!v_trans) {
// note: v->nb[1] <= v->nb[2]
return ggml_view_3d(ctx, v,
hparams.n_embd_head_v, hparams.n_head_kv(il), n_kv,
ggml_row_size(v->type, hparams.n_embd_head_v), // v->nb[1]
ggml_row_size(v->type, hparams.n_embd_v_gqa(il)), // v->nb[2]
0);
return ggml_view_4d(ctx, v,
hparams.n_embd_head_v, hparams.n_head_kv(il), n_kv, 1,
ggml_row_size(v->type, hparams.n_embd_head_v), // v->nb[1]
ggml_row_size(v->type, n_embd_v_gqa), // v->nb[2]
ggml_row_size(v->type, n_embd_v_gqa*kv_size), // v->nb[3]
ggml_row_size(v->type, n_embd_v_gqa*kv_size)*0);
}

// note: v->nb[1] > v->nb[2]
return ggml_view_3d(ctx, v,
n_kv, hparams.n_head_kv(il), hparams.n_embd_head_v,
ggml_row_size(v->type, v->ne[1]*hparams.n_embd_head_v), // v->nb[1]
ggml_row_size(v->type, v->ne[1]), // v->nb[2]
0);
return ggml_view_4d(ctx, v,
n_kv, hparams.n_head_kv(il), hparams.n_embd_head_v, 1,
ggml_row_size(v->type, kv_size*hparams.n_embd_head_v), // v->nb[1]
ggml_row_size(v->type, kv_size), // v->nb[2]
ggml_row_size(v->type, kv_size*n_embd_v_gqa), // v->nb[3]
ggml_row_size(v->type, kv_size*n_embd_v_gqa)*0);
}

ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const {
Expand All @@ -820,6 +841,10 @@ ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_
k_cur = ggml_reshape_2d(ctx, k_cur, k->ne[0], n_tokens);

if (k_idxs && supports_set_rows) {
if (k->ne[2] > 1) {
k = ggml_reshape_2d(ctx, k, k->ne[0], k->ne[1]*k->ne[2]);
}

return ggml_set_rows(ctx, k, k_cur, k_idxs);
}

Expand All @@ -838,31 +863,30 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_

auto * v = layers[ikv].v;

const int64_t n_embd_v_gqa = v->ne[0];
const int64_t n_tokens = v_cur->ne[2];
const int64_t n_embd_v_gqa = v_cur->ne[0]*v_cur->ne[1];
const int64_t n_tokens = v_cur->ne[2];

v_cur = ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens);

if (v_idxs && supports_set_rows) {
if (!v_trans) {
if (v->ne[2] > 1) {
v = ggml_reshape_2d(ctx, v, v->ne[0], v->ne[1]*v->ne[2]);
}

return ggml_set_rows(ctx, v, v_cur, v_idxs);
}

// the row becomes a single element
ggml_tensor * v_view = ggml_reshape_3d(ctx, v, 1, v->ne[1], v->ne[0]);
// [TAG_V_CACHE_VARIABLE]
if (n_embd_v_gqa < v->ne[0]) {
v_cur = ggml_pad(ctx, v_cur, v->ne[0] - n_embd_v_gqa, 0, 0, 0);
}

// note: the V cache is transposed when not using flash attention
v_cur = ggml_permute(ctx, ggml_reshape_3d(ctx, v_cur, v_cur->ne[0], 1, v_cur->ne[1]), 2, 0, 1, 3);
// the row becomes a single element
ggml_tensor * v_view = ggml_reshape_2d(ctx, v, 1, v->ne[0]*v->ne[1]*v->ne[2]);

// note: we can be more explicit here at the cost of extra cont
// however, above we take advantage that a row of single element is always continuous regardless of the row stride
//v_cur = ggml_transpose(ctx, v_cur);
//v_cur = ggml_cont_3d(ctx, v_cur, 1, v_cur->ne[0], v_cur->ne[1]);
v_cur = ggml_reshape_2d(ctx, v_cur, 1, v_cur->ne[0]*v_cur->ne[1]);

// we broadcast the KV indices n_embd_v_gqa times
// v [1, n_kv, n_embd_v_gqa]
// v_cur [1, n_tokens, n_embd_v_gqa]
// v_idxs [n_tokens, 1, 1]
return ggml_set_rows(ctx, v_view, v_cur, v_idxs);
}

Expand Down Expand Up @@ -899,7 +923,13 @@ ggml_tensor * llama_kv_cache_unified::build_input_k_idxs(ggml_context * ctx, con
ggml_tensor * llama_kv_cache_unified::build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
const uint32_t n_tokens = ubatch.n_tokens;

ggml_tensor * v_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens);
ggml_tensor * v_idxs;

if (!v_trans) {
v_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens);
} else {
v_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens*hparams.n_embd_v_gqa_max());
}

ggml_set_input(v_idxs);

Expand All @@ -916,7 +946,7 @@ void llama_kv_cache_unified::set_input_k_idxs(ggml_tensor * dst, const llama_uba
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
int64_t * data = (int64_t *) dst->data;

for (int64_t i = 0; i < n_tokens; ++i) {
for (uint32_t i = 0; i < n_tokens; ++i) {
data[i] = sinfo.idxs.at(i);
}
}
Expand All @@ -931,8 +961,21 @@ void llama_kv_cache_unified::set_input_v_idxs(ggml_tensor * dst, const llama_uba
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
int64_t * data = (int64_t *) dst->data;

for (int64_t i = 0; i < n_tokens; ++i) {
data[i] = sinfo.idxs.at(i);
if (!v_trans) {
for (uint32_t i = 0; i < n_tokens; ++i) {
data[i] = sinfo.idxs.at(i);
}
} else {
// note: the V cache is transposed when not using flash attention
const int64_t kv_size = get_size();

const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa_max();

for (uint32_t i = 0; i < n_tokens; ++i) {
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
data[i*n_embd_v_gqa + j] = j*kv_size + sinfo.idxs.at(i);
}
}
}
}

Expand Down
Loading