@@ -14028,7 +14028,11 @@ struct llm_build_granite_hybrid : public llm_graph_context {
14028
14028
14029
14029
inpL = build_inp_embd(model.tok_embd);
14030
14030
14031
- auto * inp = build_inp_mem_hybrid();
14031
+ const auto * mctx_hyb = static_cast<const llama_memory_hybrid_context *>(mctx);
14032
+
14033
+ auto * inp_rs = build_rs_inp(mctx_hyb->get_recr());
14034
+
14035
+ auto * inp_attn = build_attn_inp_kv_unified(mctx_hyb->get_attn());
14032
14036
14033
14037
ggml_tensor * inp_out_ids = build_inp_out_ids();
14034
14038
@@ -14049,11 +14053,11 @@ struct llm_build_granite_hybrid : public llm_graph_context {
14049
14053
14050
14054
if (hparams.is_recurrent(il)) {
14051
14055
// ssm layer //
14052
- cur = build_mamba2_layer(inp , gf, cur, model, ubatch, il);
14056
+ cur = build_mamba2_layer(inp_rs , gf, cur, model, ubatch, il);
14053
14057
} else {
14054
14058
// attention layer //
14055
14059
cur = build_granite_attention_layer(
14056
- gf, cur, inp_pos, inp , model,
14060
+ gf, cur, inp_pos, inp_attn , model,
14057
14061
n_embd_head, use_rope, il);
14058
14062
}
14059
14063
@@ -14092,12 +14096,12 @@ struct llm_build_granite_hybrid : public llm_graph_context {
14092
14096
}
14093
14097
14094
14098
ggml_tensor * build_mamba2_layer(
14095
- llm_graph_input_mem_hybrid * inp,
14096
- ggml_cgraph * gf,
14097
- ggml_tensor * cur,
14098
- const llama_model & model,
14099
- const llama_ubatch & ubatch,
14100
- int il) const {
14099
+ llm_graph_input_rs * inp,
14100
+ ggml_cgraph * gf,
14101
+ ggml_tensor * cur,
14102
+ const llama_model & model,
14103
+ const llama_ubatch & ubatch,
14104
+ int il) const {
14101
14105
const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx)->get_recr();
14102
14106
14103
14107
const auto kv_head = mctx_cur->get_head();
@@ -14221,14 +14225,14 @@ struct llm_build_granite_hybrid : public llm_graph_context {
14221
14225
}
14222
14226
14223
14227
ggml_tensor * build_granite_attention_layer(
14224
- ggml_cgraph * gf,
14225
- ggml_tensor * cur,
14226
- ggml_tensor * inp_pos,
14227
- llm_graph_input_mem_hybrid * inp,
14228
- const llama_model & model,
14229
- const int64_t n_embd_head,
14230
- const bool use_rope,
14231
- const int il) {
14228
+ ggml_cgraph * gf,
14229
+ ggml_tensor * cur,
14230
+ ggml_tensor * inp_pos,
14231
+ llm_graph_input_attn_kv_unified * inp,
14232
+ const llama_model & model,
14233
+ const int64_t n_embd_head,
14234
+ const bool use_rope,
14235
+ const int il) {
14232
14236
14233
14237
// compute Q and K and (optionally) RoPE them
14234
14238
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
0 commit comments