@@ -235,7 +235,7 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
235
235
}
236
236
}
237
237
238
- void llm_graph_input_s_copy ::set_input (const llama_ubatch * ubatch) {
238
+ void llm_graph_input_rs ::set_input (const llama_ubatch * ubatch) {
239
239
GGML_UNUSED (ubatch);
240
240
241
241
const int64_t n_kv = kv_state->get_n_kv ();
@@ -251,6 +251,11 @@ void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) {
251
251
}
252
252
}
253
253
254
+ llm_graph_input_rs_hybrid_recurrent::llm_graph_input_rs_hybrid_recurrent (
255
+ const llama_kv_cache_hybrid_recurrent_state * kv_state) :
256
+ llm_graph_input_rs(kv_state->get_state_recurrent ()) {
257
+ }
258
+
254
259
void llm_graph_input_cross_embd::set_input (const llama_ubatch * ubatch) {
255
260
GGML_UNUSED (ubatch);
256
261
@@ -354,6 +359,13 @@ void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
354
359
}
355
360
}
356
361
362
+ llm_graph_input_attn_kv_hybrid_recurrent::llm_graph_input_attn_kv_hybrid_recurrent (
363
+ const llama_hparams & hparams,
364
+ const llama_cparams & cparams,
365
+ const llama_kv_cache_hybrid_recurrent_state * kv_state) :
366
+ llm_graph_input_attn_kv_unified(hparams, cparams, kv_state->get_state_attn ()) {
367
+ }
368
+
357
369
void llm_graph_input_attn_kv_unified_iswa::set_input (const llama_ubatch * ubatch) {
358
370
if (self_kq_mask) {
359
371
kv_state->get_base ()->set_input_kq_mask (self_kq_mask, ubatch, cparams.causal_attn );
@@ -955,25 +967,6 @@ ggml_tensor * llm_graph_context::build_inp_cls() const {
955
967
return cur;
956
968
}
957
969
958
- ggml_tensor * llm_graph_context::build_inp_s_copy (const llama_kv_cache_recurrent_state * kv_state) const {
959
- if (kv_state == nullptr ) {
960
- kv_state = static_cast <const llama_kv_cache_recurrent_state *>(mstate);
961
- }
962
-
963
- auto inp = std::make_unique<llm_graph_input_s_copy>(kv_state);
964
-
965
- const auto n_kv = kv_state->get_n_kv ();
966
-
967
- auto & cur = inp->s_copy ;
968
-
969
- cur = ggml_new_tensor_1d (ctx0, GGML_TYPE_I32, n_kv);
970
- ggml_set_input (cur);
971
-
972
- res->add_input (std::move (inp));
973
-
974
- return cur;
975
- }
976
-
977
970
ggml_tensor * llm_graph_context::build_inp_cross_embd () const {
978
971
auto inp = std::make_unique<llm_graph_input_cross_embd>(cross);
979
972
@@ -1255,9 +1248,7 @@ ggml_tensor * llm_graph_context::build_attn(
1255
1248
ggml_build_forward_expand (gf, k_cur);
1256
1249
ggml_build_forward_expand (gf, v_cur);
1257
1250
1258
- // NOTE: For hybrid caches, this may be a child of mstate, so we use the one
1259
- // encapsulated in inp
1260
- const auto * kv_state = inp->kv_state ;
1251
+ const auto * kv_state = static_cast <const llama_kv_cache_unified_state *>(mstate);
1261
1252
1262
1253
// store to KV cache
1263
1254
{
@@ -1289,15 +1280,14 @@ ggml_tensor * llm_graph_context::build_attn(
1289
1280
return cur;
1290
1281
}
1291
1282
1292
- llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_hybrid_recurrent () const {
1293
- const auto * kv_state = static_cast <const llama_kv_cache_hybrid_recurrent_state *>(mstate);
1294
-
1295
- auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, kv_state->get_state_attn ());
1283
+ llm_graph_input_attn_kv_hybrid_recurrent * llm_graph_context::build_attn_inp_kv_hybrid_recurrent () const {
1284
+ auto inp = std::make_unique<llm_graph_input_attn_kv_hybrid_recurrent>(
1285
+ hparams, cparams, static_cast <const llama_kv_cache_hybrid_recurrent_state *>(mstate));
1296
1286
1297
1287
{
1298
1288
GGML_ASSERT (hparams.swa_type == LLAMA_SWA_TYPE_NONE && " Hybrid recurrent is not supported with SWA attention layers" );
1299
1289
1300
- const auto n_kv = kv_state-> get_state_attn () ->get_n_kv ();
1290
+ const auto n_kv = inp-> kv_state ->get_n_kv ();
1301
1291
1302
1292
inp->self_kq_mask = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
1303
1293
// cb(inp->self_kq_mask, "KQ_mask", -1);
@@ -1306,7 +1296,57 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_hybrid_re
1306
1296
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->self_kq_mask , GGML_TYPE_F16) : inp->self_kq_mask ;
1307
1297
}
1308
1298
1309
- return (llm_graph_input_attn_kv_unified *) res->add_input (std::move (inp));
1299
+ return (llm_graph_input_attn_kv_hybrid_recurrent *) res->add_input (std::move (inp));
1300
+ }
1301
+
1302
+ ggml_tensor * llm_graph_context::build_attn (
1303
+ llm_graph_input_attn_kv_hybrid_recurrent * inp,
1304
+ ggml_cgraph * gf,
1305
+ ggml_tensor * wo,
1306
+ ggml_tensor * wo_b,
1307
+ ggml_tensor * q_cur,
1308
+ ggml_tensor * k_cur,
1309
+ ggml_tensor * v_cur,
1310
+ ggml_tensor * kq_b,
1311
+ ggml_tensor * v_mla,
1312
+ float kq_scale,
1313
+ int il) const {
1314
+ // these nodes are added to the graph together so that they are not reordered
1315
+ // by doing so, the number of splits in the graph is reduced
1316
+ ggml_build_forward_expand (gf, q_cur);
1317
+ ggml_build_forward_expand (gf, k_cur);
1318
+ ggml_build_forward_expand (gf, v_cur);
1319
+
1320
+ const auto * kv_state = static_cast <const llama_kv_cache_hybrid_recurrent_state *>(mstate)->get_state_attn ();
1321
+
1322
+ // store to KV cache
1323
+ {
1324
+ ggml_build_forward_expand (gf, kv_state->cpy_k (ctx0, k_cur, il));
1325
+ ggml_build_forward_expand (gf, kv_state->cpy_v (ctx0, v_cur, il));
1326
+ }
1327
+
1328
+ const auto & kq_mask = inp->get_kq_mask ();
1329
+
1330
+ ggml_tensor * q = q_cur;
1331
+ ggml_tensor * k = kv_state->get_k (ctx0, il);
1332
+ ggml_tensor * v = kv_state->get_v (ctx0, il);
1333
+
1334
+ ggml_tensor * cur = build_attn_mha (gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1335
+ cb (cur, " kqv_out" , il);
1336
+
1337
+ if (wo) {
1338
+ cur = build_lora_mm (wo, cur);
1339
+ if (arch == LLM_ARCH_GLM4) {
1340
+ // GLM4 seems to have numerical issues with half-precision accumulators
1341
+ ggml_mul_mat_set_prec (cur, GGML_PREC_F32);
1342
+ }
1343
+ }
1344
+
1345
+ if (wo_b) {
1346
+ cur = ggml_add (ctx0, cur, wo_b);
1347
+ }
1348
+
1349
+ return cur;
1310
1350
}
1311
1351
1312
1352
llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa () const {
@@ -1448,19 +1488,90 @@ ggml_tensor * llm_graph_context::build_attn(
1448
1488
return cur;
1449
1489
}
1450
1490
1451
- ggml_tensor * llm_graph_context::build_recurrent_state (
1452
- ggml_cgraph * gf,
1453
- ggml_tensor * s,
1454
- ggml_tensor * state_copy,
1455
- int32_t state_size,
1456
- int32_t n_seqs,
1457
- bool avoid_copies,
1458
- const llama_kv_cache_recurrent_state * kv_state) const {
1491
+ llm_graph_input_rs * llm_graph_context::build_rs_inp_recurrent () const {
1492
+ const auto * kv_state = static_cast <const llama_kv_cache_recurrent_state *>(mstate);
1493
+
1494
+ auto inp = std::make_unique<llm_graph_input_rs>(kv_state);
1495
+
1496
+ const auto n_kv = kv_state->get_n_kv ();
1497
+
1498
+ auto & cur = inp->s_copy ;
1499
+
1500
+ cur = ggml_new_tensor_1d (ctx0, GGML_TYPE_I32, n_kv);
1501
+ ggml_set_input (cur);
1502
+
1503
+ return (llm_graph_input_rs *) res->add_input (std::move (inp));
1504
+ }
1505
+
1506
+ ggml_tensor * llm_graph_context::build_rs (
1507
+ llm_graph_input_rs * inp,
1508
+ ggml_cgraph * gf,
1509
+ ggml_tensor * s,
1510
+ int32_t state_size,
1511
+ int32_t n_seqs,
1512
+ bool avoid_copies) const {
1513
+
1514
+ const auto * kv_state = static_cast <const llama_kv_cache_recurrent_state *>(mstate);
1515
+
1516
+ const auto n_kv = kv_state->get_n_kv ();
1517
+ const auto kv_head = kv_state->get_head ();
1518
+ const auto rs_zero = kv_state->get_rs_z ();
1519
+
1520
+ ggml_tensor * states = ggml_reshape_2d (ctx0, s, state_size, kv_state->get_size ());
1521
+
1522
+ // Clear a single state which will then be copied to the other cleared states.
1523
+ // Note that this is a no-op when the view is zero-sized.
1524
+ ggml_tensor * state_zero = ggml_view_1d (ctx0, states, state_size*(rs_zero >= 0 ), rs_zero*states->nb [1 ]*(rs_zero >= 0 ));
1525
+ ggml_build_forward_expand (gf, ggml_scale_inplace (ctx0, state_zero, 0 ));
1526
+
1527
+ ggml_tensor * output_states;
1459
1528
1460
- if (kv_state == nullptr ) {
1461
- kv_state = static_cast <const llama_kv_cache_recurrent_state *>(mstate);
1529
+ if (!avoid_copies) {
1530
+ // copy states
1531
+ // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
1532
+ // {state_size, kv_size} -> {state_size, n_seqs}
1533
+ output_states = ggml_get_rows (ctx0, states, ggml_view_1d (ctx0, inp->s_copy , n_seqs, 0 ));
1534
+ ggml_build_forward_expand (gf, output_states);
1535
+ } else {
1536
+ // FIXME: make the gathering operation happen before the copy below
1537
+ // (maybe with an optional lambda function passed as a parameter instead of `avoid_copies`?)
1538
+ output_states = states;
1462
1539
}
1463
1540
1541
+ // copy extra states which won't be changed further (between n_seqs and n_kv)
1542
+ ggml_tensor * states_extra = ggml_get_rows (ctx0, states, ggml_view_1d (ctx0, inp->s_copy , n_kv - n_seqs, n_seqs*inp->s_copy ->nb [0 ]));
1543
+ ggml_build_forward_expand (gf,
1544
+ ggml_cpy (ctx0,
1545
+ states_extra,
1546
+ ggml_view_1d (ctx0, s, state_size*(n_kv - n_seqs), (kv_head + n_seqs)*state_size*ggml_element_size (s))));
1547
+
1548
+ return output_states;
1549
+ }
1550
+
1551
+ llm_graph_input_rs_hybrid_recurrent * llm_graph_context::build_rs_inp_hybrid_recurrent () const {
1552
+ auto inp = std::make_unique<llm_graph_input_rs_hybrid_recurrent>(
1553
+ static_cast <const llama_kv_cache_hybrid_recurrent_state *>(mstate));
1554
+
1555
+ const auto n_kv = inp->kv_state ->get_n_kv ();
1556
+
1557
+ auto & cur = inp->s_copy ;
1558
+
1559
+ cur = ggml_new_tensor_1d (ctx0, GGML_TYPE_I32, n_kv);
1560
+ ggml_set_input (cur);
1561
+
1562
+ return (llm_graph_input_rs_hybrid_recurrent *) res->add_input (std::move (inp));
1563
+ }
1564
+
1565
+ ggml_tensor * llm_graph_context::build_rs (
1566
+ llm_graph_input_rs_hybrid_recurrent * inp,
1567
+ ggml_cgraph * gf,
1568
+ ggml_tensor * s,
1569
+ int32_t state_size,
1570
+ int32_t n_seqs,
1571
+ bool avoid_copies) const {
1572
+
1573
+ const auto * kv_state = static_cast <const llama_kv_cache_hybrid_recurrent_state *>(mstate)->get_state_recurrent ();
1574
+
1464
1575
const auto n_kv = kv_state->get_n_kv ();
1465
1576
const auto kv_head = kv_state->get_head ();
1466
1577
const auto rs_zero = kv_state->get_rs_z ();
@@ -1478,7 +1589,7 @@ ggml_tensor * llm_graph_context::build_recurrent_state(
1478
1589
// copy states
1479
1590
// NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
1480
1591
// {state_size, kv_size} -> {state_size, n_seqs}
1481
- output_states = ggml_get_rows (ctx0, states, ggml_view_1d (ctx0, state_copy , n_seqs, 0 ));
1592
+ output_states = ggml_get_rows (ctx0, states, ggml_view_1d (ctx0, inp-> s_copy , n_seqs, 0 ));
1482
1593
ggml_build_forward_expand (gf, output_states);
1483
1594
} else {
1484
1595
// FIXME: make the gathering operation happen before the copy below
@@ -1487,7 +1598,7 @@ ggml_tensor * llm_graph_context::build_recurrent_state(
1487
1598
}
1488
1599
1489
1600
// copy extra states which won't be changed further (between n_seqs and n_kv)
1490
- ggml_tensor * states_extra = ggml_get_rows (ctx0, states, ggml_view_1d (ctx0, state_copy , n_kv - n_seqs, n_seqs*state_copy ->nb [0 ]));
1601
+ ggml_tensor * states_extra = ggml_get_rows (ctx0, states, ggml_view_1d (ctx0, inp-> s_copy , n_kv - n_seqs, n_seqs*inp-> s_copy ->nb [0 ]));
1491
1602
ggml_build_forward_expand (gf,
1492
1603
ggml_cpy (ctx0,
1493
1604
states_extra,
@@ -1497,9 +1608,9 @@ ggml_tensor * llm_graph_context::build_recurrent_state(
1497
1608
}
1498
1609
1499
1610
ggml_tensor * llm_graph_context::build_rwkv_token_shift_load (
1500
- ggml_cgraph * gf ,
1501
- ggml_tensor * state_copy ,
1502
- const llama_ubatch & ubatch,
1611
+ llm_graph_input_rs * inp ,
1612
+ ggml_cgraph * gf ,
1613
+ const llama_ubatch & ubatch,
1503
1614
int il) const {
1504
1615
const auto * kv_state = static_cast <const llama_kv_cache_recurrent_state *>(mstate);
1505
1616
@@ -1509,8 +1620,8 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
1509
1620
1510
1621
ggml_tensor * token_shift_all = kv_state->get_k_l (il);
1511
1622
1512
- ggml_tensor * token_shift = build_recurrent_state (
1513
- gf, token_shift_all, state_copy ,
1623
+ ggml_tensor * token_shift = build_rs (
1624
+ inp, gf, token_shift_all ,
1514
1625
hparams.n_embd_k_s (), n_seqs);
1515
1626
1516
1627
token_shift = ggml_reshape_3d (ctx0, token_shift, hparams.n_embd , token_shift_count, n_seqs);
0 commit comments