File tree Expand file tree Collapse file tree 2 files changed +12
-3
lines changed Expand file tree Collapse file tree 2 files changed +12
-3
lines changed Original file line number Diff line number Diff line change @@ -508,6 +508,9 @@ extern "C" {
508
508
// to the decoder to start generating output sequence. For other models, it returns -1.
509
509
LLAMA_API llama_token llama_model_decoder_start_token (const struct llama_model * model);
510
510
511
+ // Returns true if the model is recurrent (like Mamba, RWKV, etc.)
512
+ LLAMA_API bool llama_model_is_recurrent (const struct llama_model * model);
513
+
511
514
// Returns 0 on success
512
515
LLAMA_API uint32_t llama_model_quantize (
513
516
const char * fname_inp,
Original file line number Diff line number Diff line change @@ -3292,8 +3292,7 @@ static bool llama_kv_cache_init(
3292
3292
3293
3293
cache.has_shift = false;
3294
3294
3295
- // TODO: find a nicer way to add other recurrent model architectures
3296
- cache.recurrent = model.arch == LLM_ARCH_MAMBA;
3295
+ cache.recurrent = llama_model_is_recurrent(&model);
3297
3296
cache.v_trans = !cache.recurrent && !cparams.flash_attn;
3298
3297
3299
3298
cache.head = 0;
@@ -17235,7 +17234,7 @@ struct llama_context * llama_new_context_with_model(
17235
17234
ggml_type type_v = params.type_v;
17236
17235
17237
17236
// Mamba only needs a constant number of KV cache cells per sequence
17238
- if (model->arch == LLM_ARCH_MAMBA ) {
17237
+ if (llama_model_is_recurrent( model) ) {
17239
17238
// Mamba needs at least as many KV cells as there are sequences kept at any time
17240
17239
kv_size = std::max((uint32_t) 1, params.n_seq_max);
17241
17240
// it's probably best to keep as much precision as possible for the states
@@ -17709,6 +17708,13 @@ llama_token llama_model_decoder_start_token(const struct llama_model * model) {
17709
17708
return model->hparams.dec_start_token_id;
17710
17709
}
17711
17710
17711
+ bool llama_model_is_recurrent(const struct llama_model * model) {
17712
+ switch (model->arch) {
17713
+ case LLM_ARCH_MAMBA: return true;
17714
+ default: return false;
17715
+ }
17716
+ }
17717
+
17712
17718
uint32_t llama_model_quantize(
17713
17719
const char * fname_inp,
17714
17720
const char * fname_out,
You can’t perform that action at this time.
0 commit comments