|
7 | 7 | #include "llama-kv-cache-unified.h"
|
8 | 8 | #include "llama-kv-cache-unified-iswa.h"
|
9 | 9 | #include "llama-kv-cache-recurrent.h"
|
| 10 | +#include "llama-kv-cache-hybrid-recurrent.h" |
10 | 11 |
|
11 | 12 | #include <cassert>
|
12 | 13 | #include <cmath>
|
@@ -412,6 +413,13 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
|
412 | 413 | }
|
413 | 414 | }
|
414 | 415 |
|
| 416 | +llm_graph_input_attn_kv_hybrid_recurrent::llm_graph_input_attn_kv_hybrid_recurrent( |
| 417 | + const llama_hparams & hparams, |
| 418 | + const llama_cparams & cparams, |
| 419 | + const llama_kv_cache_hybrid_recurrent_state * kv_state) : |
| 420 | + llm_graph_input_attn_kv_unified(hparams, cparams, kv_state->get_state_attn()) { |
| 421 | +} |
| 422 | + |
415 | 423 | //
|
416 | 424 | // llm_graph_context
|
417 | 425 | //
|
@@ -969,8 +977,10 @@ ggml_tensor * llm_graph_context::build_inp_cls() const {
|
969 | 977 | return cur;
|
970 | 978 | }
|
971 | 979 |
|
972 |
| -ggml_tensor * llm_graph_context::build_inp_s_copy() const { |
973 |
| - const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate); |
| 980 | +ggml_tensor * llm_graph_context::build_inp_s_copy(const llama_kv_cache_recurrent_state * kv_state) const { |
| 981 | + if (kv_state == nullptr) { |
| 982 | + kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate); |
| 983 | + } |
974 | 984 |
|
975 | 985 | auto inp = std::make_unique<llm_graph_input_s_copy>(kv_state);
|
976 | 986 |
|
@@ -1316,6 +1326,44 @@ ggml_tensor * llm_graph_context::build_attn(
|
1316 | 1326 | return cur;
|
1317 | 1327 | }
|
1318 | 1328 |
|
| 1329 | +llm_graph_input_attn_kv_hybrid_recurrent * llm_graph_context::build_attn_inp_kv_hybrid_recurrent() const { |
| 1330 | + const auto * kv_state = static_cast<const llama_kv_cache_hybrid_recurrent_state *>(mstate); |
| 1331 | + |
| 1332 | + auto inp = std::make_unique<llm_graph_input_attn_kv_hybrid_recurrent>(hparams, cparams, kv_state); |
| 1333 | + |
| 1334 | + { |
| 1335 | + GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Hybrid recurrent is not supported with SWA attention layers"); |
| 1336 | + |
| 1337 | + const auto n_kv = kv_state->get_state_attn()->get_n_kv(); |
| 1338 | + |
| 1339 | + inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); |
| 1340 | + //cb(inp->self_kq_mask, "KQ_mask", -1); |
| 1341 | + ggml_set_input(inp->self_kq_mask); |
| 1342 | + |
| 1343 | + inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; |
| 1344 | + } |
| 1345 | + |
| 1346 | + return (llm_graph_input_attn_kv_hybrid_recurrent *) res->add_input(std::move(inp)); |
| 1347 | +} |
| 1348 | + |
| 1349 | +ggml_tensor * llm_graph_context::build_attn( |
| 1350 | + llm_graph_input_attn_kv_hybrid_recurrent * inp, |
| 1351 | + ggml_cgraph * gf, |
| 1352 | + ggml_tensor * wo, |
| 1353 | + ggml_tensor * wo_b, |
| 1354 | + ggml_tensor * q_cur, |
| 1355 | + ggml_tensor * k_cur, |
| 1356 | + ggml_tensor * v_cur, |
| 1357 | + ggml_tensor * kq_b, |
| 1358 | + ggml_tensor * v_mla, |
| 1359 | + float kq_scale, |
| 1360 | + int il) const { |
| 1361 | + return build_attn( |
| 1362 | + static_cast<llm_graph_input_attn_kv_unified *>(inp), |
| 1363 | + gf, wo, wo_b, q_cur, k_cur, v_cur, kq_b, v_mla, kq_scale, il |
| 1364 | + ); |
| 1365 | +} |
| 1366 | + |
1319 | 1367 | llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
|
1320 | 1368 | const auto * kv_state = static_cast<const llama_kv_cache_unified_iswa_state *>(mstate);
|
1321 | 1369 |
|
|
0 commit comments