Skip to content

Commit f7fa1b1

Browse files
committed
refactor: Refactor mamba2/granite/jamba/granite_hybrid relationships as mixins
The key is for the mixin classes (llm_graph_context_mamba, llm_graph_context_granite) to use virtual inheritance from llm_graph_context. This allows the common members to exist only once in the class hierarchy. The downside is that llm_graph_context will be re-initialized once for each parent (ie 2x for single mixin, 3x for two mixins, etc...). Branch: GraniteFourWithJamba Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
1 parent 0796726 commit f7fa1b1

File tree

1 file changed

+88
-222
lines changed

1 file changed

+88
-222
lines changed

src/llama-model.cpp

Lines changed: 88 additions & 222 deletions
Original file line numberDiff line numberDiff line change
@@ -10024,7 +10024,7 @@ struct llm_build_starcoder2 : public llm_graph_context {
1002410024
}
1002510025
};
1002610026

10027-
struct llm_graph_context_mamba : public llm_graph_context {
10027+
struct llm_graph_context_mamba : public virtual llm_graph_context {
1002810028
llm_graph_context_mamba(const llm_graph_params & params) : llm_graph_context(params) {}
1002910029

1003010030
ggml_tensor * build_mamba_layer(
@@ -10298,7 +10298,8 @@ struct llm_graph_context_mamba : public llm_graph_context {
1029810298
};
1029910299

1030010300
struct llm_build_mamba : public llm_graph_context_mamba {
10301-
llm_build_mamba(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context_mamba(params) {
10301+
llm_build_mamba(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf)
10302+
: llm_graph_context(params), llm_graph_context_mamba(params) {
1030210303
ggml_tensor * cur;
1030310304
ggml_tensor * inpL;
1030410305

@@ -10355,7 +10356,8 @@ struct llm_build_mamba : public llm_graph_context_mamba {
1035510356
};
1035610357

1035710358
struct llm_build_jamba : public llm_graph_context_mamba {
10358-
llm_build_jamba(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context_mamba(params) {
10359+
llm_build_jamba(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf)
10360+
: llm_graph_context(params), llm_graph_context_mamba(params) {
1035910361
const int64_t n_embd_head = hparams.n_embd_head_v;
1036010362

1036110363
ggml_tensor * cur;
@@ -13794,81 +13796,10 @@ struct llm_build_arwkv7 : public llm_build_rwkv7_base {
1379413796
}
1379513797
};
1379613798

13797-
struct llm_build_granite : public llm_graph_context {
13798-
llm_build_granite(
13799-
const llama_model & model,
13800-
const llm_graph_params & params,
13801-
ggml_cgraph * gf,
13802-
const bool use_rope = true)
13803-
: llm_graph_context(params) {
13804-
13805-
const int64_t n_embd_head = hparams.n_embd_head_v;
13806-
13807-
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
13808-
GGML_ASSERT(n_embd_head == hparams.n_rot);
13809-
13810-
ggml_tensor * cur;
13811-
ggml_tensor * inpL;
13812-
13813-
inpL = build_inp_embd(model.tok_embd);
13814-
13815-
// inp_pos - built only if rope enabled
13816-
ggml_tensor * inp_pos = nullptr;
13817-
if (use_rope) {
13818-
inp_pos = build_inp_pos();
13819-
}
13820-
13821-
auto * inp_attn = build_attn_inp_kv_unified();
13822-
13823-
ggml_tensor * inp_out_ids = build_inp_out_ids();
13824-
13825-
for (int il = 0; il < n_layer; ++il) {
13826-
ggml_tensor * inpSA = inpL;
13799+
struct llm_graph_context_granite : public virtual llm_graph_context {
13800+
llm_graph_context_granite(const llm_graph_params & params) : llm_graph_context(params) {}
1382713801

13828-
// norm
13829-
cur = build_norm(inpL,
13830-
model.layers[il].attn_norm, NULL,
13831-
LLM_NORM_RMS, il);
13832-
cb(cur, "attn_norm", il);
13833-
13834-
// self-attention
13835-
cur = build_granite_attention_layer(
13836-
gf, cur, inp_pos, inp_attn,
13837-
model, n_embd_head, use_rope, il);
13838-
13839-
if (il == n_layer - 1 && inp_out_ids) {
13840-
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
13841-
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
13842-
}
13843-
13844-
// ffn
13845-
cur = build_layer_ffn(cur, inpSA, model, il);
13846-
13847-
// input for next layer
13848-
inpL = cur;
13849-
}
13850-
13851-
cur = inpL;
13852-
13853-
cur = build_norm(cur,
13854-
model.output_norm, NULL,
13855-
LLM_NORM_RMS, -1);
13856-
13857-
cb(cur, "result_norm", -1);
13858-
res->t_embd = cur;
13859-
13860-
// lm_head
13861-
cur = build_lora_mm(model.output, cur);
13862-
13863-
// For Granite architectures - scale logits
13864-
cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_logit_scale);
13865-
cb(cur, "result_output", -1);
13866-
res->t_logits = cur;
13867-
13868-
ggml_build_forward_expand(gf, cur);
13869-
}
13870-
13871-
ggml_tensor * build_granite_attention_layer(
13802+
ggml_tensor * build_attention_layer(
1387213803
ggml_cgraph * gf,
1387313804
ggml_tensor * cur,
1387413805
ggml_tensor * inp_pos,
@@ -14011,14 +13942,91 @@ struct llm_build_granite : public llm_graph_context {
1401113942
}
1401213943
};
1401313944

14014-
struct llm_build_granite_hybrid : public llm_graph_context_mamba {
13945+
struct llm_build_granite : public llm_graph_context_granite {
13946+
llm_build_granite(
13947+
const llama_model & model,
13948+
const llm_graph_params & params,
13949+
ggml_cgraph * gf,
13950+
const bool use_rope = true)
13951+
: llm_graph_context(params), llm_graph_context_granite(params) {
13952+
13953+
const int64_t n_embd_head = hparams.n_embd_head_v;
13954+
13955+
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
13956+
GGML_ASSERT(n_embd_head == hparams.n_rot);
13957+
13958+
ggml_tensor * cur;
13959+
ggml_tensor * inpL;
13960+
13961+
inpL = build_inp_embd(model.tok_embd);
13962+
13963+
// inp_pos - built only if rope enabled
13964+
ggml_tensor * inp_pos = nullptr;
13965+
if (use_rope) {
13966+
inp_pos = build_inp_pos();
13967+
}
13968+
13969+
auto * inp_attn = build_attn_inp_kv_unified();
13970+
13971+
ggml_tensor * inp_out_ids = build_inp_out_ids();
13972+
13973+
for (int il = 0; il < n_layer; ++il) {
13974+
ggml_tensor * inpSA = inpL;
13975+
13976+
// norm
13977+
cur = build_norm(inpL,
13978+
model.layers[il].attn_norm, NULL,
13979+
LLM_NORM_RMS, il);
13980+
cb(cur, "attn_norm", il);
13981+
13982+
// self-attention
13983+
cur = build_attention_layer(
13984+
gf, cur, inp_pos, inp_attn,
13985+
model, n_embd_head, use_rope, il);
13986+
13987+
if (il == n_layer - 1 && inp_out_ids) {
13988+
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
13989+
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
13990+
}
13991+
13992+
// ffn
13993+
cur = build_layer_ffn(cur, inpSA, model, il);
13994+
13995+
// input for next layer
13996+
inpL = cur;
13997+
}
13998+
13999+
cur = inpL;
14000+
14001+
cur = build_norm(cur,
14002+
model.output_norm, NULL,
14003+
LLM_NORM_RMS, -1);
14004+
14005+
cb(cur, "result_norm", -1);
14006+
res->t_embd = cur;
14007+
14008+
// lm_head
14009+
cur = build_lora_mm(model.output, cur);
14010+
14011+
// For Granite architectures - scale logits
14012+
cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_logit_scale);
14013+
cb(cur, "result_output", -1);
14014+
res->t_logits = cur;
14015+
14016+
ggml_build_forward_expand(gf, cur);
14017+
}
14018+
};
14019+
14020+
struct llm_build_granite_hybrid : public llm_graph_context_mamba, public llm_graph_context_granite {
1401514021

1401614022
llm_build_granite_hybrid(
1401714023
const llama_model & model,
1401814024
const llm_graph_params & params,
1401914025
ggml_cgraph * gf,
1402014026
const bool use_rope = true) :
14021-
llm_graph_context_mamba(params) {
14027+
llm_graph_context(params),
14028+
llm_graph_context_mamba(params),
14029+
llm_graph_context_granite(params) {
1402214030

1402314031
const int64_t n_embd_head = hparams.n_embd_head_v;
1402414032
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -14056,7 +14064,7 @@ struct llm_build_granite_hybrid : public llm_graph_context_mamba {
1405614064
cur = build_mamba2_layer(inp_rs, gf, cur, model, ubatch, il);
1405714065
} else {
1405814066
// attention layer //
14059-
cur = build_granite_attention_layer(
14067+
cur = build_attention_layer(
1406014068
gf, cur, inp_pos, inp_attn, model,
1406114069
n_embd_head, use_rope, il);
1406214070
}
@@ -14094,148 +14102,6 @@ struct llm_build_granite_hybrid : public llm_graph_context_mamba {
1409414102

1409514103
ggml_build_forward_expand(gf, cur);
1409614104
}
14097-
14098-
ggml_tensor * build_granite_attention_layer(
14099-
ggml_cgraph * gf,
14100-
ggml_tensor * cur,
14101-
ggml_tensor * inp_pos,
14102-
llm_graph_input_attn_kv_unified * inp,
14103-
const llama_model & model,
14104-
const int64_t n_embd_head,
14105-
const bool use_rope,
14106-
const int il) {
14107-
14108-
// compute Q and K and (optionally) RoPE them
14109-
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
14110-
cb(Qcur, "Qcur", il);
14111-
if (model.layers[il].bq) {
14112-
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
14113-
cb(Qcur, "Qcur", il);
14114-
}
14115-
14116-
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
14117-
cb(Kcur, "Kcur", il);
14118-
if (model.layers[il].bk) {
14119-
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
14120-
cb(Kcur, "Kcur", il);
14121-
}
14122-
14123-
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
14124-
cb(Vcur, "Vcur", il);
14125-
if (model.layers[il].bv) {
14126-
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
14127-
cb(Vcur, "Vcur", il);
14128-
}
14129-
14130-
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, hparams.n_head(il), n_tokens);
14131-
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, hparams.n_head_kv(il), n_tokens);
14132-
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, hparams.n_head_kv(il), n_tokens);
14133-
14134-
if (use_rope) {
14135-
ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
14136-
Qcur = ggml_rope_ext(
14137-
ctx0, Qcur, inp_pos, rope_factors,
14138-
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
14139-
ext_factor, attn_factor, beta_fast, beta_slow
14140-
);
14141-
14142-
Kcur = ggml_rope_ext(
14143-
ctx0, Kcur, inp_pos, rope_factors,
14144-
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
14145-
ext_factor, attn_factor, beta_fast, beta_slow
14146-
);
14147-
}
14148-
14149-
cb(Qcur, "Qcur", il);
14150-
cb(Kcur, "Kcur", il);
14151-
cb(Vcur, "Vcur", il);
14152-
14153-
const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
14154-
cur = build_attn(inp, gf,
14155-
model.layers[il].wo, model.layers[il].bo,
14156-
Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
14157-
cb(cur, "attn_out", il);
14158-
return cur;
14159-
}
14160-
14161-
ggml_tensor * build_layer_ffn(
14162-
ggml_tensor * cur,
14163-
ggml_tensor * inpSA,
14164-
const llama_model & model,
14165-
const int il) {
14166-
14167-
// For Granite architectures - scale residual
14168-
if (hparams.f_residual_scale) {
14169-
cur = ggml_scale(ctx0, cur, hparams.f_residual_scale);
14170-
}
14171-
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
14172-
cb(ffn_inp, "ffn_inp", il);
14173-
14174-
// feed-forward network (non-MoE)
14175-
if (model.layers[il].ffn_gate_inp == nullptr) {
14176-
14177-
cur = build_norm(ffn_inp,
14178-
model.layers[il].ffn_norm, NULL,
14179-
LLM_NORM_RMS, il);
14180-
cb(cur, "ffn_norm", il);
14181-
14182-
cur = build_ffn(cur,
14183-
model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL,
14184-
model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL,
14185-
model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
14186-
NULL,
14187-
LLM_FFN_SILU, LLM_FFN_PAR, il);
14188-
cb(cur, "ffn_out", il);
14189-
14190-
} else {
14191-
// MoE branch
14192-
cur = build_norm(ffn_inp,
14193-
model.layers[il].ffn_norm, NULL,
14194-
LLM_NORM_RMS, il);
14195-
cb(cur, "ffn_norm", il);
14196-
14197-
ggml_tensor * moe_out = build_moe_ffn(cur,
14198-
model.layers[il].ffn_gate_inp,
14199-
model.layers[il].ffn_up_exps,
14200-
model.layers[il].ffn_gate_exps,
14201-
model.layers[il].ffn_down_exps,
14202-
nullptr,
14203-
n_expert, n_expert_used,
14204-
LLM_FFN_SILU, true,
14205-
false, 0.0,
14206-
LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
14207-
il);
14208-
cb(moe_out, "ffn_moe_out", il);
14209-
14210-
// For Granite MoE Shared
14211-
if (hparams.n_ff_shexp > 0) {
14212-
ggml_tensor * ffn_shexp = build_ffn(cur,
14213-
model.layers[il].ffn_up_shexp, NULL, NULL,
14214-
model.layers[il].ffn_gate_shexp, NULL, NULL,
14215-
model.layers[il].ffn_down_shexp, NULL, NULL,
14216-
NULL,
14217-
LLM_FFN_SILU, LLM_FFN_PAR, il);
14218-
cb(ffn_shexp, "ffn_shexp", il);
14219-
14220-
cur = ggml_add(ctx0, moe_out, ffn_shexp);
14221-
cb(cur, "ffn_out", il);
14222-
} else {
14223-
cur = moe_out;
14224-
}
14225-
}
14226-
14227-
// For Granite architectures - scale residual
14228-
if (hparams.f_residual_scale) {
14229-
cur = ggml_scale(ctx0, cur, hparams.f_residual_scale);
14230-
}
14231-
cur = ggml_add(ctx0, cur, ffn_inp);
14232-
cb(cur, "ffn_out", il);
14233-
14234-
cur = build_cvec(cur, il);
14235-
cb(cur, "l_out", il);
14236-
14237-
return cur;
14238-
}
1423914105
};
1424014106

1424114107
// ref: https://github.com/facebookresearch/chameleon

0 commit comments

Comments
 (0)