Skip to content

Commit fa159cf

Browse files
committed
Merge remote-tracking branch 'origin/compilade/refactor-kv-cache' into GraniteFour
* origin/compilade/refactor-kv-cache: memory : avoid referring to KV in recurrent cache logs model : make falcon-h1 use shared mamba2 layer builder Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
2 parents 0583d95 + 452207f commit fa159cf

File tree

3 files changed

+24
-155
lines changed

3 files changed

+24
-155
lines changed

gguf-py/gguf/tensor_mapping.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -566,13 +566,13 @@ class TensorNameMap:
566566
MODEL_TENSOR.SSM_IN: (
567567
"model.layers.{bid}.in_proj", # mamba-hf
568568
"backbone.layers.{bid}.mixer.in_proj", # mamba
569-
"model.layers.{bid}.mamba.in_proj", # falcon-h1, jamba, bamba
569+
"model.layers.{bid}.mamba.in_proj", # jamba falcon-h1 bamba
570570
),
571571

572572
MODEL_TENSOR.SSM_CONV1D: (
573573
"model.layers.{bid}.conv1d", # mamba-hf
574574
"backbone.layers.{bid}.mixer.conv1d", # mamba
575-
"model.layers.{bid}.mamba.conv1d", # falcon-h1, jamba, bamba
575+
"model.layers.{bid}.mamba.conv1d", # jamba falcon-h1 bamba
576576
),
577577

578578
MODEL_TENSOR.SSM_X: (
@@ -584,7 +584,7 @@ class TensorNameMap:
584584
MODEL_TENSOR.SSM_DT: (
585585
"model.layers.{bid}.dt_proj", # mamba-hf
586586
"backbone.layers.{bid}.mixer.dt_proj", # mamba
587-
"model.layers.{bid}.mamba.dt_proj", # falcon-h1, jamba, bamba
587+
"model.layers.{bid}.mamba.dt_proj", # jamba falcon-h1 bamba
588588
),
589589

590590
MODEL_TENSOR.SSM_DT_NORM: (
@@ -594,7 +594,7 @@ class TensorNameMap:
594594
MODEL_TENSOR.SSM_A: (
595595
"model.layers.{bid}.A_log", # mamba-hf
596596
"backbone.layers.{bid}.mixer.A_log", # mamba
597-
"model.layers.{bid}.mamba.A_log", # falcon-h1, jamba, bamba
597+
"model.layers.{bid}.mamba.A_log", # jamba falcon-h1 bamba
598598
),
599599

600600
MODEL_TENSOR.SSM_B_NORM: (
@@ -610,7 +610,7 @@ class TensorNameMap:
610610
MODEL_TENSOR.SSM_D: (
611611
"model.layers.{bid}.D", # mamba-hf
612612
"backbone.layers.{bid}.mixer.D", # mamba
613-
"model.layers.{bid}.mamba.D", # falcon-h1, jamba, bamba
613+
"model.layers.{bid}.mamba.D", # jamba falcon-h1 bamba
614614
),
615615

616616
MODEL_TENSOR.SSM_NORM: (
@@ -622,7 +622,7 @@ class TensorNameMap:
622622
MODEL_TENSOR.SSM_OUT: (
623623
"model.layers.{bid}.out_proj", # mamba-hf
624624
"backbone.layers.{bid}.mixer.out_proj", # mamba
625-
"model.layers.{bid}.mamba.out_proj", # falcon-h1, jamba, bamba
625+
"model.layers.{bid}.mamba.out_proj", # jamba falcon-h1 bamba
626626
),
627627

628628
MODEL_TENSOR.TIME_MIX_W0: (

src/llama-memory-recurrent.cpp

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,6 @@ llama_memory_recurrent::llama_memory_recurrent(
2525
uint32_t n_seq_max) : hparams(model.hparams), n_seq_max(n_seq_max) {
2626
const int32_t n_layer = hparams.n_layer;
2727

28-
LLAMA_LOG_INFO("%s: mem_size = %u, n_seq_max = %u, type_r = '%s', type_s = '%s', n_layer = %d\n",
29-
__func__, mem_size, n_seq_max, ggml_type_name(type_r), ggml_type_name(type_s), n_layer);
30-
3128
head = 0;
3229
size = mem_size;
3330
used = 0;
@@ -84,7 +81,7 @@ llama_memory_recurrent::llama_memory_recurrent(
8481

8582
ggml_context * ctx = ctx_for_buft(buft);
8683
if (!ctx) {
87-
throw std::runtime_error("failed to create ggml context for kv cache");
84+
throw std::runtime_error("failed to create ggml context for rs cache");
8885
}
8986

9087
ggml_tensor * r = ggml_new_tensor_1d(ctx, type_r, hparams.n_embd_r()*mem_size);
@@ -102,19 +99,19 @@ llama_memory_recurrent::llama_memory_recurrent(
10299

103100
ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
104101
if (!buf) {
105-
throw std::runtime_error("failed to allocate buffer for kv cache");
102+
throw std::runtime_error("failed to allocate buffer for rs cache");
106103
}
107104
ggml_backend_buffer_clear(buf, 0);
108-
LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
105+
LLAMA_LOG_INFO("%s: %10s RS buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
109106
bufs.emplace_back(buf);
110107
}
111108

112109
{
113110
const size_t memory_size_r = size_r_bytes();
114111
const size_t memory_size_s = size_s_bytes();
115112

116-
LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, R (%s): %7.2f MiB, S (%s): %7.2f MiB\n", __func__,
117-
(float)(memory_size_r + memory_size_s) / (1024.0f * 1024.0f),
113+
LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u seqs), R (%s): %7.2f MiB, S (%s): %7.2f MiB\n", __func__,
114+
(float)(memory_size_r + memory_size_s) / (1024.0f * 1024.0f), mem_size, n_layer, n_seq_max,
118115
ggml_type_name(type_r), (float)memory_size_r / (1024.0f * 1024.0f),
119116
ggml_type_name(type_s), (float)memory_size_s / (1024.0f * 1024.0f));
120117
}

src/llama-model.cpp

Lines changed: 13 additions & 141 deletions
Original file line numberDiff line numberDiff line change
@@ -5163,7 +5163,10 @@ void llama_model::print_info() const {
51635163
}
51645164
}
51655165

5166-
if (arch == LLM_ARCH_MAMBA || arch == LLM_ARCH_MAMBA2 || arch == LLM_ARCH_JAMBA) {
5166+
if (arch == LLM_ARCH_MAMBA ||
5167+
arch == LLM_ARCH_MAMBA2 ||
5168+
arch == LLM_ARCH_JAMBA ||
5169+
arch == LLM_ARCH_FALCON_H1) {
51675170
LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv);
51685171
LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner);
51695172
LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state);
@@ -10436,8 +10439,11 @@ struct llm_graph_context_mamba : public virtual llm_graph_context {
1043610439
y = ggml_swiglu_split(ctx0, ggml_cont(ctx0, z), y);
1043710440

1043810441
// grouped RMS norm
10439-
y = ggml_reshape_4d(ctx0, y, d_inner / n_group, n_group, n_seq_tokens, n_seqs);
10440-
y = build_norm(y, model.layers[il].ssm_norm, NULL, LLM_NORM_RMS, il);
10442+
if (model.layers[il].ssm_norm) {
10443+
y = ggml_reshape_4d(ctx0, y, d_inner / n_group, n_group, n_seq_tokens, n_seqs);
10444+
y = build_norm(y, model.layers[il].ssm_norm, NULL, LLM_NORM_RMS, il);
10445+
}
10446+
1044110447
y = ggml_reshape_3d(ctx0, y, d_inner, n_seq_tokens, n_seqs);
1044210448

1044310449
// {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs}
@@ -15180,10 +15186,9 @@ struct llm_build_ernie4_5 : public llm_graph_context {
1518015186
}
1518115187
};
1518215188

15183-
struct llm_build_falcon_h1 : public llm_graph_context {
15184-
const llama_model & model;
15185-
15186-
llm_build_falcon_h1(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params), model(model) {
15189+
struct llm_build_falcon_h1 : public llm_graph_context_mamba {
15190+
llm_build_falcon_h1(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf)
15191+
: llm_graph_context(params), llm_graph_context_mamba(params) {
1518715192
const int64_t n_embd_head = hparams.n_embd_head_v;
1518815193

1518915194
ggml_tensor * cur;
@@ -15250,7 +15255,7 @@ struct llm_build_falcon_h1 : public llm_graph_context {
1525015255
// Mamba2 layer
1525115256
cb(cur, "ssm_in", il);
1525215257

15253-
ggml_tensor * ssm_out = build_mamba2_layer(inp, gf, cur, ubatch, il);
15258+
ggml_tensor * ssm_out = build_mamba2_layer(inp->get_recr(), gf, cur, model, ubatch, il);
1525415259
cb(ssm_out, "ssm_out", il);
1525515260

1525615261
// // Aggregation
@@ -15306,139 +15311,6 @@ struct llm_build_falcon_h1 : public llm_graph_context {
1530615311

1530715312
ggml_build_forward_expand(gf, cur);
1530815313
}
15309-
15310-
ggml_tensor * build_mamba2_layer(
15311-
llm_graph_input_mem_hybrid * inp,
15312-
ggml_cgraph * gf,
15313-
ggml_tensor * cur,
15314-
const llama_ubatch & ubatch,
15315-
int il) const {
15316-
const auto * kv_state = static_cast<const llama_memory_hybrid_context *>(mctx)->get_recr();
15317-
15318-
const auto kv_head = kv_state->get_head();
15319-
15320-
const int64_t d_conv = hparams.ssm_d_conv;
15321-
const int64_t d_inner = hparams.ssm_d_inner;
15322-
const int64_t d_state = hparams.ssm_d_state;
15323-
const int64_t n_head = hparams.ssm_dt_rank;
15324-
const int64_t head_dim = d_inner / n_head;
15325-
const int64_t n_group = hparams.ssm_n_group;
15326-
const int64_t n_seqs = ubatch.n_seqs;
15327-
15328-
const int64_t n_seq_tokens = ubatch.n_seq_tokens;
15329-
15330-
GGML_ASSERT(n_seqs != 0);
15331-
GGML_ASSERT(ubatch.equal_seqs);
15332-
GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
15333-
15334-
ggml_tensor * conv_states_all = kv_state->get_r_l(il);
15335-
ggml_tensor * ssm_states_all = kv_state->get_s_l(il);
15336-
15337-
ggml_tensor * conv = build_rs(inp->get_recr(), gf, conv_states_all, hparams.n_embd_r(), n_seqs);
15338-
conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner + 2*n_group*d_state, n_seqs);
15339-
15340-
// {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
15341-
cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs);
15342-
15343-
// d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads
15344-
15345-
// {n_embd, d_in_proj} @ {n_embd, n_seq_tokens, n_seqs} => {d_in_proj, n_seq_tokens, n_seqs}
15346-
ggml_tensor * zxBCdt = build_lora_mm(model.layers[il].ssm_in, cur);
15347-
cb(zxBCdt, "zxBCdt", il);
15348-
15349-
// split the above in three
15350-
ggml_tensor * z = ggml_view_4d(ctx0, zxBCdt, head_dim, n_head, n_seq_tokens, n_seqs, head_dim*zxBCdt->nb[0], zxBCdt->nb[1], zxBCdt->nb[2], 0);
15351-
ggml_tensor * xBC = ggml_view_3d(ctx0, zxBCdt, d_inner + 2*n_group*d_state, n_seq_tokens, n_seqs, zxBCdt->nb[1], zxBCdt->nb[2], d_inner*ggml_element_size(zxBCdt));
15352-
ggml_tensor * dt = ggml_view_3d(ctx0, zxBCdt, n_head, n_seq_tokens, n_seqs, zxBCdt->nb[1], zxBCdt->nb[2], (2*d_inner + 2*n_group*d_state)*ggml_element_size(zxBCdt));
15353-
15354-
// conv
15355-
{
15356-
// => {d_conv - 1 + n_seq_tokens, d_inner + 2*n_group*d_state, n_seqs}
15357-
ggml_tensor * conv_x = ggml_concat(ctx0, conv, ggml_transpose(ctx0, xBC), 0);
15358-
15359-
// copy last (d_conv - 1) columns back into the state cache
15360-
ggml_tensor * last_conv = ggml_view_3d(ctx0, conv_x, d_conv - 1, d_inner + 2*n_group*d_state, n_seqs, conv_x->nb[1], conv_x->nb[2], n_seq_tokens*(conv_x->nb[0]));
15361-
15362-
ggml_build_forward_expand(gf,
15363-
ggml_cpy(ctx0, last_conv,
15364-
ggml_view_1d(ctx0, conv_states_all,
15365-
(d_conv - 1)*(d_inner + 2*n_group*d_state)*(n_seqs),
15366-
kv_head*(d_conv - 1)*(d_inner + 2*n_group*d_state)*ggml_element_size(conv_states_all))));
15367-
15368-
// 1D convolution
15369-
// The equivalent is to make a self-overlapping view of conv_x
15370-
// over d_conv columns at each stride in the 3rd dimension,
15371-
// then element-wise multiply that with the conv1d weight,
15372-
// then sum the elements of each row,
15373-
// (the last two steps are a dot product over rows (also doable with mul_mat))
15374-
// then permute away the ne[0] dimension,
15375-
// and then you're left with the resulting x tensor.
15376-
// For simultaneous sequences, all sequences need to have the same length.
15377-
xBC = ggml_ssm_conv(ctx0, conv_x, model.layers[il].ssm_conv1d);
15378-
15379-
// bias
15380-
xBC = ggml_add(ctx0, xBC, model.layers[il].ssm_conv1d_b);
15381-
15382-
xBC = ggml_silu(ctx0, xBC);
15383-
}
15384-
15385-
// ssm
15386-
{
15387-
// These correspond to V K Q in SSM/attention duality
15388-
ggml_tensor * x = ggml_view_4d(ctx0, xBC, head_dim, n_head, n_seq_tokens, n_seqs, head_dim*xBC->nb[0], xBC->nb[1], xBC->nb[2], 0);
15389-
15390-
ggml_tensor * B = ggml_view_4d(ctx0, xBC, d_state, n_group, n_seq_tokens, n_seqs, d_state*xBC->nb[0], xBC->nb[1], xBC->nb[2], d_inner*ggml_element_size(xBC));
15391-
15392-
ggml_tensor * C = ggml_view_4d(ctx0, xBC, d_state, n_group, n_seq_tokens, n_seqs, d_state*xBC->nb[0], xBC->nb[1], xBC->nb[2], (d_inner + n_group*d_state)*ggml_element_size(xBC));
15393-
15394-
// {n_head, n_seq_tokens, n_seqs}
15395-
dt = ggml_add(ctx0, ggml_cont(ctx0, dt), model.layers[il].ssm_dt_b);
15396-
15397-
ggml_tensor * A = model.layers[il].ssm_a;
15398-
15399-
// use the states and the indices provided by build_rs
15400-
// (this is necessary in order to properly use the states before they are overwritten,
15401-
// while avoiding to make unnecessary copies of the states)
15402-
auto get_ssm_rows = [&](ggml_context * ctx, ggml_tensor * states, ggml_tensor * ids) {
15403-
ggml_tensor * ssm = ggml_reshape_4d(ctx, states, d_state, head_dim, n_head, kv_state->get_size());
15404-
15405-
// TODO: use semistructured matrices to implement state-space duality
15406-
// => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs}
15407-
return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids);
15408-
};
15409-
15410-
ggml_tensor * y_ssm = build_rs(inp->get_recr(), gf, ssm_states_all, hparams.n_embd_s(), ubatch.n_seqs, get_ssm_rows);
15411-
15412-
// store last states
15413-
ggml_build_forward_expand(gf,
15414-
ggml_cpy(ctx0,
15415-
ggml_view_1d(ctx0, y_ssm, d_state*d_inner*n_seqs, ggml_nelements(x)*x->nb[0]),
15416-
ggml_view_1d(ctx0, ssm_states_all, d_state*d_inner*n_seqs, kv_head*d_state*d_inner*ggml_element_size(ssm_states_all))));
15417-
15418-
ggml_tensor * y = ggml_view_4d(ctx0, y_ssm, head_dim, n_head, n_seq_tokens, n_seqs, x->nb[1], n_head*x->nb[1], n_seq_tokens*n_head*x->nb[1], 0);
15419-
15420-
// TODO: skip computing output earlier for unused tokens
15421-
15422-
y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d));
15423-
y = ggml_swiglu_split(ctx0, ggml_cont(ctx0, z), y);
15424-
15425-
// grouped RMS norm
15426-
if (model.layers[il].ssm_norm) {
15427-
y = ggml_reshape_4d(ctx0, y, d_inner / n_group, n_group, n_seq_tokens, n_seqs);
15428-
y = build_norm(y, model.layers[il].ssm_norm, NULL, LLM_NORM_RMS, il);
15429-
}
15430-
15431-
y = ggml_reshape_3d(ctx0, y, d_inner, n_seq_tokens, n_seqs);
15432-
15433-
// {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs}
15434-
cur = build_lora_mm(model.layers[il].ssm_out, y);
15435-
}
15436-
15437-
// {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens}
15438-
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], n_seq_tokens * n_seqs);
15439-
cb(cur, "mamba_out", il);
15440-
return cur;
15441-
}
1544215314
};
1544315315

1544415316
struct llm_build_arcee : public llm_graph_context {

0 commit comments

Comments
 (0)