Skip to content

Commit 1be5ea7

Browse files
committed
llama : add llama_model_is_recurrent to simplify figuring that out
This will make it easier to more cleanly support RWKV-v6 and Mamba-2.
1 parent b264edd commit 1be5ea7

File tree

2 files changed

+12
-3
lines changed

2 files changed

+12
-3
lines changed

include/llama.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -508,6 +508,9 @@ extern "C" {
508508
// to the decoder to start generating output sequence. For other models, it returns -1.
509509
LLAMA_API llama_token llama_model_decoder_start_token(const struct llama_model * model);
510510

511+
// Returns true if the model is recurrent (like Mamba, RWKV, etc.)
512+
LLAMA_API bool llama_model_is_recurrent(const struct llama_model * model);
513+
511514
// Returns 0 on success
512515
LLAMA_API uint32_t llama_model_quantize(
513516
const char * fname_inp,

src/llama.cpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3292,8 +3292,7 @@ static bool llama_kv_cache_init(
32923292

32933293
cache.has_shift = false;
32943294

3295-
// TODO: find a nicer way to add other recurrent model architectures
3296-
cache.recurrent = model.arch == LLM_ARCH_MAMBA;
3295+
cache.recurrent = llama_model_is_recurrent(&model);
32973296
cache.v_trans = !cache.recurrent && !cparams.flash_attn;
32983297

32993298
cache.head = 0;
@@ -17235,7 +17234,7 @@ struct llama_context * llama_new_context_with_model(
1723517234
ggml_type type_v = params.type_v;
1723617235

1723717236
// Mamba only needs a constant number of KV cache cells per sequence
17238-
if (model->arch == LLM_ARCH_MAMBA) {
17237+
if (llama_model_is_recurrent(model)) {
1723917238
// Mamba needs at least as many KV cells as there are sequences kept at any time
1724017239
kv_size = std::max((uint32_t) 1, params.n_seq_max);
1724117240
// it's probably best to keep as much precision as possible for the states
@@ -17709,6 +17708,13 @@ llama_token llama_model_decoder_start_token(const struct llama_model * model) {
1770917708
return model->hparams.dec_start_token_id;
1771017709
}
1771117710

17711+
bool llama_model_is_recurrent(const struct llama_model * model) {
17712+
switch (model->arch) {
17713+
case LLM_ARCH_MAMBA: return true;
17714+
default: return false;
17715+
}
17716+
}
17717+
1771217718
uint32_t llama_model_quantize(
1771317719
const char * fname_inp,
1771417720
const char * fname_out,

0 commit comments

Comments
 (0)