9
9
#include "llama-kv-cache-unified.h"
10
10
#include "llama-kv-cache-unified-iswa.h"
11
11
#include "llama-kv-cache-recurrent.h"
12
+ #include "llama-kv-cache-hybrid-recurrent.h"
12
13
13
14
#include "ggml-cpp.h"
14
15
@@ -13211,6 +13212,8 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
13211
13212
llama_memory_i * res;
13212
13213
13213
13214
switch (arch) {
13215
+ // Models that need specific instantiation should be handled in the
13216
+ // switch statement
13214
13217
case LLM_ARCH_BERT:
13215
13218
case LLM_ARCH_JINA_BERT_V2:
13216
13219
case LLM_ARCH_NOMIC_BERT:
@@ -13219,58 +13222,71 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
13219
13222
{
13220
13223
res = nullptr;
13221
13224
} break;
13222
- case LLM_ARCH_MAMBA:
13223
- case LLM_ARCH_RWKV6:
13224
- case LLM_ARCH_RWKV6QWEN2:
13225
- case LLM_ARCH_RWKV7:
13226
- case LLM_ARCH_ARWKV7:
13227
- {
13228
- res = new llama_kv_cache_recurrent(
13229
- *this,
13230
- nullptr,
13231
- GGML_TYPE_F32,
13232
- GGML_TYPE_F32,
13233
- cparams.offload_kqv,
13234
- std::max((uint32_t) 1, cparams.n_seq_max),
13235
- cparams.n_seq_max);
13236
- } break;
13225
+ // Models that need standard caching should rely on recurrent/hybrid
13226
+ // checks
13237
13227
default:
13238
13228
{
13239
- const auto padding = llama_kv_cache_unified::get_padding(cparams);
13240
-
13241
- cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
13242
-
13243
- LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
13244
-
13245
- if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
13246
- GGML_ASSERT(hparams.is_swa_any());
13247
-
13248
- res = new llama_kv_cache_unified_iswa(
13249
- *this,
13250
- params.type_k,
13251
- params.type_v,
13252
- !cparams.flash_attn,
13253
- cparams.offload_kqv,
13254
- params.swa_full,
13255
- cparams.n_ctx,
13256
- cparams.n_seq_max,
13257
- cparams.n_ubatch,
13258
- padding);
13259
- } else {
13260
- GGML_ASSERT(!hparams.is_swa_any());
13261
-
13262
- res = new llama_kv_cache_unified(
13229
+ if (llm_arch_is_recurrent(arch)) {
13230
+ res = new llama_kv_cache_recurrent(
13263
13231
*this,
13264
13232
nullptr,
13265
- params.type_k,
13266
- params.type_v,
13267
- !cparams.flash_attn,
13233
+ GGML_TYPE_F32,
13234
+ GGML_TYPE_F32,
13268
13235
cparams.offload_kqv,
13269
- cparams.n_ctx,
13270
- cparams.n_seq_max,
13271
- padding,
13272
- hparams.n_swa,
13273
- hparams.swa_type);
13236
+ std::max((uint32_t) 1, cparams.n_seq_max),
13237
+ cparams.n_seq_max);
13238
+ } else if (llm_arch_is_hybrid_recurrent(arch)) {
13239
+ res = new llama_kv_cache_hybrid_recurrent(
13240
+ /* model */ *this,
13241
+ /* attn_type_k */ params.type_k,
13242
+ /* attn_type_v */ params.type_v,
13243
+ /* attn_v_trans */ !cparams.flash_attn,
13244
+ /* attn_kv_size */ cparams.n_ctx,
13245
+ /* attn_n_pad */ llama_kv_cache_unified::get_padding(cparams),
13246
+ /* attn_n_swa */ hparams.n_swa,
13247
+ /* attn_swa_type */ hparams.swa_type,
13248
+ /* recurrent_type_k */ GGML_TYPE_F32,
13249
+ /* recurrent_type_v */ GGML_TYPE_F32,
13250
+ /* recurrent_kv_size */ std::max((uint32_t) 1, cparams.n_seq_max),
13251
+ /* n_seq_max */ cparams.n_seq_max,
13252
+ /* offload */ cparams.offload_kqv);
13253
+ } else {
13254
+ const auto padding = llama_kv_cache_unified::get_padding(cparams);
13255
+
13256
+ cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
13257
+
13258
+ LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
13259
+
13260
+ if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
13261
+ GGML_ASSERT(hparams.is_swa_any());
13262
+
13263
+ res = new llama_kv_cache_unified_iswa(
13264
+ *this,
13265
+ params.type_k,
13266
+ params.type_v,
13267
+ !cparams.flash_attn,
13268
+ cparams.offload_kqv,
13269
+ params.swa_full,
13270
+ cparams.n_ctx,
13271
+ cparams.n_seq_max,
13272
+ cparams.n_ubatch,
13273
+ padding);
13274
+ } else {
13275
+ GGML_ASSERT(!hparams.is_swa_any());
13276
+
13277
+ res = new llama_kv_cache_unified(
13278
+ *this,
13279
+ nullptr,
13280
+ params.type_k,
13281
+ params.type_v,
13282
+ !cparams.flash_attn,
13283
+ cparams.offload_kqv,
13284
+ cparams.n_ctx,
13285
+ cparams.n_seq_max,
13286
+ padding,
13287
+ hparams.n_swa,
13288
+ hparams.swa_type);
13289
+ }
13274
13290
}
13275
13291
}
13276
13292
}
0 commit comments