Skip to content

Commit 40cde99

Browse files
committed
feat: Add llama_model_is_hybrid API call
Also, split llama_model_is_recurrent into llm_arch_is_recurrent in llama-arch with llama_model_is_recurrent delegating to llm_arch_is_recurrent. The same split is done for hybird. This is needed because there are places where the llama_model has not yet been initialized but we need to check if the model is recurrent (specifically for the per-layer recurrent check array in hparams). Branch: GraniteFour Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
1 parent 165794d commit 40cde99

File tree

4 files changed

+33
-8
lines changed

4 files changed

+33
-8
lines changed

include/llama.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -553,6 +553,9 @@ extern "C" {
553553
// Returns true if the model is recurrent (like Mamba, RWKV, etc.)
554554
LLAMA_API bool llama_model_is_recurrent(const struct llama_model * model);
555555

556+
// Returns true if the model is hybrid-recurrent (like Jamba, Bamba, etc.)
557+
LLAMA_API bool llama_model_is_hybrid(const struct llama_model * model);
558+
556559
// Returns 0 on success
557560
LLAMA_API uint32_t llama_model_quantize(
558561
const char * fname_inp,

src/llama-arch.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1744,3 +1744,25 @@ llm_arch llm_arch_from_string(const std::string & name) {
17441744
const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor) {
17451745
return LLM_TENSOR_INFOS.at(tensor);
17461746
}
1747+
1748+
bool llm_arch_is_recurrent(const llm_arch & arch) {
1749+
switch (arch) {
1750+
case LLM_ARCH_MAMBA:
1751+
case LLM_ARCH_RWKV6:
1752+
case LLM_ARCH_RWKV6QWEN2:
1753+
case LLM_ARCH_RWKV7:
1754+
case LLM_ARCH_ARWKV7:
1755+
return true;
1756+
default:
1757+
return false;
1758+
}
1759+
}
1760+
1761+
bool llm_arch_is_hybrid(const llm_arch & arch) {
1762+
// TODO: There are currently no hybrid models! Once there are, this will be
1763+
// the place to identify them
1764+
switch (arch) {
1765+
default:
1766+
return false;
1767+
}
1768+
}

src/llama-arch.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -435,3 +435,6 @@ const char * llm_arch_name(llm_arch arch);
435435
llm_arch llm_arch_from_string(const std::string & name);
436436

437437
const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor);
438+
439+
bool llm_arch_is_recurrent(const llm_arch& arch);
440+
bool llm_arch_is_hybrid(const llm_arch& arch);

src/llama-model.cpp

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13788,14 +13788,11 @@ llama_token llama_model_decoder_start_token(const llama_model * model) {
1378813788
}
1378913789

1379013790
bool llama_model_is_recurrent(const llama_model * model) {
13791-
switch (model->arch) {
13792-
case LLM_ARCH_MAMBA: return true;
13793-
case LLM_ARCH_RWKV6: return true;
13794-
case LLM_ARCH_RWKV6QWEN2: return true;
13795-
case LLM_ARCH_RWKV7: return true;
13796-
case LLM_ARCH_ARWKV7: return true;
13797-
default: return false;
13798-
}
13791+
return llm_arch_is_recurrent(model->arch);
13792+
}
13793+
13794+
bool llama_model_is_hybrid(const llama_model * model) {
13795+
return llm_arch_is_hybrid(model->arch);
1379913796
}
1380013797

1380113798
const std::vector<std::pair<std::string, ggml_tensor *>> & llama_internal_get_tensor_map(const llama_model * model) {

0 commit comments

Comments
 (0)