@@ -239,7 +239,7 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
239
239
}
240
240
}
241
241
242
- void llm_graph_input_s_copy ::set_input (const llama_ubatch * ubatch) {
242
+ void llm_graph_input_rs ::set_input (const llama_ubatch * ubatch) {
243
243
GGML_UNUSED (ubatch);
244
244
245
245
const int64_t n_kv = kv_state->get_n_kv ();
@@ -255,6 +255,11 @@ void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) {
255
255
}
256
256
}
257
257
258
+ llm_graph_input_rs_hybrid_recurrent::llm_graph_input_rs_hybrid_recurrent (
259
+ const llama_kv_cache_hybrid_recurrent_state * kv_state) :
260
+ llm_graph_input_rs(kv_state->get_state_recurrent ()) {
261
+ }
262
+
258
263
void llm_graph_input_cross_embd::set_input (const llama_ubatch * ubatch) {
259
264
GGML_UNUSED (ubatch);
260
265
@@ -360,6 +365,13 @@ void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
360
365
}
361
366
}
362
367
368
+ llm_graph_input_attn_kv_hybrid_recurrent::llm_graph_input_attn_kv_hybrid_recurrent (
369
+ const llama_hparams & hparams,
370
+ const llama_cparams & cparams,
371
+ const llama_kv_cache_hybrid_recurrent_state * kv_state) :
372
+ llm_graph_input_attn_kv_unified(hparams, cparams, kv_state->get_state_attn ()) {
373
+ }
374
+
363
375
void llm_graph_input_attn_kv_unified_iswa::set_input (const llama_ubatch * ubatch) {
364
376
if (self_kq_mask) {
365
377
kv_state->get_base ()->set_input_kq_mask (self_kq_mask, ubatch, cparams.causal_attn );
@@ -962,25 +974,6 @@ ggml_tensor * llm_graph_context::build_inp_cls() const {
962
974
return cur;
963
975
}
964
976
965
- ggml_tensor * llm_graph_context::build_inp_s_copy (const llama_kv_cache_recurrent_state * kv_state) const {
966
- if (kv_state == nullptr ) {
967
- kv_state = static_cast <const llama_kv_cache_recurrent_state *>(mstate);
968
- }
969
-
970
- auto inp = std::make_unique<llm_graph_input_s_copy>(kv_state);
971
-
972
- const auto n_kv = kv_state->get_n_kv ();
973
-
974
- auto & cur = inp->s_copy ;
975
-
976
- cur = ggml_new_tensor_1d (ctx0, GGML_TYPE_I32, n_kv);
977
- ggml_set_input (cur);
978
-
979
- res->add_input (std::move (inp));
980
-
981
- return cur;
982
- }
983
-
984
977
ggml_tensor * llm_graph_context::build_inp_cross_embd () const {
985
978
auto inp = std::make_unique<llm_graph_input_cross_embd>(cross);
986
979
@@ -1262,9 +1255,7 @@ ggml_tensor * llm_graph_context::build_attn(
1262
1255
ggml_build_forward_expand (gf, k_cur);
1263
1256
ggml_build_forward_expand (gf, v_cur);
1264
1257
1265
- // NOTE: For hybrid caches, this may be a child of mstate, so we use the one
1266
- // encapsulated in inp
1267
- const auto * kv_state = inp->kv_state ;
1258
+ const auto * kv_state = static_cast <const llama_kv_cache_unified_state *>(mstate);
1268
1259
1269
1260
// store to KV cache
1270
1261
{
@@ -1296,15 +1287,14 @@ ggml_tensor * llm_graph_context::build_attn(
1296
1287
return cur;
1297
1288
}
1298
1289
1299
- llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_hybrid_recurrent () const {
1300
- const auto * kv_state = static_cast <const llama_kv_cache_hybrid_recurrent_state *>(mstate);
1301
-
1302
- auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, kv_state->get_state_attn ());
1290
+ llm_graph_input_attn_kv_hybrid_recurrent * llm_graph_context::build_attn_inp_kv_hybrid_recurrent () const {
1291
+ auto inp = std::make_unique<llm_graph_input_attn_kv_hybrid_recurrent>(
1292
+ hparams, cparams, static_cast <const llama_kv_cache_hybrid_recurrent_state *>(mstate));
1303
1293
1304
1294
{
1305
1295
GGML_ASSERT (hparams.swa_type == LLAMA_SWA_TYPE_NONE && " Hybrid recurrent is not supported with SWA attention layers" );
1306
1296
1307
- const auto n_kv = kv_state-> get_state_attn () ->get_n_kv ();
1297
+ const auto n_kv = inp-> kv_state ->get_n_kv ();
1308
1298
1309
1299
inp->self_kq_mask = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
1310
1300
// cb(inp->self_kq_mask, "KQ_mask", -1);
@@ -1313,7 +1303,57 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_hybrid_re
1313
1303
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->self_kq_mask , GGML_TYPE_F16) : inp->self_kq_mask ;
1314
1304
}
1315
1305
1316
- return (llm_graph_input_attn_kv_unified *) res->add_input (std::move (inp));
1306
+ return (llm_graph_input_attn_kv_hybrid_recurrent *) res->add_input (std::move (inp));
1307
+ }
1308
+
1309
+ ggml_tensor * llm_graph_context::build_attn (
1310
+ llm_graph_input_attn_kv_hybrid_recurrent * inp,
1311
+ ggml_cgraph * gf,
1312
+ ggml_tensor * wo,
1313
+ ggml_tensor * wo_b,
1314
+ ggml_tensor * q_cur,
1315
+ ggml_tensor * k_cur,
1316
+ ggml_tensor * v_cur,
1317
+ ggml_tensor * kq_b,
1318
+ ggml_tensor * v_mla,
1319
+ float kq_scale,
1320
+ int il) const {
1321
+ // these nodes are added to the graph together so that they are not reordered
1322
+ // by doing so, the number of splits in the graph is reduced
1323
+ ggml_build_forward_expand (gf, q_cur);
1324
+ ggml_build_forward_expand (gf, k_cur);
1325
+ ggml_build_forward_expand (gf, v_cur);
1326
+
1327
+ const auto * kv_state = static_cast <const llama_kv_cache_hybrid_recurrent_state *>(mstate)->get_state_attn ();
1328
+
1329
+ // store to KV cache
1330
+ {
1331
+ ggml_build_forward_expand (gf, kv_state->cpy_k (ctx0, k_cur, il));
1332
+ ggml_build_forward_expand (gf, kv_state->cpy_v (ctx0, v_cur, il));
1333
+ }
1334
+
1335
+ const auto & kq_mask = inp->get_kq_mask ();
1336
+
1337
+ ggml_tensor * q = q_cur;
1338
+ ggml_tensor * k = kv_state->get_k (ctx0, il);
1339
+ ggml_tensor * v = kv_state->get_v (ctx0, il);
1340
+
1341
+ ggml_tensor * cur = build_attn_mha (gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1342
+ cb (cur, " kqv_out" , il);
1343
+
1344
+ if (wo) {
1345
+ cur = build_lora_mm (wo, cur);
1346
+ if (arch == LLM_ARCH_GLM4) {
1347
+ // GLM4 seems to have numerical issues with half-precision accumulators
1348
+ ggml_mul_mat_set_prec (cur, GGML_PREC_F32);
1349
+ }
1350
+ }
1351
+
1352
+ if (wo_b) {
1353
+ cur = ggml_add (ctx0, cur, wo_b);
1354
+ }
1355
+
1356
+ return cur;
1317
1357
}
1318
1358
1319
1359
llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa () const {
@@ -1455,19 +1495,90 @@ ggml_tensor * llm_graph_context::build_attn(
1455
1495
return cur;
1456
1496
}
1457
1497
1458
- ggml_tensor * llm_graph_context::build_recurrent_state (
1459
- ggml_cgraph * gf,
1460
- ggml_tensor * s,
1461
- ggml_tensor * state_copy,
1462
- int32_t state_size,
1463
- int32_t n_seqs,
1464
- bool avoid_copies,
1465
- const llama_kv_cache_recurrent_state * kv_state) const {
1498
+ llm_graph_input_rs * llm_graph_context::build_rs_inp_recurrent () const {
1499
+ const auto * kv_state = static_cast <const llama_kv_cache_recurrent_state *>(mstate);
1500
+
1501
+ auto inp = std::make_unique<llm_graph_input_rs>(kv_state);
1502
+
1503
+ const auto n_kv = kv_state->get_n_kv ();
1504
+
1505
+ auto & cur = inp->s_copy ;
1506
+
1507
+ cur = ggml_new_tensor_1d (ctx0, GGML_TYPE_I32, n_kv);
1508
+ ggml_set_input (cur);
1509
+
1510
+ return (llm_graph_input_rs *) res->add_input (std::move (inp));
1511
+ }
1512
+
1513
+ ggml_tensor * llm_graph_context::build_rs (
1514
+ llm_graph_input_rs * inp,
1515
+ ggml_cgraph * gf,
1516
+ ggml_tensor * s,
1517
+ int32_t state_size,
1518
+ int32_t n_seqs,
1519
+ bool avoid_copies) const {
1520
+
1521
+ const auto * kv_state = static_cast <const llama_kv_cache_recurrent_state *>(mstate);
1522
+
1523
+ const auto n_kv = kv_state->get_n_kv ();
1524
+ const auto kv_head = kv_state->get_head ();
1525
+ const auto rs_zero = kv_state->get_rs_z ();
1526
+
1527
+ ggml_tensor * states = ggml_reshape_2d (ctx0, s, state_size, kv_state->get_size ());
1528
+
1529
+ // Clear a single state which will then be copied to the other cleared states.
1530
+ // Note that this is a no-op when the view is zero-sized.
1531
+ ggml_tensor * state_zero = ggml_view_1d (ctx0, states, state_size*(rs_zero >= 0 ), rs_zero*states->nb [1 ]*(rs_zero >= 0 ));
1532
+ ggml_build_forward_expand (gf, ggml_scale_inplace (ctx0, state_zero, 0 ));
1533
+
1534
+ ggml_tensor * output_states;
1466
1535
1467
- if (kv_state == nullptr ) {
1468
- kv_state = static_cast <const llama_kv_cache_recurrent_state *>(mstate);
1536
+ if (!avoid_copies) {
1537
+ // copy states
1538
+ // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
1539
+ // {state_size, kv_size} -> {state_size, n_seqs}
1540
+ output_states = ggml_get_rows (ctx0, states, ggml_view_1d (ctx0, inp->s_copy , n_seqs, 0 ));
1541
+ ggml_build_forward_expand (gf, output_states);
1542
+ } else {
1543
+ // FIXME: make the gathering operation happen before the copy below
1544
+ // (maybe with an optional lambda function passed as a parameter instead of `avoid_copies`?)
1545
+ output_states = states;
1469
1546
}
1470
1547
1548
+ // copy extra states which won't be changed further (between n_seqs and n_kv)
1549
+ ggml_tensor * states_extra = ggml_get_rows (ctx0, states, ggml_view_1d (ctx0, inp->s_copy , n_kv - n_seqs, n_seqs*inp->s_copy ->nb [0 ]));
1550
+ ggml_build_forward_expand (gf,
1551
+ ggml_cpy (ctx0,
1552
+ states_extra,
1553
+ ggml_view_1d (ctx0, s, state_size*(n_kv - n_seqs), (kv_head + n_seqs)*state_size*ggml_element_size (s))));
1554
+
1555
+ return output_states;
1556
+ }
1557
+
1558
+ llm_graph_input_rs_hybrid_recurrent * llm_graph_context::build_rs_inp_hybrid_recurrent () const {
1559
+ auto inp = std::make_unique<llm_graph_input_rs_hybrid_recurrent>(
1560
+ static_cast <const llama_kv_cache_hybrid_recurrent_state *>(mstate));
1561
+
1562
+ const auto n_kv = inp->kv_state ->get_n_kv ();
1563
+
1564
+ auto & cur = inp->s_copy ;
1565
+
1566
+ cur = ggml_new_tensor_1d (ctx0, GGML_TYPE_I32, n_kv);
1567
+ ggml_set_input (cur);
1568
+
1569
+ return (llm_graph_input_rs_hybrid_recurrent *) res->add_input (std::move (inp));
1570
+ }
1571
+
1572
+ ggml_tensor * llm_graph_context::build_rs (
1573
+ llm_graph_input_rs_hybrid_recurrent * inp,
1574
+ ggml_cgraph * gf,
1575
+ ggml_tensor * s,
1576
+ int32_t state_size,
1577
+ int32_t n_seqs,
1578
+ bool avoid_copies) const {
1579
+
1580
+ const auto * kv_state = static_cast <const llama_kv_cache_hybrid_recurrent_state *>(mstate)->get_state_recurrent ();
1581
+
1471
1582
const auto n_kv = kv_state->get_n_kv ();
1472
1583
const auto kv_head = kv_state->get_head ();
1473
1584
const auto rs_zero = kv_state->get_rs_z ();
@@ -1485,7 +1596,7 @@ ggml_tensor * llm_graph_context::build_recurrent_state(
1485
1596
// copy states
1486
1597
// NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
1487
1598
// {state_size, kv_size} -> {state_size, n_seqs}
1488
- output_states = ggml_get_rows (ctx0, states, ggml_view_1d (ctx0, state_copy , n_seqs, 0 ));
1599
+ output_states = ggml_get_rows (ctx0, states, ggml_view_1d (ctx0, inp-> s_copy , n_seqs, 0 ));
1489
1600
ggml_build_forward_expand (gf, output_states);
1490
1601
} else {
1491
1602
// FIXME: make the gathering operation happen before the copy below
@@ -1494,7 +1605,7 @@ ggml_tensor * llm_graph_context::build_recurrent_state(
1494
1605
}
1495
1606
1496
1607
// copy extra states which won't be changed further (between n_seqs and n_kv)
1497
- ggml_tensor * states_extra = ggml_get_rows (ctx0, states, ggml_view_1d (ctx0, state_copy , n_kv - n_seqs, n_seqs*state_copy ->nb [0 ]));
1608
+ ggml_tensor * states_extra = ggml_get_rows (ctx0, states, ggml_view_1d (ctx0, inp-> s_copy , n_kv - n_seqs, n_seqs*inp-> s_copy ->nb [0 ]));
1498
1609
ggml_build_forward_expand (gf,
1499
1610
ggml_cpy (ctx0,
1500
1611
states_extra,
@@ -1504,9 +1615,9 @@ ggml_tensor * llm_graph_context::build_recurrent_state(
1504
1615
}
1505
1616
1506
1617
ggml_tensor * llm_graph_context::build_rwkv_token_shift_load (
1507
- ggml_cgraph * gf ,
1508
- ggml_tensor * state_copy ,
1509
- const llama_ubatch & ubatch,
1618
+ llm_graph_input_rs * inp ,
1619
+ ggml_cgraph * gf ,
1620
+ const llama_ubatch & ubatch,
1510
1621
int il) const {
1511
1622
const auto * kv_state = static_cast <const llama_kv_cache_recurrent_state *>(mstate);
1512
1623
@@ -1516,8 +1627,8 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
1516
1627
1517
1628
ggml_tensor * token_shift_all = kv_state->get_k_l (il);
1518
1629
1519
- ggml_tensor * token_shift = build_recurrent_state (
1520
- gf, token_shift_all, state_copy ,
1630
+ ggml_tensor * token_shift = build_rs (
1631
+ inp, gf, token_shift_all ,
1521
1632
hparams.n_embd_k_s (), n_seqs);
1522
1633
1523
1634
token_shift = ggml_reshape_3d (ctx0, token_shift, hparams.n_embd , token_shift_count, n_seqs);
0 commit comments