Skip to content

Commit 7f3955a

Browse files
committed
model : make falcon-h1 use shared mamba2 layer builder
1 parent a60a24b commit 7f3955a

File tree

1 file changed

+13
-142
lines changed

1 file changed

+13
-142
lines changed

src/llama-model.cpp

Lines changed: 13 additions & 142 deletions
Original file line numberDiff line numberDiff line change
@@ -5021,7 +5021,10 @@ void llama_model::print_info() const {
50215021
}
50225022
}
50235023

5024-
if (arch == LLM_ARCH_MAMBA || arch == LLM_ARCH_MAMBA2 || arch == LLM_ARCH_JAMBA) {
5024+
if (arch == LLM_ARCH_MAMBA ||
5025+
arch == LLM_ARCH_MAMBA2 ||
5026+
arch == LLM_ARCH_JAMBA ||
5027+
arch == LLM_ARCH_FALCON_H1) {
50255028
LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv);
50265029
LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner);
50275030
LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state);
@@ -10292,8 +10295,11 @@ struct llm_graph_context_mamba : public llm_graph_context {
1029210295
y = ggml_swiglu_split(ctx0, ggml_cont(ctx0, z), y);
1029310296

1029410297
// grouped RMS norm
10295-
y = ggml_reshape_4d(ctx0, y, d_inner / n_group, n_group, n_seq_tokens, n_seqs);
10296-
y = build_norm(y, model.layers[il].ssm_norm, NULL, LLM_NORM_RMS, il);
10298+
if (model.layers[il].ssm_norm) {
10299+
y = ggml_reshape_4d(ctx0, y, d_inner / n_group, n_group, n_seq_tokens, n_seqs);
10300+
y = build_norm(y, model.layers[il].ssm_norm, NULL, LLM_NORM_RMS, il);
10301+
}
10302+
1029710303
y = ggml_reshape_3d(ctx0, y, d_inner, n_seq_tokens, n_seqs);
1029810304

1029910305
// {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs}
@@ -14919,10 +14925,8 @@ struct llm_build_ernie4_5 : public llm_graph_context {
1491914925
}
1492014926
};
1492114927

14922-
struct llm_build_falcon_h1 : public llm_graph_context {
14923-
const llama_model & model;
14924-
14925-
llm_build_falcon_h1(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params), model(model) {
14928+
struct llm_build_falcon_h1 : public llm_graph_context_mamba {
14929+
llm_build_falcon_h1(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context_mamba(params) {
1492614930
const int64_t n_embd_head = hparams.n_embd_head_v;
1492714931

1492814932
ggml_tensor * cur;
@@ -14978,7 +14982,7 @@ struct llm_build_falcon_h1 : public llm_graph_context {
1497814982
cb(Kcur, "Kcur-post-rope", il);
1497914983
cb(Vcur, "Vcur-post-rope", il);
1498014984

14981-
ggml_tensor * attn_out = build_attn(inp, gf,
14985+
ggml_tensor * attn_out = build_attn(inp->get_attn(), gf,
1498214986
model.layers[il].wo, NULL,
1498314987
Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
1498414988
cb(attn_out, "attn_out", il);
@@ -14989,7 +14993,7 @@ struct llm_build_falcon_h1 : public llm_graph_context {
1498914993
// Mamba2 layer
1499014994
cb(cur, "ssm_in", il);
1499114995

14992-
ggml_tensor * ssm_out = build_mamba2_layer(inp, gf, cur, ubatch, il);
14996+
ggml_tensor * ssm_out = build_mamba2_layer(inp->get_recr(), gf, cur, model, ubatch, il);
1499314997
cb(ssm_out, "ssm_out", il);
1499414998

1499514999
// // Aggregation
@@ -15045,139 +15049,6 @@ struct llm_build_falcon_h1 : public llm_graph_context {
1504515049

1504615050
ggml_build_forward_expand(gf, cur);
1504715051
}
15048-
15049-
ggml_tensor * build_mamba2_layer(
15050-
llm_graph_input_mem_hybrid * inp,
15051-
ggml_cgraph * gf,
15052-
ggml_tensor * cur,
15053-
const llama_ubatch & ubatch,
15054-
int il) const {
15055-
const auto * kv_state = static_cast<const llama_memory_hybrid_context *>(mctx)->get_recr();
15056-
15057-
const auto kv_head = kv_state->get_head();
15058-
15059-
const int64_t d_conv = hparams.ssm_d_conv;
15060-
const int64_t d_inner = hparams.ssm_d_inner;
15061-
const int64_t d_state = hparams.ssm_d_state;
15062-
const int64_t n_head = hparams.ssm_dt_rank;
15063-
const int64_t head_dim = d_inner / n_head;
15064-
const int64_t n_group = hparams.ssm_n_group;
15065-
const int64_t n_seqs = ubatch.n_seqs;
15066-
15067-
const int64_t n_seq_tokens = ubatch.n_seq_tokens;
15068-
15069-
GGML_ASSERT(n_seqs != 0);
15070-
GGML_ASSERT(ubatch.equal_seqs);
15071-
GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
15072-
15073-
ggml_tensor * conv_states_all = kv_state->get_r_l(il);
15074-
ggml_tensor * ssm_states_all = kv_state->get_s_l(il);
15075-
15076-
ggml_tensor * conv = build_rs(inp, gf, conv_states_all, hparams.n_embd_r(), n_seqs);
15077-
conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner + 2*n_group*d_state, n_seqs);
15078-
15079-
// {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
15080-
cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs);
15081-
15082-
// d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads
15083-
15084-
// {n_embd, d_in_proj} @ {n_embd, n_seq_tokens, n_seqs} => {d_in_proj, n_seq_tokens, n_seqs}
15085-
ggml_tensor * zxBCdt = build_lora_mm(model.layers[il].ssm_in, cur);
15086-
cb(zxBCdt, "zxBCdt", il);
15087-
15088-
// split the above in three
15089-
ggml_tensor * z = ggml_view_4d(ctx0, zxBCdt, head_dim, n_head, n_seq_tokens, n_seqs, head_dim*zxBCdt->nb[0], zxBCdt->nb[1], zxBCdt->nb[2], 0);
15090-
ggml_tensor * xBC = ggml_view_3d(ctx0, zxBCdt, d_inner + 2*n_group*d_state, n_seq_tokens, n_seqs, zxBCdt->nb[1], zxBCdt->nb[2], d_inner*ggml_element_size(zxBCdt));
15091-
ggml_tensor * dt = ggml_view_3d(ctx0, zxBCdt, n_head, n_seq_tokens, n_seqs, zxBCdt->nb[1], zxBCdt->nb[2], (2*d_inner + 2*n_group*d_state)*ggml_element_size(zxBCdt));
15092-
15093-
// conv
15094-
{
15095-
// => {d_conv - 1 + n_seq_tokens, d_inner + 2*n_group*d_state, n_seqs}
15096-
ggml_tensor * conv_x = ggml_concat(ctx0, conv, ggml_transpose(ctx0, xBC), 0);
15097-
15098-
// copy last (d_conv - 1) columns back into the state cache
15099-
ggml_tensor * last_conv = ggml_view_3d(ctx0, conv_x, d_conv - 1, d_inner + 2*n_group*d_state, n_seqs, conv_x->nb[1], conv_x->nb[2], n_seq_tokens*(conv_x->nb[0]));
15100-
15101-
ggml_build_forward_expand(gf,
15102-
ggml_cpy(ctx0, last_conv,
15103-
ggml_view_1d(ctx0, conv_states_all,
15104-
(d_conv - 1)*(d_inner + 2*n_group*d_state)*(n_seqs),
15105-
kv_head*(d_conv - 1)*(d_inner + 2*n_group*d_state)*ggml_element_size(conv_states_all))));
15106-
15107-
// 1D convolution
15108-
// The equivalent is to make a self-overlapping view of conv_x
15109-
// over d_conv columns at each stride in the 3rd dimension,
15110-
// then element-wise multiply that with the conv1d weight,
15111-
// then sum the elements of each row,
15112-
// (the last two steps are a dot product over rows (also doable with mul_mat))
15113-
// then permute away the ne[0] dimension,
15114-
// and then you're left with the resulting x tensor.
15115-
// For simultaneous sequences, all sequences need to have the same length.
15116-
xBC = ggml_ssm_conv(ctx0, conv_x, model.layers[il].ssm_conv1d);
15117-
15118-
// bias
15119-
xBC = ggml_add(ctx0, xBC, model.layers[il].ssm_conv1d_b);
15120-
15121-
xBC = ggml_silu(ctx0, xBC);
15122-
}
15123-
15124-
// ssm
15125-
{
15126-
// These correspond to V K Q in SSM/attention duality
15127-
ggml_tensor * x = ggml_view_4d(ctx0, xBC, head_dim, n_head, n_seq_tokens, n_seqs, head_dim*xBC->nb[0], xBC->nb[1], xBC->nb[2], 0);
15128-
15129-
ggml_tensor * B = ggml_view_4d(ctx0, xBC, d_state, n_group, n_seq_tokens, n_seqs, d_state*xBC->nb[0], xBC->nb[1], xBC->nb[2], d_inner*ggml_element_size(xBC));
15130-
15131-
ggml_tensor * C = ggml_view_4d(ctx0, xBC, d_state, n_group, n_seq_tokens, n_seqs, d_state*xBC->nb[0], xBC->nb[1], xBC->nb[2], (d_inner + n_group*d_state)*ggml_element_size(xBC));
15132-
15133-
// {n_head, n_seq_tokens, n_seqs}
15134-
dt = ggml_add(ctx0, ggml_cont(ctx0, dt), model.layers[il].ssm_dt_b);
15135-
15136-
ggml_tensor * A = model.layers[il].ssm_a;
15137-
15138-
// use the states and the indices provided by build_rs
15139-
// (this is necessary in order to properly use the states before they are overwritten,
15140-
// while avoiding to make unnecessary copies of the states)
15141-
auto get_ssm_rows = [&](ggml_context * ctx, ggml_tensor * states, ggml_tensor * ids) {
15142-
ggml_tensor * ssm = ggml_reshape_4d(ctx, states, d_state, head_dim, n_head, kv_state->get_size());
15143-
15144-
// TODO: use semistructured matrices to implement state-space duality
15145-
// => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs}
15146-
return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids);
15147-
};
15148-
15149-
ggml_tensor * y_ssm = build_rs(inp, gf, ssm_states_all, hparams.n_embd_s(), ubatch.n_seqs, get_ssm_rows);
15150-
15151-
// store last states
15152-
ggml_build_forward_expand(gf,
15153-
ggml_cpy(ctx0,
15154-
ggml_view_1d(ctx0, y_ssm, d_state*d_inner*n_seqs, ggml_nelements(x)*x->nb[0]),
15155-
ggml_view_1d(ctx0, ssm_states_all, d_state*d_inner*n_seqs, kv_head*d_state*d_inner*ggml_element_size(ssm_states_all))));
15156-
15157-
ggml_tensor * y = ggml_view_4d(ctx0, y_ssm, head_dim, n_head, n_seq_tokens, n_seqs, x->nb[1], n_head*x->nb[1], n_seq_tokens*n_head*x->nb[1], 0);
15158-
15159-
// TODO: skip computing output earlier for unused tokens
15160-
15161-
y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d));
15162-
y = ggml_swiglu_split(ctx0, ggml_cont(ctx0, z), y);
15163-
15164-
// grouped RMS norm
15165-
if (model.layers[il].ssm_norm) {
15166-
y = ggml_reshape_4d(ctx0, y, d_inner / n_group, n_group, n_seq_tokens, n_seqs);
15167-
y = build_norm(y, model.layers[il].ssm_norm, NULL, LLM_NORM_RMS, il);
15168-
}
15169-
15170-
y = ggml_reshape_3d(ctx0, y, d_inner, n_seq_tokens, n_seqs);
15171-
15172-
// {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs}
15173-
cur = build_lora_mm(model.layers[il].ssm_out, y);
15174-
}
15175-
15176-
// {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens}
15177-
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], n_seq_tokens * n_seqs);
15178-
cb(cur, "mamba_out", il);
15179-
return cur;
15180-
}
1518115052
};
1518215053

1518315054
struct llm_build_arcee : public llm_graph_context {

0 commit comments

Comments
 (0)