|
7 | 7 | #include "llama-kv-cache-unified.h"
|
8 | 8 | #include "llama-kv-cache-unified-iswa.h"
|
9 | 9 | #include "llama-kv-cache-recurrent.h"
|
| 10 | +#include "llama-kv-cache-hybrid-recurrent.h" |
10 | 11 |
|
11 | 12 | #include <cassert>
|
12 | 13 | #include <cmath>
|
@@ -412,6 +413,13 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
|
412 | 413 | }
|
413 | 414 | }
|
414 | 415 |
|
| 416 | +llm_graph_input_attn_kv_hybrid_recurrent::llm_graph_input_attn_kv_hybrid_recurrent( |
| 417 | + const llama_hparams & hparams, |
| 418 | + const llama_cparams & cparams, |
| 419 | + const llama_kv_cache_hybrid_recurrent_state * kv_state) : |
| 420 | + llm_graph_input_attn_kv_unified(hparams, cparams, kv_state->get_state_attn()) { |
| 421 | +} |
| 422 | + |
415 | 423 | //
|
416 | 424 | // llm_graph_context
|
417 | 425 | //
|
@@ -955,8 +963,10 @@ ggml_tensor * llm_graph_context::build_inp_cls() const {
|
955 | 963 | return cur;
|
956 | 964 | }
|
957 | 965 |
|
958 |
| -ggml_tensor * llm_graph_context::build_inp_s_copy() const { |
959 |
| - const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate); |
| 966 | +ggml_tensor * llm_graph_context::build_inp_s_copy(const llama_kv_cache_recurrent_state * kv_state) const { |
| 967 | + if (kv_state == nullptr) { |
| 968 | + kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate); |
| 969 | + } |
960 | 970 |
|
961 | 971 | auto inp = std::make_unique<llm_graph_input_s_copy>(kv_state);
|
962 | 972 |
|
@@ -1302,6 +1312,44 @@ ggml_tensor * llm_graph_context::build_attn(
|
1302 | 1312 | return cur;
|
1303 | 1313 | }
|
1304 | 1314 |
|
| 1315 | +llm_graph_input_attn_kv_hybrid_recurrent * llm_graph_context::build_attn_inp_kv_hybrid_recurrent() const { |
| 1316 | + const auto * kv_state = static_cast<const llama_kv_cache_hybrid_recurrent_state *>(mstate); |
| 1317 | + |
| 1318 | + auto inp = std::make_unique<llm_graph_input_attn_kv_hybrid_recurrent>(hparams, cparams, kv_state); |
| 1319 | + |
| 1320 | + { |
| 1321 | + GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Hybrid recurrent is not supported with SWA attention layers"); |
| 1322 | + |
| 1323 | + const auto n_kv = kv_state->get_state_attn()->get_n_kv(); |
| 1324 | + |
| 1325 | + inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); |
| 1326 | + //cb(inp->self_kq_mask, "KQ_mask", -1); |
| 1327 | + ggml_set_input(inp->self_kq_mask); |
| 1328 | + |
| 1329 | + inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; |
| 1330 | + } |
| 1331 | + |
| 1332 | + return (llm_graph_input_attn_kv_hybrid_recurrent *) res->add_input(std::move(inp)); |
| 1333 | +} |
| 1334 | + |
| 1335 | +ggml_tensor * llm_graph_context::build_attn( |
| 1336 | + llm_graph_input_attn_kv_hybrid_recurrent * inp, |
| 1337 | + ggml_cgraph * gf, |
| 1338 | + ggml_tensor * wo, |
| 1339 | + ggml_tensor * wo_b, |
| 1340 | + ggml_tensor * q_cur, |
| 1341 | + ggml_tensor * k_cur, |
| 1342 | + ggml_tensor * v_cur, |
| 1343 | + ggml_tensor * kq_b, |
| 1344 | + ggml_tensor * v_mla, |
| 1345 | + float kq_scale, |
| 1346 | + int il) const { |
| 1347 | + return build_attn( |
| 1348 | + static_cast<llm_graph_input_attn_kv_unified *>(inp), |
| 1349 | + gf, wo, wo_b, q_cur, k_cur, v_cur, kq_b, v_mla, kq_scale, il |
| 1350 | + ); |
| 1351 | +} |
| 1352 | + |
1305 | 1353 | llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
|
1306 | 1354 | const auto * kv_state = static_cast<const llama_kv_cache_unified_iswa_state *>(mstate);
|
1307 | 1355 |
|
|
0 commit comments