Skip to content

Commit 7c7cfce

Browse files
committed
feat: Add llama_model_is_hybrid API call
Also, split llama_model_is_recurrent into llm_arch_is_recurrent in llama-arch with llama_model_is_recurrent delegating to llm_arch_is_recurrent. The same split is done for hybird. This is needed because there are places where the llama_model has not yet been initialized but we need to check if the model is recurrent (specifically for the per-layer recurrent check array in hparams). Branch: GraniteFour Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
1 parent 2589ad3 commit 7c7cfce

File tree

4 files changed

+33
-8
lines changed

4 files changed

+33
-8
lines changed

include/llama.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -557,6 +557,9 @@ extern "C" {
557557
// Returns true if the model is recurrent (like Mamba, RWKV, etc.)
558558
LLAMA_API bool llama_model_is_recurrent(const struct llama_model * model);
559559

560+
// Returns true if the model is hybrid-recurrent (like Jamba, Bamba, etc.)
561+
LLAMA_API bool llama_model_is_hybrid(const struct llama_model * model);
562+
560563
// Returns 0 on success
561564
LLAMA_API uint32_t llama_model_quantize(
562565
const char * fname_inp,

src/llama-arch.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1747,3 +1747,25 @@ llm_arch llm_arch_from_string(const std::string & name) {
17471747
const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor) {
17481748
return LLM_TENSOR_INFOS.at(tensor);
17491749
}
1750+
1751+
bool llm_arch_is_recurrent(const llm_arch & arch) {
1752+
switch (arch) {
1753+
case LLM_ARCH_MAMBA:
1754+
case LLM_ARCH_RWKV6:
1755+
case LLM_ARCH_RWKV6QWEN2:
1756+
case LLM_ARCH_RWKV7:
1757+
case LLM_ARCH_ARWKV7:
1758+
return true;
1759+
default:
1760+
return false;
1761+
}
1762+
}
1763+
1764+
bool llm_arch_is_hybrid(const llm_arch & arch) {
1765+
// TODO: There are currently no hybrid models! Once there are, this will be
1766+
// the place to identify them
1767+
switch (arch) {
1768+
default:
1769+
return false;
1770+
}
1771+
}

src/llama-arch.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -437,3 +437,6 @@ const char * llm_arch_name(llm_arch arch);
437437
llm_arch llm_arch_from_string(const std::string & name);
438438

439439
const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor);
440+
441+
bool llm_arch_is_recurrent(const llm_arch& arch);
442+
bool llm_arch_is_hybrid(const llm_arch& arch);

src/llama-model.cpp

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13804,14 +13804,11 @@ llama_token llama_model_decoder_start_token(const llama_model * model) {
1380413804
}
1380513805

1380613806
bool llama_model_is_recurrent(const llama_model * model) {
13807-
switch (model->arch) {
13808-
case LLM_ARCH_MAMBA: return true;
13809-
case LLM_ARCH_RWKV6: return true;
13810-
case LLM_ARCH_RWKV6QWEN2: return true;
13811-
case LLM_ARCH_RWKV7: return true;
13812-
case LLM_ARCH_ARWKV7: return true;
13813-
default: return false;
13814-
}
13807+
return llm_arch_is_recurrent(model->arch);
13808+
}
13809+
13810+
bool llama_model_is_hybrid(const llama_model * model) {
13811+
return llm_arch_is_hybrid(model->arch);
1381513812
}
1381613813

1381713814
const std::vector<std::pair<std::string, ggml_tensor *>> & llama_internal_get_tensor_map(const llama_model * model) {

0 commit comments

Comments
 (0)