Skip to content

Commit 6253c7c

Browse files
committed
feat: Overhaul build_recurrent_state / build_inp_s_copy to match attention pattern
https://github.com/ggml-org/llama.cpp/pull/13979/files#r2141701738 This is a big overhaul to bring consistency between how inputs and per- layer components are created for attention layers and recurrent layers. The main changes are: - Rename class llm_graph_input_s_copy -> llm_graph_input_rs - Add a corresponding llm_graph_input_rs_hybrid_recurrent - Rename build_inp_s_copy -> build_rs_inp_recurrent - Add a corresponding build_rs_inp_hybrid_recurrent - Rename build_recurrent_state -> build_rs to match build_attn w/ llm_graph_input_rs * as the first input - Add a corresponding overload of build_rs w/ llm_graph_input_rs_hybrid_recurrent * as the first input - Add a llm_graph_input_attn_kv_hybrid_recurrent analogous to llm_graph_input_attn_kv_unified - Add a build_attn override that takes llm_graph_input_attn_kv_hybrid_recurrent * as the first input This makes the two paradigms fully consistent. The main drawback is the code duplication in the build_attn and build_rs implementations where the only difference between implementations is how they cast the memory state. Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
1 parent f3e34bb commit 6253c7c

File tree

3 files changed

+240
-100
lines changed

3 files changed

+240
-100
lines changed

src/llama-graph.cpp

Lines changed: 157 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
235235
}
236236
}
237237

238-
void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) {
238+
void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) {
239239
GGML_UNUSED(ubatch);
240240

241241
const int64_t n_kv = kv_state->get_n_kv();
@@ -251,6 +251,11 @@ void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) {
251251
}
252252
}
253253

254+
llm_graph_input_rs_hybrid_recurrent::llm_graph_input_rs_hybrid_recurrent(
255+
const llama_kv_cache_hybrid_recurrent_state * kv_state) :
256+
llm_graph_input_rs(kv_state->get_state_recurrent()) {
257+
}
258+
254259
void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
255260
GGML_UNUSED(ubatch);
256261

@@ -354,6 +359,13 @@ void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
354359
}
355360
}
356361

362+
llm_graph_input_attn_kv_hybrid_recurrent::llm_graph_input_attn_kv_hybrid_recurrent(
363+
const llama_hparams & hparams,
364+
const llama_cparams & cparams,
365+
const llama_kv_cache_hybrid_recurrent_state * kv_state) :
366+
llm_graph_input_attn_kv_unified(hparams, cparams, kv_state->get_state_attn()) {
367+
}
368+
357369
void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
358370
if (self_kq_mask) {
359371
kv_state->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
@@ -955,25 +967,6 @@ ggml_tensor * llm_graph_context::build_inp_cls() const {
955967
return cur;
956968
}
957969

958-
ggml_tensor * llm_graph_context::build_inp_s_copy(const llama_kv_cache_recurrent_state * kv_state) const {
959-
if (kv_state == nullptr) {
960-
kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
961-
}
962-
963-
auto inp = std::make_unique<llm_graph_input_s_copy>(kv_state);
964-
965-
const auto n_kv = kv_state->get_n_kv();
966-
967-
auto & cur = inp->s_copy;
968-
969-
cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_kv);
970-
ggml_set_input(cur);
971-
972-
res->add_input(std::move(inp));
973-
974-
return cur;
975-
}
976-
977970
ggml_tensor * llm_graph_context::build_inp_cross_embd() const {
978971
auto inp = std::make_unique<llm_graph_input_cross_embd>(cross);
979972

@@ -1255,9 +1248,7 @@ ggml_tensor * llm_graph_context::build_attn(
12551248
ggml_build_forward_expand(gf, k_cur);
12561249
ggml_build_forward_expand(gf, v_cur);
12571250

1258-
// NOTE: For hybrid caches, this may be a child of mstate, so we use the one
1259-
// encapsulated in inp
1260-
const auto * kv_state = inp->kv_state;
1251+
const auto * kv_state = static_cast<const llama_kv_cache_unified_state *>(mstate);
12611252

12621253
// store to KV cache
12631254
{
@@ -1289,15 +1280,14 @@ ggml_tensor * llm_graph_context::build_attn(
12891280
return cur;
12901281
}
12911282

1292-
llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_hybrid_recurrent() const {
1293-
const auto * kv_state = static_cast<const llama_kv_cache_hybrid_recurrent_state *>(mstate);
1294-
1295-
auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, kv_state->get_state_attn());
1283+
llm_graph_input_attn_kv_hybrid_recurrent * llm_graph_context::build_attn_inp_kv_hybrid_recurrent() const {
1284+
auto inp = std::make_unique<llm_graph_input_attn_kv_hybrid_recurrent>(
1285+
hparams, cparams, static_cast<const llama_kv_cache_hybrid_recurrent_state *>(mstate));
12961286

12971287
{
12981288
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Hybrid recurrent is not supported with SWA attention layers");
12991289

1300-
const auto n_kv = kv_state->get_state_attn()->get_n_kv();
1290+
const auto n_kv = inp->kv_state->get_n_kv();
13011291

13021292
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
13031293
//cb(inp->self_kq_mask, "KQ_mask", -1);
@@ -1306,7 +1296,57 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_hybrid_re
13061296
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
13071297
}
13081298

1309-
return (llm_graph_input_attn_kv_unified *) res->add_input(std::move(inp));
1299+
return (llm_graph_input_attn_kv_hybrid_recurrent *) res->add_input(std::move(inp));
1300+
}
1301+
1302+
ggml_tensor * llm_graph_context::build_attn(
1303+
llm_graph_input_attn_kv_hybrid_recurrent * inp,
1304+
ggml_cgraph * gf,
1305+
ggml_tensor * wo,
1306+
ggml_tensor * wo_b,
1307+
ggml_tensor * q_cur,
1308+
ggml_tensor * k_cur,
1309+
ggml_tensor * v_cur,
1310+
ggml_tensor * kq_b,
1311+
ggml_tensor * v_mla,
1312+
float kq_scale,
1313+
int il) const {
1314+
// these nodes are added to the graph together so that they are not reordered
1315+
// by doing so, the number of splits in the graph is reduced
1316+
ggml_build_forward_expand(gf, q_cur);
1317+
ggml_build_forward_expand(gf, k_cur);
1318+
ggml_build_forward_expand(gf, v_cur);
1319+
1320+
const auto * kv_state = static_cast<const llama_kv_cache_hybrid_recurrent_state *>(mstate)->get_state_attn();
1321+
1322+
// store to KV cache
1323+
{
1324+
ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il));
1325+
ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il));
1326+
}
1327+
1328+
const auto & kq_mask = inp->get_kq_mask();
1329+
1330+
ggml_tensor * q = q_cur;
1331+
ggml_tensor * k = kv_state->get_k(ctx0, il);
1332+
ggml_tensor * v = kv_state->get_v(ctx0, il);
1333+
1334+
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1335+
cb(cur, "kqv_out", il);
1336+
1337+
if (wo) {
1338+
cur = build_lora_mm(wo, cur);
1339+
if (arch == LLM_ARCH_GLM4) {
1340+
// GLM4 seems to have numerical issues with half-precision accumulators
1341+
ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
1342+
}
1343+
}
1344+
1345+
if (wo_b) {
1346+
cur = ggml_add(ctx0, cur, wo_b);
1347+
}
1348+
1349+
return cur;
13101350
}
13111351

13121352
llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
@@ -1448,19 +1488,90 @@ ggml_tensor * llm_graph_context::build_attn(
14481488
return cur;
14491489
}
14501490

1451-
ggml_tensor * llm_graph_context::build_recurrent_state(
1452-
ggml_cgraph * gf,
1453-
ggml_tensor * s,
1454-
ggml_tensor * state_copy,
1455-
int32_t state_size,
1456-
int32_t n_seqs,
1457-
bool avoid_copies,
1458-
const llama_kv_cache_recurrent_state * kv_state) const {
1491+
llm_graph_input_rs * llm_graph_context::build_rs_inp_recurrent() const {
1492+
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
1493+
1494+
auto inp = std::make_unique<llm_graph_input_rs>(kv_state);
1495+
1496+
const auto n_kv = kv_state->get_n_kv();
1497+
1498+
auto & cur = inp->s_copy;
1499+
1500+
cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_kv);
1501+
ggml_set_input(cur);
1502+
1503+
return (llm_graph_input_rs *) res->add_input(std::move(inp));
1504+
}
1505+
1506+
ggml_tensor * llm_graph_context::build_rs(
1507+
llm_graph_input_rs * inp,
1508+
ggml_cgraph * gf,
1509+
ggml_tensor * s,
1510+
int32_t state_size,
1511+
int32_t n_seqs,
1512+
bool avoid_copies) const {
1513+
1514+
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
1515+
1516+
const auto n_kv = kv_state->get_n_kv();
1517+
const auto kv_head = kv_state->get_head();
1518+
const auto rs_zero = kv_state->get_rs_z();
1519+
1520+
ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, kv_state->get_size());
1521+
1522+
// Clear a single state which will then be copied to the other cleared states.
1523+
// Note that this is a no-op when the view is zero-sized.
1524+
ggml_tensor * state_zero = ggml_view_1d(ctx0, states, state_size*(rs_zero >= 0), rs_zero*states->nb[1]*(rs_zero >= 0));
1525+
ggml_build_forward_expand(gf, ggml_scale_inplace(ctx0, state_zero, 0));
1526+
1527+
ggml_tensor * output_states;
14591528

1460-
if (kv_state == nullptr) {
1461-
kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
1529+
if (!avoid_copies) {
1530+
// copy states
1531+
// NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
1532+
// {state_size, kv_size} -> {state_size, n_seqs}
1533+
output_states = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, inp->s_copy, n_seqs, 0));
1534+
ggml_build_forward_expand(gf, output_states);
1535+
} else {
1536+
// FIXME: make the gathering operation happen before the copy below
1537+
// (maybe with an optional lambda function passed as a parameter instead of `avoid_copies`?)
1538+
output_states = states;
14621539
}
14631540

1541+
// copy extra states which won't be changed further (between n_seqs and n_kv)
1542+
ggml_tensor * states_extra = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, inp->s_copy, n_kv - n_seqs, n_seqs*inp->s_copy->nb[0]));
1543+
ggml_build_forward_expand(gf,
1544+
ggml_cpy(ctx0,
1545+
states_extra,
1546+
ggml_view_1d(ctx0, s, state_size*(n_kv - n_seqs), (kv_head + n_seqs)*state_size*ggml_element_size(s))));
1547+
1548+
return output_states;
1549+
}
1550+
1551+
llm_graph_input_rs_hybrid_recurrent * llm_graph_context::build_rs_inp_hybrid_recurrent() const {
1552+
auto inp = std::make_unique<llm_graph_input_rs_hybrid_recurrent>(
1553+
static_cast<const llama_kv_cache_hybrid_recurrent_state *>(mstate));
1554+
1555+
const auto n_kv = inp->kv_state->get_n_kv();
1556+
1557+
auto & cur = inp->s_copy;
1558+
1559+
cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_kv);
1560+
ggml_set_input(cur);
1561+
1562+
return (llm_graph_input_rs_hybrid_recurrent *) res->add_input(std::move(inp));
1563+
}
1564+
1565+
ggml_tensor * llm_graph_context::build_rs(
1566+
llm_graph_input_rs_hybrid_recurrent * inp,
1567+
ggml_cgraph * gf,
1568+
ggml_tensor * s,
1569+
int32_t state_size,
1570+
int32_t n_seqs,
1571+
bool avoid_copies) const {
1572+
1573+
const auto * kv_state = static_cast<const llama_kv_cache_hybrid_recurrent_state *>(mstate)->get_state_recurrent();
1574+
14641575
const auto n_kv = kv_state->get_n_kv();
14651576
const auto kv_head = kv_state->get_head();
14661577
const auto rs_zero = kv_state->get_rs_z();
@@ -1478,7 +1589,7 @@ ggml_tensor * llm_graph_context::build_recurrent_state(
14781589
// copy states
14791590
// NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
14801591
// {state_size, kv_size} -> {state_size, n_seqs}
1481-
output_states = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_seqs, 0));
1592+
output_states = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, inp->s_copy, n_seqs, 0));
14821593
ggml_build_forward_expand(gf, output_states);
14831594
} else {
14841595
// FIXME: make the gathering operation happen before the copy below
@@ -1487,7 +1598,7 @@ ggml_tensor * llm_graph_context::build_recurrent_state(
14871598
}
14881599

14891600
// copy extra states which won't be changed further (between n_seqs and n_kv)
1490-
ggml_tensor * states_extra = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_kv - n_seqs, n_seqs*state_copy->nb[0]));
1601+
ggml_tensor * states_extra = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, inp->s_copy, n_kv - n_seqs, n_seqs*inp->s_copy->nb[0]));
14911602
ggml_build_forward_expand(gf,
14921603
ggml_cpy(ctx0,
14931604
states_extra,
@@ -1497,9 +1608,9 @@ ggml_tensor * llm_graph_context::build_recurrent_state(
14971608
}
14981609

14991610
ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
1500-
ggml_cgraph * gf,
1501-
ggml_tensor * state_copy,
1502-
const llama_ubatch & ubatch,
1611+
llm_graph_input_rs * inp,
1612+
ggml_cgraph * gf,
1613+
const llama_ubatch & ubatch,
15031614
int il) const {
15041615
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
15051616

@@ -1509,8 +1620,8 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
15091620

15101621
ggml_tensor * token_shift_all = kv_state->get_k_l(il);
15111622

1512-
ggml_tensor * token_shift = build_recurrent_state(
1513-
gf, token_shift_all, state_copy,
1623+
ggml_tensor * token_shift = build_rs(
1624+
inp, gf, token_shift_all,
15141625
hparams.n_embd_k_s(), n_seqs);
15151626

15161627
token_shift = ggml_reshape_3d(ctx0, token_shift, hparams.n_embd, token_shift_count, n_seqs);

0 commit comments

Comments
 (0)