Skip to content

Commit 8f62004

Browse files
committed
feat: Construct hybrid recurrent cache for hybrid recurrent models
This includes a refactor of the create_memory logic to avoid needing to use the arch enum explicitly unless a model needs explicit cache instantiation logic beyond the standard logic for recurrent, hybrid, unified, and iswa. Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
1 parent a48dad4 commit 8f62004

File tree

1 file changed

+63
-47
lines changed

1 file changed

+63
-47
lines changed

src/llama-model.cpp

Lines changed: 63 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "llama-kv-cache-unified.h"
1010
#include "llama-kv-cache-unified-iswa.h"
1111
#include "llama-kv-cache-recurrent.h"
12+
#include "llama-kv-cache-hybrid-recurrent.h"
1213

1314
#include "ggml-cpp.h"
1415

@@ -13197,6 +13198,8 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
1319713198
llama_memory_i * res;
1319813199

1319913200
switch (arch) {
13201+
// Models that need specific instantiation should be handled in the
13202+
// switch statement
1320013203
case LLM_ARCH_BERT:
1320113204
case LLM_ARCH_JINA_BERT_V2:
1320213205
case LLM_ARCH_NOMIC_BERT:
@@ -13205,58 +13208,71 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
1320513208
{
1320613209
res = nullptr;
1320713210
} break;
13208-
case LLM_ARCH_MAMBA:
13209-
case LLM_ARCH_RWKV6:
13210-
case LLM_ARCH_RWKV6QWEN2:
13211-
case LLM_ARCH_RWKV7:
13212-
case LLM_ARCH_ARWKV7:
13213-
{
13214-
res = new llama_kv_cache_recurrent(
13215-
*this,
13216-
nullptr,
13217-
GGML_TYPE_F32,
13218-
GGML_TYPE_F32,
13219-
cparams.offload_kqv,
13220-
std::max((uint32_t) 1, cparams.n_seq_max),
13221-
cparams.n_seq_max);
13222-
} break;
13211+
// Models that need standard caching should rely on recurrent/hybrid
13212+
// checks
1322313213
default:
1322413214
{
13225-
const auto padding = llama_kv_cache_unified::get_padding(cparams);
13226-
13227-
cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
13228-
13229-
LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
13230-
13231-
if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
13232-
GGML_ASSERT(hparams.is_swa_any());
13233-
13234-
res = new llama_kv_cache_unified_iswa(
13235-
*this,
13236-
params.type_k,
13237-
params.type_v,
13238-
!cparams.flash_attn,
13239-
cparams.offload_kqv,
13240-
params.swa_full,
13241-
cparams.n_ctx,
13242-
cparams.n_seq_max,
13243-
cparams.n_ubatch,
13244-
padding);
13245-
} else {
13246-
GGML_ASSERT(!hparams.is_swa_any());
13247-
13248-
res = new llama_kv_cache_unified(
13215+
if (llm_arch_is_recurrent(arch)) {
13216+
res = new llama_kv_cache_recurrent(
1324913217
*this,
1325013218
nullptr,
13251-
params.type_k,
13252-
params.type_v,
13253-
!cparams.flash_attn,
13219+
GGML_TYPE_F32,
13220+
GGML_TYPE_F32,
1325413221
cparams.offload_kqv,
13255-
cparams.n_ctx,
13256-
cparams.n_seq_max,
13257-
padding,
13258-
hparams.n_swa,
13259-
hparams.swa_type);
13222+
std::max((uint32_t) 1, cparams.n_seq_max),
13223+
cparams.n_seq_max);
13224+
} else if (llm_arch_is_hybrid_recurrent(arch)) {
13225+
res = new llama_kv_cache_hybrid_recurrent(
13226+
/* model */ *this,
13227+
/* attn_type_k */ params.type_k,
13228+
/* attn_type_v */ params.type_v,
13229+
/* attn_v_trans */ !cparams.flash_attn,
13230+
/* attn_kv_size */ cparams.n_ctx,
13231+
/* attn_n_pad */ llama_kv_cache_unified::get_padding(cparams),
13232+
/* attn_n_swa */ hparams.n_swa,
13233+
/* attn_swa_type */ hparams.swa_type,
13234+
/* recurrent_type_k */ GGML_TYPE_F32,
13235+
/* recurrent_type_v */ GGML_TYPE_F32,
13236+
/* recurrent_kv_size */ std::max((uint32_t) 1, cparams.n_seq_max),
13237+
/* n_seq_max */ cparams.n_seq_max,
13238+
/* offload */ cparams.offload_kqv);
13239+
} else {
13240+
const auto padding = llama_kv_cache_unified::get_padding(cparams);
13241+
13242+
cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
13243+
13244+
LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
13245+
13246+
if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
13247+
GGML_ASSERT(hparams.is_swa_any());
13248+
13249+
res = new llama_kv_cache_unified_iswa(
13250+
*this,
13251+
params.type_k,
13252+
params.type_v,
13253+
!cparams.flash_attn,
13254+
cparams.offload_kqv,
13255+
params.swa_full,
13256+
cparams.n_ctx,
13257+
cparams.n_seq_max,
13258+
cparams.n_ubatch,
13259+
padding);
13260+
} else {
13261+
GGML_ASSERT(!hparams.is_swa_any());
13262+
13263+
res = new llama_kv_cache_unified(
13264+
*this,
13265+
nullptr,
13266+
params.type_k,
13267+
params.type_v,
13268+
!cparams.flash_attn,
13269+
cparams.offload_kqv,
13270+
cparams.n_ctx,
13271+
cparams.n_seq_max,
13272+
padding,
13273+
hparams.n_swa,
13274+
hparams.swa_type);
13275+
}
1326013276
}
1326113277
}
1326213278
}

0 commit comments

Comments
 (0)