Skip to content

Commit bdfdbbf

Browse files
committed
feat: Overhaul build_recurrent_state / build_inp_s_copy to match attention pattern
https://github.com/ggml-org/llama.cpp/pull/13979/files#r2141701738 This is a big overhaul to bring consistency between how inputs and per- layer components are created for attention layers and recurrent layers. The main changes are: - Rename class llm_graph_input_s_copy -> llm_graph_input_rs - Add a corresponding llm_graph_input_rs_hybrid_recurrent - Rename build_inp_s_copy -> build_rs_inp_recurrent - Add a corresponding build_rs_inp_hybrid_recurrent - Rename build_recurrent_state -> build_rs to match build_attn w/ llm_graph_input_rs android-build AUTHORS bamba-9b-2.2T.gguf bamba-9b-2.2T.q4_k_m.gguf broken.log build build-rel build-xcframework.sh build.android build.android.bak ci cmake CMakeLists.txt CMakePresets.json CODEOWNERS common common.o CONTRIBUTING.md convert_hf_to_gguf_update.py convert_hf_to_gguf.py convert_llama_ggml_to_gguf.py convert_lora_to_gguf.py debug.log docs examples flake.lock flake.nix ggml ggml-alloc.o ggml-backend.o ggml-metal.o ggml-model-BF16.gguf ggml-model-Q4_K_M.gguf ggml-quants.o ggml.o gguf-py grammar-parser.o grammars include LICENSE licenses llama.log llama.o llamacpp_trace.log main.log Makefile media models mypy.ini pocs poetry.lock prompts pyproject.toml pyrightconfig.json q4_k_m_boot.log q8_0_boot.log quant.log quant2.log README.md requirements requirements.txt sampling.o scripts SECURITY.md src test-grammar-output.tmp test-json-schema-input.tmp tests tools vendor working.log as the first input - Add a corresponding overload of build_rs w/ llm_graph_input_rs_hybrid_recurrent android-build AUTHORS bamba-9b-2.2T.gguf bamba-9b-2.2T.q4_k_m.gguf broken.log build build-rel build-xcframework.sh build.android build.android.bak ci cmake CMakeLists.txt CMakePresets.json CODEOWNERS common common.o CONTRIBUTING.md convert_hf_to_gguf_update.py convert_hf_to_gguf.py convert_llama_ggml_to_gguf.py convert_lora_to_gguf.py debug.log docs examples flake.lock flake.nix ggml ggml-alloc.o ggml-backend.o ggml-metal.o ggml-model-BF16.gguf ggml-model-Q4_K_M.gguf ggml-quants.o ggml.o gguf-py grammar-parser.o grammars include LICENSE licenses llama.log llama.o llamacpp_trace.log main.log Makefile media models mypy.ini pocs poetry.lock prompts pyproject.toml pyrightconfig.json q4_k_m_boot.log q8_0_boot.log quant.log quant2.log README.md requirements requirements.txt sampling.o scripts SECURITY.md src test-grammar-output.tmp test-json-schema-input.tmp tests tools vendor working.log as the first input - Add a llm_graph_input_attn_kv_hybrid_recurrent analogous to llm_graph_input_attn_kv_unified - Add a build_attn override that takes llm_graph_input_attn_kv_hybrid_recurrent android-build AUTHORS bamba-9b-2.2T.gguf bamba-9b-2.2T.q4_k_m.gguf broken.log build build-rel build-xcframework.sh build.android build.android.bak ci cmake CMakeLists.txt CMakePresets.json CODEOWNERS common common.o CONTRIBUTING.md convert_hf_to_gguf_update.py convert_hf_to_gguf.py convert_llama_ggml_to_gguf.py convert_lora_to_gguf.py debug.log docs examples flake.lock flake.nix ggml ggml-alloc.o ggml-backend.o ggml-metal.o ggml-model-BF16.gguf ggml-model-Q4_K_M.gguf ggml-quants.o ggml.o gguf-py grammar-parser.o grammars include LICENSE licenses llama.log llama.o llamacpp_trace.log main.log Makefile media models mypy.ini pocs poetry.lock prompts pyproject.toml pyrightconfig.json q4_k_m_boot.log q8_0_boot.log quant.log quant2.log README.md requirements requirements.txt sampling.o scripts SECURITY.md src test-grammar-output.tmp test-json-schema-input.tmp tests tools vendor working.log as the first input This makes the two paradigms fully consistent. The main drawback is the code duplication in the build_attn and build_rs implementations where the only difference between implementations is how they cast the memory state. Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
1 parent 8af833b commit bdfdbbf

File tree

3 files changed

+240
-100
lines changed

3 files changed

+240
-100
lines changed

src/llama-graph.cpp

Lines changed: 157 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
239239
}
240240
}
241241

242-
void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) {
242+
void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) {
243243
GGML_UNUSED(ubatch);
244244

245245
const int64_t n_kv = kv_state->get_n_kv();
@@ -255,6 +255,11 @@ void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) {
255255
}
256256
}
257257

258+
llm_graph_input_rs_hybrid_recurrent::llm_graph_input_rs_hybrid_recurrent(
259+
const llama_kv_cache_hybrid_recurrent_state * kv_state) :
260+
llm_graph_input_rs(kv_state->get_state_recurrent()) {
261+
}
262+
258263
void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
259264
GGML_UNUSED(ubatch);
260265

@@ -360,6 +365,13 @@ void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
360365
}
361366
}
362367

368+
llm_graph_input_attn_kv_hybrid_recurrent::llm_graph_input_attn_kv_hybrid_recurrent(
369+
const llama_hparams & hparams,
370+
const llama_cparams & cparams,
371+
const llama_kv_cache_hybrid_recurrent_state * kv_state) :
372+
llm_graph_input_attn_kv_unified(hparams, cparams, kv_state->get_state_attn()) {
373+
}
374+
363375
void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
364376
if (self_kq_mask) {
365377
kv_state->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
@@ -962,25 +974,6 @@ ggml_tensor * llm_graph_context::build_inp_cls() const {
962974
return cur;
963975
}
964976

965-
ggml_tensor * llm_graph_context::build_inp_s_copy(const llama_kv_cache_recurrent_state * kv_state) const {
966-
if (kv_state == nullptr) {
967-
kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
968-
}
969-
970-
auto inp = std::make_unique<llm_graph_input_s_copy>(kv_state);
971-
972-
const auto n_kv = kv_state->get_n_kv();
973-
974-
auto & cur = inp->s_copy;
975-
976-
cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_kv);
977-
ggml_set_input(cur);
978-
979-
res->add_input(std::move(inp));
980-
981-
return cur;
982-
}
983-
984977
ggml_tensor * llm_graph_context::build_inp_cross_embd() const {
985978
auto inp = std::make_unique<llm_graph_input_cross_embd>(cross);
986979

@@ -1262,9 +1255,7 @@ ggml_tensor * llm_graph_context::build_attn(
12621255
ggml_build_forward_expand(gf, k_cur);
12631256
ggml_build_forward_expand(gf, v_cur);
12641257

1265-
// NOTE: For hybrid caches, this may be a child of mstate, so we use the one
1266-
// encapsulated in inp
1267-
const auto * kv_state = inp->kv_state;
1258+
const auto * kv_state = static_cast<const llama_kv_cache_unified_state *>(mstate);
12681259

12691260
// store to KV cache
12701261
{
@@ -1296,15 +1287,14 @@ ggml_tensor * llm_graph_context::build_attn(
12961287
return cur;
12971288
}
12981289

1299-
llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_hybrid_recurrent() const {
1300-
const auto * kv_state = static_cast<const llama_kv_cache_hybrid_recurrent_state *>(mstate);
1301-
1302-
auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, kv_state->get_state_attn());
1290+
llm_graph_input_attn_kv_hybrid_recurrent * llm_graph_context::build_attn_inp_kv_hybrid_recurrent() const {
1291+
auto inp = std::make_unique<llm_graph_input_attn_kv_hybrid_recurrent>(
1292+
hparams, cparams, static_cast<const llama_kv_cache_hybrid_recurrent_state *>(mstate));
13031293

13041294
{
13051295
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Hybrid recurrent is not supported with SWA attention layers");
13061296

1307-
const auto n_kv = kv_state->get_state_attn()->get_n_kv();
1297+
const auto n_kv = inp->kv_state->get_n_kv();
13081298

13091299
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
13101300
//cb(inp->self_kq_mask, "KQ_mask", -1);
@@ -1313,7 +1303,57 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_hybrid_re
13131303
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
13141304
}
13151305

1316-
return (llm_graph_input_attn_kv_unified *) res->add_input(std::move(inp));
1306+
return (llm_graph_input_attn_kv_hybrid_recurrent *) res->add_input(std::move(inp));
1307+
}
1308+
1309+
ggml_tensor * llm_graph_context::build_attn(
1310+
llm_graph_input_attn_kv_hybrid_recurrent * inp,
1311+
ggml_cgraph * gf,
1312+
ggml_tensor * wo,
1313+
ggml_tensor * wo_b,
1314+
ggml_tensor * q_cur,
1315+
ggml_tensor * k_cur,
1316+
ggml_tensor * v_cur,
1317+
ggml_tensor * kq_b,
1318+
ggml_tensor * v_mla,
1319+
float kq_scale,
1320+
int il) const {
1321+
// these nodes are added to the graph together so that they are not reordered
1322+
// by doing so, the number of splits in the graph is reduced
1323+
ggml_build_forward_expand(gf, q_cur);
1324+
ggml_build_forward_expand(gf, k_cur);
1325+
ggml_build_forward_expand(gf, v_cur);
1326+
1327+
const auto * kv_state = static_cast<const llama_kv_cache_hybrid_recurrent_state *>(mstate)->get_state_attn();
1328+
1329+
// store to KV cache
1330+
{
1331+
ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il));
1332+
ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il));
1333+
}
1334+
1335+
const auto & kq_mask = inp->get_kq_mask();
1336+
1337+
ggml_tensor * q = q_cur;
1338+
ggml_tensor * k = kv_state->get_k(ctx0, il);
1339+
ggml_tensor * v = kv_state->get_v(ctx0, il);
1340+
1341+
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1342+
cb(cur, "kqv_out", il);
1343+
1344+
if (wo) {
1345+
cur = build_lora_mm(wo, cur);
1346+
if (arch == LLM_ARCH_GLM4) {
1347+
// GLM4 seems to have numerical issues with half-precision accumulators
1348+
ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
1349+
}
1350+
}
1351+
1352+
if (wo_b) {
1353+
cur = ggml_add(ctx0, cur, wo_b);
1354+
}
1355+
1356+
return cur;
13171357
}
13181358

13191359
llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
@@ -1455,19 +1495,90 @@ ggml_tensor * llm_graph_context::build_attn(
14551495
return cur;
14561496
}
14571497

1458-
ggml_tensor * llm_graph_context::build_recurrent_state(
1459-
ggml_cgraph * gf,
1460-
ggml_tensor * s,
1461-
ggml_tensor * state_copy,
1462-
int32_t state_size,
1463-
int32_t n_seqs,
1464-
bool avoid_copies,
1465-
const llama_kv_cache_recurrent_state * kv_state) const {
1498+
llm_graph_input_rs * llm_graph_context::build_rs_inp_recurrent() const {
1499+
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
1500+
1501+
auto inp = std::make_unique<llm_graph_input_rs>(kv_state);
1502+
1503+
const auto n_kv = kv_state->get_n_kv();
1504+
1505+
auto & cur = inp->s_copy;
1506+
1507+
cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_kv);
1508+
ggml_set_input(cur);
1509+
1510+
return (llm_graph_input_rs *) res->add_input(std::move(inp));
1511+
}
1512+
1513+
ggml_tensor * llm_graph_context::build_rs(
1514+
llm_graph_input_rs * inp,
1515+
ggml_cgraph * gf,
1516+
ggml_tensor * s,
1517+
int32_t state_size,
1518+
int32_t n_seqs,
1519+
bool avoid_copies) const {
1520+
1521+
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
1522+
1523+
const auto n_kv = kv_state->get_n_kv();
1524+
const auto kv_head = kv_state->get_head();
1525+
const auto rs_zero = kv_state->get_rs_z();
1526+
1527+
ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, kv_state->get_size());
1528+
1529+
// Clear a single state which will then be copied to the other cleared states.
1530+
// Note that this is a no-op when the view is zero-sized.
1531+
ggml_tensor * state_zero = ggml_view_1d(ctx0, states, state_size*(rs_zero >= 0), rs_zero*states->nb[1]*(rs_zero >= 0));
1532+
ggml_build_forward_expand(gf, ggml_scale_inplace(ctx0, state_zero, 0));
1533+
1534+
ggml_tensor * output_states;
14661535

1467-
if (kv_state == nullptr) {
1468-
kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
1536+
if (!avoid_copies) {
1537+
// copy states
1538+
// NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
1539+
// {state_size, kv_size} -> {state_size, n_seqs}
1540+
output_states = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, inp->s_copy, n_seqs, 0));
1541+
ggml_build_forward_expand(gf, output_states);
1542+
} else {
1543+
// FIXME: make the gathering operation happen before the copy below
1544+
// (maybe with an optional lambda function passed as a parameter instead of `avoid_copies`?)
1545+
output_states = states;
14691546
}
14701547

1548+
// copy extra states which won't be changed further (between n_seqs and n_kv)
1549+
ggml_tensor * states_extra = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, inp->s_copy, n_kv - n_seqs, n_seqs*inp->s_copy->nb[0]));
1550+
ggml_build_forward_expand(gf,
1551+
ggml_cpy(ctx0,
1552+
states_extra,
1553+
ggml_view_1d(ctx0, s, state_size*(n_kv - n_seqs), (kv_head + n_seqs)*state_size*ggml_element_size(s))));
1554+
1555+
return output_states;
1556+
}
1557+
1558+
llm_graph_input_rs_hybrid_recurrent * llm_graph_context::build_rs_inp_hybrid_recurrent() const {
1559+
auto inp = std::make_unique<llm_graph_input_rs_hybrid_recurrent>(
1560+
static_cast<const llama_kv_cache_hybrid_recurrent_state *>(mstate));
1561+
1562+
const auto n_kv = inp->kv_state->get_n_kv();
1563+
1564+
auto & cur = inp->s_copy;
1565+
1566+
cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_kv);
1567+
ggml_set_input(cur);
1568+
1569+
return (llm_graph_input_rs_hybrid_recurrent *) res->add_input(std::move(inp));
1570+
}
1571+
1572+
ggml_tensor * llm_graph_context::build_rs(
1573+
llm_graph_input_rs_hybrid_recurrent * inp,
1574+
ggml_cgraph * gf,
1575+
ggml_tensor * s,
1576+
int32_t state_size,
1577+
int32_t n_seqs,
1578+
bool avoid_copies) const {
1579+
1580+
const auto * kv_state = static_cast<const llama_kv_cache_hybrid_recurrent_state *>(mstate)->get_state_recurrent();
1581+
14711582
const auto n_kv = kv_state->get_n_kv();
14721583
const auto kv_head = kv_state->get_head();
14731584
const auto rs_zero = kv_state->get_rs_z();
@@ -1485,7 +1596,7 @@ ggml_tensor * llm_graph_context::build_recurrent_state(
14851596
// copy states
14861597
// NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
14871598
// {state_size, kv_size} -> {state_size, n_seqs}
1488-
output_states = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_seqs, 0));
1599+
output_states = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, inp->s_copy, n_seqs, 0));
14891600
ggml_build_forward_expand(gf, output_states);
14901601
} else {
14911602
// FIXME: make the gathering operation happen before the copy below
@@ -1494,7 +1605,7 @@ ggml_tensor * llm_graph_context::build_recurrent_state(
14941605
}
14951606

14961607
// copy extra states which won't be changed further (between n_seqs and n_kv)
1497-
ggml_tensor * states_extra = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_kv - n_seqs, n_seqs*state_copy->nb[0]));
1608+
ggml_tensor * states_extra = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, inp->s_copy, n_kv - n_seqs, n_seqs*inp->s_copy->nb[0]));
14981609
ggml_build_forward_expand(gf,
14991610
ggml_cpy(ctx0,
15001611
states_extra,
@@ -1504,9 +1615,9 @@ ggml_tensor * llm_graph_context::build_recurrent_state(
15041615
}
15051616

15061617
ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
1507-
ggml_cgraph * gf,
1508-
ggml_tensor * state_copy,
1509-
const llama_ubatch & ubatch,
1618+
llm_graph_input_rs * inp,
1619+
ggml_cgraph * gf,
1620+
const llama_ubatch & ubatch,
15101621
int il) const {
15111622
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
15121623

@@ -1516,8 +1627,8 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
15161627

15171628
ggml_tensor * token_shift_all = kv_state->get_k_l(il);
15181629

1519-
ggml_tensor * token_shift = build_recurrent_state(
1520-
gf, token_shift_all, state_copy,
1630+
ggml_tensor * token_shift = build_rs(
1631+
inp, gf, token_shift_all,
15211632
hparams.n_embd_k_s(), n_seqs);
15221633

15231634
token_shift = ggml_reshape_3d(ctx0, token_shift, hparams.n_embd, token_shift_count, n_seqs);

0 commit comments

Comments
 (0)