Skip to content

Commit 4b5f673

Browse files
committed
fix: Fix input initialization in granite_hybrid after removal of hybrid inputs
Branch: GraniteFourWithJamba Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
1 parent e100153 commit 4b5f673

File tree

1 file changed

+21
-17
lines changed

1 file changed

+21
-17
lines changed

src/llama-model.cpp

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14028,7 +14028,11 @@ struct llm_build_granite_hybrid : public llm_graph_context {
1402814028

1402914029
inpL = build_inp_embd(model.tok_embd);
1403014030

14031-
auto * inp = build_inp_mem_hybrid();
14031+
const auto * mctx_hyb = static_cast<const llama_memory_hybrid_context *>(mctx);
14032+
14033+
auto * inp_rs = build_rs_inp(mctx_hyb->get_recr());
14034+
14035+
auto * inp_attn = build_attn_inp_kv_unified(mctx_hyb->get_attn());
1403214036

1403314037
ggml_tensor * inp_out_ids = build_inp_out_ids();
1403414038

@@ -14049,11 +14053,11 @@ struct llm_build_granite_hybrid : public llm_graph_context {
1404914053

1405014054
if (hparams.is_recurrent(il)) {
1405114055
// ssm layer //
14052-
cur = build_mamba2_layer(inp, gf, cur, model, ubatch, il);
14056+
cur = build_mamba2_layer(inp_rs, gf, cur, model, ubatch, il);
1405314057
} else {
1405414058
// attention layer //
1405514059
cur = build_granite_attention_layer(
14056-
gf, cur, inp_pos, inp, model,
14060+
gf, cur, inp_pos, inp_attn, model,
1405714061
n_embd_head, use_rope, il);
1405814062
}
1405914063

@@ -14092,12 +14096,12 @@ struct llm_build_granite_hybrid : public llm_graph_context {
1409214096
}
1409314097

1409414098
ggml_tensor * build_mamba2_layer(
14095-
llm_graph_input_mem_hybrid * inp,
14096-
ggml_cgraph * gf,
14097-
ggml_tensor * cur,
14098-
const llama_model & model,
14099-
const llama_ubatch & ubatch,
14100-
int il) const {
14099+
llm_graph_input_rs * inp,
14100+
ggml_cgraph * gf,
14101+
ggml_tensor * cur,
14102+
const llama_model & model,
14103+
const llama_ubatch & ubatch,
14104+
int il) const {
1410114105
const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx)->get_recr();
1410214106

1410314107
const auto kv_head = mctx_cur->get_head();
@@ -14221,14 +14225,14 @@ struct llm_build_granite_hybrid : public llm_graph_context {
1422114225
}
1422214226

1422314227
ggml_tensor * build_granite_attention_layer(
14224-
ggml_cgraph * gf,
14225-
ggml_tensor * cur,
14226-
ggml_tensor * inp_pos,
14227-
llm_graph_input_mem_hybrid * inp,
14228-
const llama_model & model,
14229-
const int64_t n_embd_head,
14230-
const bool use_rope,
14231-
const int il) {
14228+
ggml_cgraph * gf,
14229+
ggml_tensor * cur,
14230+
ggml_tensor * inp_pos,
14231+
llm_graph_input_attn_kv_unified * inp,
14232+
const llama_model & model,
14233+
const int64_t n_embd_head,
14234+
const bool use_rope,
14235+
const int il) {
1423214236

1423314237
// compute Q and K and (optionally) RoPE them
1423414238
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);

0 commit comments

Comments
 (0)