Skip to content

Commit e8d9499

Browse files
committed
feat: Add llama_model_is_hybrid API call
Also, split llama_model_is_recurrent into llm_arch_is_recurrent in llama-arch with llama_model_is_recurrent delegating to llm_arch_is_recurrent. The same split is done for hybird. This is needed because there are places where the llama_model has not yet been initialized but we need to check if the model is recurrent (specifically for the per-layer recurrent check array in hparams). Branch: GraniteFour Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
1 parent 89fea80 commit e8d9499

File tree

4 files changed

+33
-8
lines changed

4 files changed

+33
-8
lines changed

include/llama.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -572,6 +572,9 @@ extern "C" {
572572
// Returns true if the model is recurrent (like Mamba, RWKV, etc.)
573573
LLAMA_API bool llama_model_is_recurrent(const struct llama_model * model);
574574

575+
// Returns true if the model is hybrid-recurrent (like Jamba, Bamba, etc.)
576+
LLAMA_API bool llama_model_is_hybrid(const struct llama_model * model);
577+
575578
// Returns 0 on success
576579
LLAMA_API uint32_t llama_model_quantize(
577580
const char * fname_inp,

src/llama-arch.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1816,3 +1816,25 @@ llm_arch llm_arch_from_string(const std::string & name) {
18161816
const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor) {
18171817
return LLM_TENSOR_INFOS.at(tensor);
18181818
}
1819+
1820+
bool llm_arch_is_recurrent(const llm_arch & arch) {
1821+
switch (arch) {
1822+
case LLM_ARCH_MAMBA:
1823+
case LLM_ARCH_RWKV6:
1824+
case LLM_ARCH_RWKV6QWEN2:
1825+
case LLM_ARCH_RWKV7:
1826+
case LLM_ARCH_ARWKV7:
1827+
return true;
1828+
default:
1829+
return false;
1830+
}
1831+
}
1832+
1833+
bool llm_arch_is_hybrid(const llm_arch & arch) {
1834+
// TODO: There are currently no hybrid models! Once there are, this will be
1835+
// the place to identify them
1836+
switch (arch) {
1837+
default:
1838+
return false;
1839+
}
1840+
}

src/llama-arch.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -439,3 +439,6 @@ const char * llm_arch_name(llm_arch arch);
439439
llm_arch llm_arch_from_string(const std::string & name);
440440

441441
const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor);
442+
443+
bool llm_arch_is_recurrent(const llm_arch& arch);
444+
bool llm_arch_is_hybrid(const llm_arch& arch);

src/llama-model.cpp

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14377,14 +14377,11 @@ llama_token llama_model_decoder_start_token(const llama_model * model) {
1437714377
}
1437814378

1437914379
bool llama_model_is_recurrent(const llama_model * model) {
14380-
switch (model->arch) {
14381-
case LLM_ARCH_MAMBA: return true;
14382-
case LLM_ARCH_RWKV6: return true;
14383-
case LLM_ARCH_RWKV6QWEN2: return true;
14384-
case LLM_ARCH_RWKV7: return true;
14385-
case LLM_ARCH_ARWKV7: return true;
14386-
default: return false;
14387-
}
14380+
return llm_arch_is_recurrent(model->arch);
14381+
}
14382+
14383+
bool llama_model_is_hybrid(const llama_model * model) {
14384+
return llm_arch_is_hybrid(model->arch);
1438814385
}
1438914386

1439014387
const std::vector<std::pair<std::string, ggml_tensor *>> & llama_internal_get_tensor_map(const llama_model * model) {

0 commit comments

Comments
 (0)