9
9
#include "llama-kv-cache-unified.h"
10
10
#include "llama-kv-cache-unified-iswa.h"
11
11
#include "llama-kv-cache-recurrent.h"
12
+ #include "llama-kv-cache-hybrid-recurrent.h"
12
13
13
14
#include "ggml-cpp.h"
14
15
@@ -13197,6 +13198,8 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
13197
13198
llama_memory_i * res;
13198
13199
13199
13200
switch (arch) {
13201
+ // Models that need specific instantiation should be handled in the
13202
+ // switch statement
13200
13203
case LLM_ARCH_BERT:
13201
13204
case LLM_ARCH_JINA_BERT_V2:
13202
13205
case LLM_ARCH_NOMIC_BERT:
@@ -13205,58 +13208,71 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
13205
13208
{
13206
13209
res = nullptr;
13207
13210
} break;
13208
- case LLM_ARCH_MAMBA:
13209
- case LLM_ARCH_RWKV6:
13210
- case LLM_ARCH_RWKV6QWEN2:
13211
- case LLM_ARCH_RWKV7:
13212
- case LLM_ARCH_ARWKV7:
13213
- {
13214
- res = new llama_kv_cache_recurrent(
13215
- *this,
13216
- nullptr,
13217
- GGML_TYPE_F32,
13218
- GGML_TYPE_F32,
13219
- cparams.offload_kqv,
13220
- std::max((uint32_t) 1, cparams.n_seq_max),
13221
- cparams.n_seq_max);
13222
- } break;
13211
+ // Models that need standard caching should rely on recurrent/hybrid
13212
+ // checks
13223
13213
default:
13224
13214
{
13225
- const auto padding = llama_kv_cache_unified::get_padding(cparams);
13226
-
13227
- cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
13228
-
13229
- LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
13230
-
13231
- if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
13232
- GGML_ASSERT(hparams.is_swa_any());
13233
-
13234
- res = new llama_kv_cache_unified_iswa(
13235
- *this,
13236
- params.type_k,
13237
- params.type_v,
13238
- !cparams.flash_attn,
13239
- cparams.offload_kqv,
13240
- params.swa_full,
13241
- cparams.n_ctx,
13242
- cparams.n_seq_max,
13243
- cparams.n_ubatch,
13244
- padding);
13245
- } else {
13246
- GGML_ASSERT(!hparams.is_swa_any());
13247
-
13248
- res = new llama_kv_cache_unified(
13215
+ if (llm_arch_is_recurrent(arch)) {
13216
+ res = new llama_kv_cache_recurrent(
13249
13217
*this,
13250
13218
nullptr,
13251
- params.type_k,
13252
- params.type_v,
13253
- !cparams.flash_attn,
13219
+ GGML_TYPE_F32,
13220
+ GGML_TYPE_F32,
13254
13221
cparams.offload_kqv,
13255
- cparams.n_ctx,
13256
- cparams.n_seq_max,
13257
- padding,
13258
- hparams.n_swa,
13259
- hparams.swa_type);
13222
+ std::max((uint32_t) 1, cparams.n_seq_max),
13223
+ cparams.n_seq_max);
13224
+ } else if (llm_arch_is_hybrid_recurrent(arch)) {
13225
+ res = new llama_kv_cache_hybrid_recurrent(
13226
+ /* model */ *this,
13227
+ /* attn_type_k */ params.type_k,
13228
+ /* attn_type_v */ params.type_v,
13229
+ /* attn_v_trans */ !cparams.flash_attn,
13230
+ /* attn_kv_size */ cparams.n_ctx,
13231
+ /* attn_n_pad */ llama_kv_cache_unified::get_padding(cparams),
13232
+ /* attn_n_swa */ hparams.n_swa,
13233
+ /* attn_swa_type */ hparams.swa_type,
13234
+ /* recurrent_type_k */ GGML_TYPE_F32,
13235
+ /* recurrent_type_v */ GGML_TYPE_F32,
13236
+ /* recurrent_kv_size */ std::max((uint32_t) 1, cparams.n_seq_max),
13237
+ /* n_seq_max */ cparams.n_seq_max,
13238
+ /* offload */ cparams.offload_kqv);
13239
+ } else {
13240
+ const auto padding = llama_kv_cache_unified::get_padding(cparams);
13241
+
13242
+ cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
13243
+
13244
+ LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
13245
+
13246
+ if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
13247
+ GGML_ASSERT(hparams.is_swa_any());
13248
+
13249
+ res = new llama_kv_cache_unified_iswa(
13250
+ *this,
13251
+ params.type_k,
13252
+ params.type_v,
13253
+ !cparams.flash_attn,
13254
+ cparams.offload_kqv,
13255
+ params.swa_full,
13256
+ cparams.n_ctx,
13257
+ cparams.n_seq_max,
13258
+ cparams.n_ubatch,
13259
+ padding);
13260
+ } else {
13261
+ GGML_ASSERT(!hparams.is_swa_any());
13262
+
13263
+ res = new llama_kv_cache_unified(
13264
+ *this,
13265
+ nullptr,
13266
+ params.type_k,
13267
+ params.type_v,
13268
+ !cparams.flash_attn,
13269
+ cparams.offload_kqv,
13270
+ cparams.n_ctx,
13271
+ cparams.n_seq_max,
13272
+ padding,
13273
+ hparams.n_swa,
13274
+ hparams.swa_type);
13275
+ }
13260
13276
}
13261
13277
}
13262
13278
}
0 commit comments