Skip to content

Commit fe39803

Browse files
committed
feat: Construct hybrid recurrent cache for hybrid recurrent models
This includes a refactor of the create_memory logic to avoid needing to use the arch enum explicitly unless a model needs explicit cache instantiation logic beyond the standard logic for recurrent, hybrid, unified, and iswa. Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
1 parent 6a822b7 commit fe39803

File tree

1 file changed

+63
-47
lines changed

1 file changed

+63
-47
lines changed

src/llama-model.cpp

Lines changed: 63 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "llama-kv-cache-unified.h"
1010
#include "llama-kv-cache-unified-iswa.h"
1111
#include "llama-kv-cache-recurrent.h"
12+
#include "llama-kv-cache-hybrid-recurrent.h"
1213

1314
#include "ggml-cpp.h"
1415

@@ -13211,6 +13212,8 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
1321113212
llama_memory_i * res;
1321213213

1321313214
switch (arch) {
13215+
// Models that need specific instantiation should be handled in the
13216+
// switch statement
1321413217
case LLM_ARCH_BERT:
1321513218
case LLM_ARCH_JINA_BERT_V2:
1321613219
case LLM_ARCH_NOMIC_BERT:
@@ -13219,58 +13222,71 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
1321913222
{
1322013223
res = nullptr;
1322113224
} break;
13222-
case LLM_ARCH_MAMBA:
13223-
case LLM_ARCH_RWKV6:
13224-
case LLM_ARCH_RWKV6QWEN2:
13225-
case LLM_ARCH_RWKV7:
13226-
case LLM_ARCH_ARWKV7:
13227-
{
13228-
res = new llama_kv_cache_recurrent(
13229-
*this,
13230-
nullptr,
13231-
GGML_TYPE_F32,
13232-
GGML_TYPE_F32,
13233-
cparams.offload_kqv,
13234-
std::max((uint32_t) 1, cparams.n_seq_max),
13235-
cparams.n_seq_max);
13236-
} break;
13225+
// Models that need standard caching should rely on recurrent/hybrid
13226+
// checks
1323713227
default:
1323813228
{
13239-
const auto padding = llama_kv_cache_unified::get_padding(cparams);
13240-
13241-
cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
13242-
13243-
LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
13244-
13245-
if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
13246-
GGML_ASSERT(hparams.is_swa_any());
13247-
13248-
res = new llama_kv_cache_unified_iswa(
13249-
*this,
13250-
params.type_k,
13251-
params.type_v,
13252-
!cparams.flash_attn,
13253-
cparams.offload_kqv,
13254-
params.swa_full,
13255-
cparams.n_ctx,
13256-
cparams.n_seq_max,
13257-
cparams.n_ubatch,
13258-
padding);
13259-
} else {
13260-
GGML_ASSERT(!hparams.is_swa_any());
13261-
13262-
res = new llama_kv_cache_unified(
13229+
if (llm_arch_is_recurrent(arch)) {
13230+
res = new llama_kv_cache_recurrent(
1326313231
*this,
1326413232
nullptr,
13265-
params.type_k,
13266-
params.type_v,
13267-
!cparams.flash_attn,
13233+
GGML_TYPE_F32,
13234+
GGML_TYPE_F32,
1326813235
cparams.offload_kqv,
13269-
cparams.n_ctx,
13270-
cparams.n_seq_max,
13271-
padding,
13272-
hparams.n_swa,
13273-
hparams.swa_type);
13236+
std::max((uint32_t) 1, cparams.n_seq_max),
13237+
cparams.n_seq_max);
13238+
} else if (llm_arch_is_hybrid_recurrent(arch)) {
13239+
res = new llama_kv_cache_hybrid_recurrent(
13240+
/* model */ *this,
13241+
/* attn_type_k */ params.type_k,
13242+
/* attn_type_v */ params.type_v,
13243+
/* attn_v_trans */ !cparams.flash_attn,
13244+
/* attn_kv_size */ cparams.n_ctx,
13245+
/* attn_n_pad */ llama_kv_cache_unified::get_padding(cparams),
13246+
/* attn_n_swa */ hparams.n_swa,
13247+
/* attn_swa_type */ hparams.swa_type,
13248+
/* recurrent_type_k */ GGML_TYPE_F32,
13249+
/* recurrent_type_v */ GGML_TYPE_F32,
13250+
/* recurrent_kv_size */ std::max((uint32_t) 1, cparams.n_seq_max),
13251+
/* n_seq_max */ cparams.n_seq_max,
13252+
/* offload */ cparams.offload_kqv);
13253+
} else {
13254+
const auto padding = llama_kv_cache_unified::get_padding(cparams);
13255+
13256+
cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
13257+
13258+
LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
13259+
13260+
if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
13261+
GGML_ASSERT(hparams.is_swa_any());
13262+
13263+
res = new llama_kv_cache_unified_iswa(
13264+
*this,
13265+
params.type_k,
13266+
params.type_v,
13267+
!cparams.flash_attn,
13268+
cparams.offload_kqv,
13269+
params.swa_full,
13270+
cparams.n_ctx,
13271+
cparams.n_seq_max,
13272+
cparams.n_ubatch,
13273+
padding);
13274+
} else {
13275+
GGML_ASSERT(!hparams.is_swa_any());
13276+
13277+
res = new llama_kv_cache_unified(
13278+
*this,
13279+
nullptr,
13280+
params.type_k,
13281+
params.type_v,
13282+
!cparams.flash_attn,
13283+
cparams.offload_kqv,
13284+
cparams.n_ctx,
13285+
cparams.n_seq_max,
13286+
padding,
13287+
hparams.n_swa,
13288+
hparams.swa_type);
13289+
}
1327413290
}
1327513291
}
1327613292
}

0 commit comments

Comments
 (0)