@@ -335,6 +335,11 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
335
335
}
336
336
}
337
337
338
+ void llm_graph_input_mem_hybrid::set_input (const llama_ubatch * ubatch) {
339
+ inp_attn->set_input (ubatch);
340
+ inp_rs->set_input (ubatch);
341
+ }
342
+
338
343
void llm_graph_input_one::set_input (const llama_ubatch * ubatch) {
339
344
GGML_UNUSED (ubatch);
340
345
GGML_ASSERT (one && ggml_nelements (one) == 1 );
@@ -1147,17 +1152,20 @@ ggml_tensor * llm_graph_context::build_attn(
1147
1152
return cur;
1148
1153
}
1149
1154
1150
- llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified (const llama_kv_cache_unified_context * mctx_cur) const {
1151
- if (!mctx_cur) {
1152
- mctx_cur = static_cast <const llama_kv_cache_unified_context *>(mctx);
1153
- }
1155
+ static std::unique_ptr<llm_graph_input_attn_kv_unified> build_attn_inp_kv_unified_impl (
1156
+ ggml_context * ctx0,
1157
+ const llama_ubatch & ubatch,
1158
+ const llama_hparams & hparams,
1159
+ const llama_cparams & cparams,
1160
+ const llama_kv_cache_unified_context * mctx_cur) {
1154
1161
1155
1162
auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, mctx_cur);
1156
1163
1157
1164
{
1158
1165
GGML_ASSERT (hparams.swa_type == LLAMA_SWA_TYPE_NONE && " Use llama_kv_cache_unified_iswa for SWA" );
1159
1166
1160
1167
const auto n_kv = mctx_cur->get_n_kv ();
1168
+ const auto n_tokens = ubatch.n_tokens ;
1161
1169
1162
1170
inp->self_k_idxs = mctx_cur->build_input_k_idxs (ctx0, ubatch);
1163
1171
inp->self_v_idxs = mctx_cur->build_input_v_idxs (ctx0, ubatch);
@@ -1168,6 +1176,14 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified(c
1168
1176
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->self_kq_mask , GGML_TYPE_F16) : inp->self_kq_mask ;
1169
1177
}
1170
1178
1179
+ return inp;
1180
+ }
1181
+
1182
+ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified () const {
1183
+ const auto * mctx_cur = static_cast <const llama_kv_cache_unified_context *>(mctx);
1184
+
1185
+ auto inp = build_attn_inp_kv_unified_impl (ctx0, ubatch, hparams, cparams, mctx_cur);
1186
+
1171
1187
return (llm_graph_input_attn_kv_unified *) res->add_input (std::move (inp));
1172
1188
}
1173
1189
@@ -1346,10 +1362,11 @@ ggml_tensor * llm_graph_context::build_attn(
1346
1362
return cur;
1347
1363
}
1348
1364
1349
- llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa (const llama_kv_cache_unified_iswa_context * mctx_cur) const {
1350
- if (!mctx_cur) {
1351
- mctx_cur = static_cast <const llama_kv_cache_unified_iswa_context *>(mctx);
1352
- }
1365
+ // TODO: maybe separate the inner implementation into a separate function
1366
+ // like with the non-sliding window equivalent
1367
+ // once sliding-window hybrid caches are a thing.
1368
+ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa () const {
1369
+ const auto * mctx_cur = static_cast <const llama_kv_cache_unified_iswa_context *>(mctx);
1353
1370
1354
1371
auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, mctx_cur);
1355
1372
@@ -1417,10 +1434,9 @@ ggml_tensor * llm_graph_context::build_rs(
1417
1434
return output_states;
1418
1435
}
1419
1436
1420
- llm_graph_input_rs * llm_graph_context::build_rs_inp (const llama_memory_recurrent_context * mctx_cur) const {
1421
- if (!mctx_cur) {
1422
- mctx_cur = static_cast <const llama_memory_recurrent_context *>(mctx);
1423
- }
1437
+ static std::unique_ptr<llm_graph_input_rs> build_rs_inp_impl (
1438
+ ggml_context * ctx0,
1439
+ const llama_memory_recurrent_context * mctx_cur) {
1424
1440
1425
1441
auto inp = std::make_unique<llm_graph_input_rs>(mctx_cur);
1426
1442
@@ -1429,6 +1445,14 @@ llm_graph_input_rs * llm_graph_context::build_rs_inp(const llama_memory_recurren
1429
1445
inp->s_copy = ggml_new_tensor_1d (ctx0, GGML_TYPE_I32, n_rs);
1430
1446
ggml_set_input (inp->s_copy );
1431
1447
1448
+ return inp;
1449
+ }
1450
+
1451
+ llm_graph_input_rs * llm_graph_context::build_rs_inp () const {
1452
+ const auto * mctx_cur = static_cast <const llama_memory_recurrent_context *>(mctx);
1453
+
1454
+ auto inp = build_rs_inp_impl (ctx0, mctx_cur);
1455
+
1432
1456
return (llm_graph_input_rs *) res->add_input (std::move (inp));
1433
1457
}
1434
1458
@@ -1486,6 +1510,17 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
1486
1510
);
1487
1511
}
1488
1512
1513
+ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid () const {
1514
+ const auto * mctx_cur = static_cast <const llama_memory_hybrid_context *>(mctx);
1515
+
1516
+ auto inp_rs = build_rs_inp_impl (ctx0, mctx_cur->get_recr ());
1517
+ auto inp_attn = build_attn_inp_kv_unified_impl (ctx0, ubatch, hparams, cparams, mctx_cur->get_attn ());
1518
+
1519
+ auto inp = std::make_unique<llm_graph_input_mem_hybrid>(std::move (inp_attn), std::move (inp_rs), mctx_cur);
1520
+
1521
+ return (llm_graph_input_mem_hybrid *) res->add_input (std::move (inp));
1522
+ }
1523
+
1489
1524
void llm_graph_context::build_pooling (
1490
1525
ggml_cgraph * gf,
1491
1526
ggml_tensor * cls,
0 commit comments