Skip to content

Commit 0796726

Browse files
committed
fix: Use llm_graph_context_mamba in llm_build_granite_hybrid
Branch: GraniteFourWithJamba Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
1 parent 4b5f673 commit 0796726

File tree

1 file changed

+2
-131
lines changed

1 file changed

+2
-131
lines changed

src/llama-model.cpp

Lines changed: 2 additions & 131 deletions
Original file line numberDiff line numberDiff line change
@@ -14011,14 +14011,14 @@ struct llm_build_granite : public llm_graph_context {
1401114011
}
1401214012
};
1401314013

14014-
struct llm_build_granite_hybrid : public llm_graph_context {
14014+
struct llm_build_granite_hybrid : public llm_graph_context_mamba {
1401514015

1401614016
llm_build_granite_hybrid(
1401714017
const llama_model & model,
1401814018
const llm_graph_params & params,
1401914019
ggml_cgraph * gf,
1402014020
const bool use_rope = true) :
14021-
llm_graph_context(params) {
14021+
llm_graph_context_mamba(params) {
1402214022

1402314023
const int64_t n_embd_head = hparams.n_embd_head_v;
1402414024
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -14095,135 +14095,6 @@ struct llm_build_granite_hybrid : public llm_graph_context {
1409514095
ggml_build_forward_expand(gf, cur);
1409614096
}
1409714097

14098-
ggml_tensor * build_mamba2_layer(
14099-
llm_graph_input_rs * inp,
14100-
ggml_cgraph * gf,
14101-
ggml_tensor * cur,
14102-
const llama_model & model,
14103-
const llama_ubatch & ubatch,
14104-
int il) const {
14105-
const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx)->get_recr();
14106-
14107-
const auto kv_head = mctx_cur->get_head();
14108-
14109-
const int64_t d_conv = hparams.ssm_d_conv;
14110-
const int64_t d_inner = hparams.ssm_d_inner;
14111-
const int64_t d_state = hparams.ssm_d_state;
14112-
const int64_t n_head = hparams.ssm_dt_rank;
14113-
const int64_t head_dim = d_inner / n_head;
14114-
const int64_t n_group = hparams.ssm_n_group;
14115-
const int64_t n_seqs = ubatch.n_seqs;
14116-
14117-
const int64_t n_seq_tokens = ubatch.n_seq_tokens;
14118-
14119-
GGML_ASSERT(n_seqs != 0);
14120-
GGML_ASSERT(ubatch.equal_seqs);
14121-
GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
14122-
14123-
ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
14124-
ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il);
14125-
14126-
ggml_tensor * conv = build_rs(inp, gf, conv_states_all, hparams.n_embd_r(), n_seqs);
14127-
conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner + 2*n_group*d_state, n_seqs);
14128-
14129-
// {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
14130-
cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs);
14131-
14132-
// d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads
14133-
14134-
// {n_embd, d_in_proj} @ {n_embd, n_seq_tokens, n_seqs} => {d_in_proj, n_seq_tokens, n_seqs}
14135-
ggml_tensor * zxBCdt = build_lora_mm(model.layers[il].ssm_in, cur);
14136-
14137-
// split the above in three
14138-
ggml_tensor * z = ggml_view_4d(ctx0, zxBCdt, head_dim, n_head, n_seq_tokens, n_seqs, head_dim*zxBCdt->nb[0], zxBCdt->nb[1], zxBCdt->nb[2], 0);
14139-
ggml_tensor * xBC = ggml_view_3d(ctx0, zxBCdt, d_inner + 2*n_group*d_state, n_seq_tokens, n_seqs, zxBCdt->nb[1], zxBCdt->nb[2], d_inner*ggml_element_size(zxBCdt));
14140-
ggml_tensor * dt = ggml_view_3d(ctx0, zxBCdt, n_head, n_seq_tokens, n_seqs, zxBCdt->nb[1], zxBCdt->nb[2], (2*d_inner + 2*n_group*d_state)*ggml_element_size(zxBCdt));
14141-
14142-
// conv
14143-
{
14144-
// => {d_conv - 1 + n_seq_tokens, d_inner + 2*n_group*d_state, n_seqs}
14145-
ggml_tensor * conv_x = ggml_concat(ctx0, conv, ggml_transpose(ctx0, xBC), 0);
14146-
14147-
// copy last (d_conv - 1) columns back into the state cache
14148-
ggml_tensor * last_conv = ggml_view_3d(ctx0, conv_x, d_conv - 1, d_inner + 2*n_group*d_state, n_seqs, conv_x->nb[1], conv_x->nb[2], n_seq_tokens*(conv_x->nb[0]));
14149-
14150-
ggml_build_forward_expand(gf,
14151-
ggml_cpy(ctx0, last_conv,
14152-
ggml_view_1d(ctx0, conv_states_all,
14153-
(d_conv - 1)*(d_inner + 2*n_group*d_state)*(n_seqs),
14154-
kv_head*(d_conv - 1)*(d_inner + 2*n_group*d_state)*ggml_element_size(conv_states_all))));
14155-
14156-
// 1D convolution
14157-
// The equivalent is to make a self-overlapping view of conv_x
14158-
// over d_conv columns at each stride in the 3rd dimension,
14159-
// then element-wise multiply that with the conv1d weight,
14160-
// then sum the elements of each row,
14161-
// (the last two steps are a dot product over rows (also doable with mul_mat))
14162-
// then permute away the ne[0] dimension,
14163-
// and then you're left with the resulting x tensor.
14164-
// For simultaneous sequences, all sequences need to have the same length.
14165-
xBC = ggml_ssm_conv(ctx0, conv_x, model.layers[il].ssm_conv1d);
14166-
14167-
// bias
14168-
xBC = ggml_add(ctx0, xBC, model.layers[il].ssm_conv1d_b);
14169-
14170-
xBC = ggml_silu(ctx0, xBC);
14171-
}
14172-
14173-
// ssm
14174-
{
14175-
// These correspond to V K Q in SSM/attention duality
14176-
ggml_tensor * x = ggml_view_4d(ctx0, xBC, head_dim, n_head, n_seq_tokens, n_seqs, head_dim*xBC->nb[0], xBC->nb[1], xBC->nb[2], 0);
14177-
ggml_tensor * B = ggml_view_4d(ctx0, xBC, d_state, n_group, n_seq_tokens, n_seqs, d_state*xBC->nb[0], xBC->nb[1], xBC->nb[2], d_inner*ggml_element_size(xBC));
14178-
ggml_tensor * C = ggml_view_4d(ctx0, xBC, d_state, n_group, n_seq_tokens, n_seqs, d_state*xBC->nb[0], xBC->nb[1], xBC->nb[2], (d_inner + n_group*d_state)*ggml_element_size(xBC));
14179-
14180-
// {n_head, n_seq_tokens, n_seqs}
14181-
dt = ggml_add(ctx0, ggml_cont(ctx0, dt), model.layers[il].ssm_dt_b);
14182-
14183-
ggml_tensor * A = model.layers[il].ssm_a;
14184-
14185-
// use the states and the indices provided by build_rs
14186-
// (this is necessary in order to properly use the states before they are overwritten,
14187-
// while avoiding to make unnecessary copies of the states)
14188-
auto get_ssm_rows = [&](ggml_context * ctx, ggml_tensor * states, ggml_tensor * ids) {
14189-
ggml_tensor * ssm = ggml_reshape_4d(ctx, states, d_state, head_dim, n_head, mctx_cur->get_size());
14190-
14191-
// TODO: use semistructured matrices to implement state-space duality
14192-
// => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs}
14193-
return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids);
14194-
};
14195-
14196-
ggml_tensor * y_ssm = build_rs(inp, gf, ssm_states_all, hparams.n_embd_s(), ubatch.n_seqs, get_ssm_rows);
14197-
14198-
// store last states
14199-
ggml_build_forward_expand(gf,
14200-
ggml_cpy(ctx0,
14201-
ggml_view_1d(ctx0, y_ssm, d_state*d_inner*n_seqs, ggml_nelements(x)*x->nb[0]),
14202-
ggml_view_1d(ctx0, ssm_states_all, d_state*d_inner*n_seqs, kv_head*d_state*d_inner*ggml_element_size(ssm_states_all))));
14203-
14204-
ggml_tensor * y = ggml_view_4d(ctx0, y_ssm, head_dim, n_head, n_seq_tokens, n_seqs, x->nb[1], n_head*x->nb[1], n_seq_tokens*n_head*x->nb[1], 0);
14205-
14206-
// TODO: skip computing output earlier for unused tokens
14207-
14208-
y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d));
14209-
y = ggml_mul(ctx0, y, ggml_silu(ctx0, ggml_cont(ctx0, z)));
14210-
14211-
// grouped RMS norm
14212-
y = ggml_reshape_4d(ctx0, y, d_inner / n_group, n_group, n_seq_tokens, n_seqs);
14213-
y = build_norm(y, model.layers[il].ssm_norm, NULL, LLM_NORM_RMS, il);
14214-
y = ggml_reshape_3d(ctx0, y, d_inner, n_seq_tokens, n_seqs);
14215-
14216-
// {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs}
14217-
cur = build_lora_mm(model.layers[il].ssm_out, y);
14218-
}
14219-
14220-
// {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens}
14221-
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], n_seq_tokens * n_seqs);
14222-
// cb(cur, "mamba_out", il);
14223-
14224-
return cur;
14225-
}
14226-
1422714098
ggml_tensor * build_granite_attention_layer(
1422814099
ggml_cgraph * gf,
1422914100
ggml_tensor * cur,

0 commit comments

Comments
 (0)