Skip to content

Commit 17f403f

Browse files
compiladeCISC
authored andcommitted
llama : support Jamba hybrid Transformer-Mamba models (ggml-org#7531)
* wip: llama : separate recurrent states from the KV cache This will be necessary to support Jamba (and other recurrent models mixed with Attention). Doesn't compile yet, and finding a slot isn't yet done correctly for recurrent states. * llama : use std::find for seq_nodes in llama_rs_cache * llama : state checkpoints for recurrent models * llama : correctly handle more edge cases for the rs cache * llama : rename many llama_kv_cache_* functions * llama : remove useless return value for some llama_cache_* functions * llama : rethink recurrent state cell counts * llama : begin work on support for variable GQA This will also be useful for Jamba if we consider the Mamba layers to have 0 KV heads. * llama : gracefully fail when not finding hybrid slot * llama : support Jamba * llama : fix BERT inference without KV cache * convert-hf : check for unprocessed Jamba experts * convert-hf : support Mini-Jamba conversion * llama : fix Jamba quantization sanity checks * llama : sequence-length-aware batch splitting * llama : use equal-sequence-length sub-batches for recurrent models * ggml : simplify SSM-related operators * llama : make recurrent state slot allocation contiguous * llama : adapt internal uses of batches to llama_ubatch * llama : fix batch split output count for embeddings * llama : minimize swaps when reordering logits This reduces overhead when running hellaswag on thousands of sequences with very small 100k params Mamba models. * llama : fix edge case finding batch seq_id of split recurrent cell This otherwise was a problem when running the HellaSwag benchmark with small batch sizes, making it crash. * llama : avoid copies for simple batch splits * ggml : make ggml_ssm_scan not modify its source tensors * llama : fix shared recurrent tail cell count for small ubatch sizes Otherwise it was impossible to run the 'parallel' example with '-ub 1' with a Mamba or Jamba model. * llama : fix .base() compilation error on Windows * llama : allow doing the equivalent of SSM_CONV with SUM_ROWS and MUL * ggml : allow GGML_OP_CONCAT to work on non-contiguous tensors The implementation already supported it, and this makes Mamba's conv step slightly faster. * mamba : fix non-contiguous usage of ggml_silu * llama : session saving and reloading for hybrid models * convert_hf : fix Jamba conversion * llama : fix mixed signedness comparison * llama : use unused n_embd_k_gqa in k_shift This also slightly reduces the diff from the master branch * llama : begin renaming llama_past back to llama_kv_cache * llama : remove implicit recurrent state rollbacks * llama : partially apply clang-format style * convert : fix jamba conv1d shape squeezing * graph : add back hybrid memory graph input But this time it contains the sub-cache graph inputs. This *should* make it easier to handle updating the inputs when caching the graph (eventually). * model : add Jamba to Mamba-specific hparams printing * jamba : remove redundant nullptr initializations * model : remove unnecessary prefix for tensor loading constants Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> * model : use ggml_swiglu_split for Mamba Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> * model : make falcon-h1 use shared mamba2 layer builder * memory : avoid referring to KV in recurrent cache logs * gguf-py : avoid adding duplicate tensor mappings for Jamba Some of the tensor names are common with Llama4 --------- Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
1 parent fd415eb commit 17f403f

File tree

10 files changed

+621
-422
lines changed

10 files changed

+621
-422
lines changed

convert_hf_to_gguf.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4974,6 +4974,123 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
49744974
yield (new_name, data_torch)
49754975

49764976

4977+
@ModelBase.register("JambaForCausalLM")
4978+
class JambaModel(TextModel):
4979+
model_arch = gguf.MODEL_ARCH.JAMBA
4980+
4981+
def get_vocab_base_pre(self, tokenizer) -> str:
4982+
del tokenizer # unused
4983+
4984+
return "gpt-2"
4985+
4986+
def set_vocab(self):
4987+
if (self.dir_model / "tokenizer.model").is_file():
4988+
# Using Jamba's tokenizer.json causes errors on model load
4989+
# (something about "byte not found in vocab"),
4990+
# but there's a working tokenizer.model
4991+
self._set_vocab_sentencepiece()
4992+
else:
4993+
# Some Jamba models only have a tokenizer.json, which works.
4994+
self._set_vocab_gpt2()
4995+
4996+
def set_gguf_parameters(self):
4997+
d_model = self.find_hparam(["hidden_size", "mamba_d_model"])
4998+
d_conv = self.find_hparam(["mamba_d_conv"], optional=True) or 4
4999+
d_inner = self.hparams["mamba_expand"] * d_model
5000+
d_state = self.find_hparam(["mamba_d_state"], optional=True) or 16
5001+
# ceiling division
5002+
# ref: https://stackoverflow.com/a/17511341/22827863
5003+
# ref: https://github.com/state-spaces/mamba/blob/ce59daea3a090d011d6476c6e5b97f6d58ddad8b/mamba_ssm/modules/mamba_simple.py#L58
5004+
dt_rank = self.find_hparam(["mamba_dt_rank"], optional=True) or -(d_model // -16)
5005+
rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-6
5006+
n_kv_head = self.hparams["num_key_value_heads"]
5007+
attn_offset = self.hparams["attn_layer_offset"]
5008+
attn_period = self.hparams["attn_layer_period"]
5009+
n_kv_vec = [0 for _ in range(attn_offset)] + [
5010+
n_kv_head if (i - attn_offset) % attn_period == 0 else 0 for i in range(attn_offset, self.block_count)
5011+
]
5012+
5013+
self.gguf_writer.add_block_count(self.block_count)
5014+
self.gguf_writer.add_context_length(self.find_hparam(["max_position_embeddings", "n_ctx"]))
5015+
self.gguf_writer.add_embedding_length(d_model)
5016+
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
5017+
self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
5018+
self.gguf_writer.add_head_count_kv(n_kv_vec)
5019+
self.gguf_writer.add_ssm_conv_kernel(d_conv)
5020+
self.gguf_writer.add_ssm_inner_size(d_inner)
5021+
self.gguf_writer.add_ssm_state_size(d_state)
5022+
self.gguf_writer.add_ssm_time_step_rank(dt_rank)
5023+
self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps)
5024+
self.gguf_writer.add_expert_count(self.hparams["num_experts"])
5025+
self.gguf_writer.add_expert_used_count(self.hparams["num_experts_per_tok"])
5026+
self.gguf_writer.add_file_type(self.ftype)
5027+
5028+
_experts: list[dict[str, Tensor]] | None = None
5029+
5030+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
5031+
5032+
# Mini-Jamba
5033+
name = name.replace(".moe.", ".feed_forward.")
5034+
if bid is not None:
5035+
moe_offset = self.hparams["expert_layer_offset"]
5036+
moe_period = self.hparams["expert_layer_period"]
5037+
5038+
if not (bid >= moe_offset and (bid - moe_offset) % moe_period == 0):
5039+
name = name.replace(".experts.0.", ".")
5040+
5041+
# process the experts separately
5042+
if ".feed_forward.experts." in name:
5043+
n_experts = self.hparams["num_experts"]
5044+
5045+
assert bid is not None
5046+
5047+
if self._experts is None:
5048+
self._experts = [{} for _ in range(self.block_count)]
5049+
5050+
self._experts[bid][name] = data_torch
5051+
5052+
if len(self._experts[bid]) >= n_experts * 3:
5053+
5054+
# merge the experts into a single 3d tensor
5055+
for wid in ["down_proj", "gate_proj", "up_proj"]:
5056+
datas: list[Tensor] = []
5057+
5058+
for xid in range(n_experts):
5059+
ename = f"model.layers.{bid}.feed_forward.experts.{xid}.{wid}.weight"
5060+
datas.append(self._experts[bid][ename])
5061+
del self._experts[bid][ename]
5062+
5063+
data_torch = torch.stack(datas, dim=0)
5064+
5065+
# using the same merged name as qwen2moe
5066+
merged_name = f"model.layers.{bid}.mlp.experts.{wid}.weight"
5067+
5068+
new_name = self.map_tensor_name(merged_name)
5069+
5070+
yield new_name, data_torch
5071+
return
5072+
5073+
new_name = self.map_tensor_name(name)
5074+
5075+
if self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.SSM_CONV1D, bid):
5076+
data_torch = data_torch.squeeze()
5077+
5078+
if name.endswith(".A_log"):
5079+
logger.debug("A_log --> A ==> " + new_name)
5080+
data_torch = -torch.exp(data_torch)
5081+
5082+
yield (new_name, data_torch)
5083+
5084+
def prepare_tensors(self):
5085+
super().prepare_tensors()
5086+
5087+
if self._experts is not None:
5088+
# flatten `list[dict[str, Tensor]]` into `list[str]`
5089+
experts = [k for d in self._experts for k in d.keys()]
5090+
if len(experts) > 0:
5091+
raise ValueError(f"Unprocessed experts: {experts}")
5092+
5093+
49775094
@ModelBase.register("CohereForCausalLM")
49785095
class CommandR2Model(TextModel):
49795096
model_arch = gguf.MODEL_ARCH.COMMAND_R

gguf-py/gguf/constants.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,7 @@ class MODEL_ARCH(IntEnum):
330330
ARWKV7 = auto()
331331
MAMBA = auto()
332332
MAMBA2 = auto()
333+
JAMBA = auto()
333334
XVERSE = auto()
334335
COMMAND_R = auto()
335336
COHERE2 = auto()
@@ -432,7 +433,10 @@ class MODEL_TENSOR(IntEnum):
432433
SSM_CONV1D = auto()
433434
SSM_X = auto()
434435
SSM_DT = auto()
436+
SSM_DT_NORM = auto()
435437
SSM_A = auto()
438+
SSM_B_NORM = auto()
439+
SSM_C_NORM = auto()
436440
SSM_D = auto()
437441
SSM_NORM = auto()
438442
SSM_OUT = auto()
@@ -635,6 +639,7 @@ class MODEL_TENSOR(IntEnum):
635639
MODEL_ARCH.ARWKV7: "arwkv7",
636640
MODEL_ARCH.MAMBA: "mamba",
637641
MODEL_ARCH.MAMBA2: "mamba2",
642+
MODEL_ARCH.JAMBA: "jamba",
638643
MODEL_ARCH.XVERSE: "xverse",
639644
MODEL_ARCH.COMMAND_R: "command-r",
640645
MODEL_ARCH.COHERE2: "cohere2",
@@ -738,7 +743,10 @@ class MODEL_TENSOR(IntEnum):
738743
MODEL_TENSOR.SSM_CONV1D: "blk.{bid}.ssm_conv1d",
739744
MODEL_TENSOR.SSM_X: "blk.{bid}.ssm_x",
740745
MODEL_TENSOR.SSM_DT: "blk.{bid}.ssm_dt",
746+
MODEL_TENSOR.SSM_DT_NORM: "blk.{bid}.ssm_dt_norm",
741747
MODEL_TENSOR.SSM_A: "blk.{bid}.ssm_a",
748+
MODEL_TENSOR.SSM_B_NORM: "blk.{bid}.ssm_b_norm",
749+
MODEL_TENSOR.SSM_C_NORM: "blk.{bid}.ssm_c_norm",
742750
MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d",
743751
MODEL_TENSOR.SSM_NORM: "blk.{bid}.ssm_norm",
744752
MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out",
@@ -1738,6 +1746,34 @@ class MODEL_TENSOR(IntEnum):
17381746
MODEL_TENSOR.SSM_NORM,
17391747
MODEL_TENSOR.SSM_OUT,
17401748
],
1749+
MODEL_ARCH.JAMBA: [
1750+
MODEL_TENSOR.TOKEN_EMBD,
1751+
MODEL_TENSOR.OUTPUT_NORM,
1752+
MODEL_TENSOR.OUTPUT,
1753+
MODEL_TENSOR.ATTN_NORM,
1754+
MODEL_TENSOR.ATTN_Q,
1755+
MODEL_TENSOR.ATTN_K,
1756+
MODEL_TENSOR.ATTN_V,
1757+
MODEL_TENSOR.ATTN_OUT,
1758+
MODEL_TENSOR.SSM_IN,
1759+
MODEL_TENSOR.SSM_CONV1D,
1760+
MODEL_TENSOR.SSM_X,
1761+
MODEL_TENSOR.SSM_DT,
1762+
MODEL_TENSOR.SSM_DT_NORM,
1763+
MODEL_TENSOR.SSM_A,
1764+
MODEL_TENSOR.SSM_B_NORM,
1765+
MODEL_TENSOR.SSM_C_NORM,
1766+
MODEL_TENSOR.SSM_D,
1767+
MODEL_TENSOR.SSM_OUT,
1768+
MODEL_TENSOR.FFN_GATE_INP,
1769+
MODEL_TENSOR.FFN_NORM,
1770+
MODEL_TENSOR.FFN_GATE,
1771+
MODEL_TENSOR.FFN_DOWN,
1772+
MODEL_TENSOR.FFN_UP,
1773+
MODEL_TENSOR.FFN_GATE_EXP,
1774+
MODEL_TENSOR.FFN_DOWN_EXP,
1775+
MODEL_TENSOR.FFN_UP_EXP,
1776+
],
17411777
MODEL_ARCH.XVERSE: [
17421778
MODEL_TENSOR.TOKEN_EMBD,
17431779
MODEL_TENSOR.OUTPUT_NORM,

gguf-py/gguf/tensor_mapping.py

Lines changed: 41 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,8 @@ class TensorNameMap:
279279
"transformer.decoder_layer.{bid}.rms_norm_2", # Grok
280280
"encoder.layers.{bid}.post_attention_layernorm", # chatglm
281281
"transformer.layers.{bid}.ffn_norm", # openelm
282+
"model.layers.{bid}.pre_ff_layernorm", # jamba
283+
"model.layers.{bid}.pre_moe_layernorm", # mini-jamba
282284
"model.layers.{bid}.post_attention_layernorm", # llama4
283285
"transformer_encoder.{bid}.ffn_norm", # neobert
284286
),
@@ -303,7 +305,7 @@ class TensorNameMap:
303305
"transformer.decoder_layer.{bid}.router", # Grok
304306
"transformer.blocks.{bid}.ffn.router.layer", # dbrx
305307
"model.layers.{bid}.block_sparse_moe.router.layer", # granitemoe
306-
"model.layers.{bid}.feed_forward.router", # llama4
308+
"model.layers.{bid}.feed_forward.router", # llama4 jamba
307309
"encoder.layers.{bid}.mlp.router.layer", # nomic-bert-moe
308310
"model.layers.{bid}.mlp.gate.wg", # hunyuan
309311
),
@@ -347,7 +349,7 @@ class TensorNameMap:
347349
"model.layers.{bid}.residual_mlp.w3", # arctic
348350
"encoder.layers.{bid}.mlp.dense_h_to_4h", # chatglm
349351
"transformer.h.{bid}.mlp.c_fc_1", # exaone
350-
"model.layers.{bid}.feed_forward.up_proj", # llama4
352+
"model.layers.{bid}.feed_forward.up_proj", # llama4 jamba
351353
"transformer_encoder.{bid}.ffn.w12", # neobert
352354
),
353355

@@ -387,7 +389,7 @@ class TensorNameMap:
387389
"transformer.h.{bid}.mlp.linear_1", # refact
388390
"model.layers.{bid}.residual_mlp.w1", # arctic
389391
"transformer.h.{bid}.mlp.c_fc_0", # exaone
390-
"model.layers.{bid}.feed_forward.gate_proj", # llama4
392+
"model.layers.{bid}.feed_forward.gate_proj", # llama4 jamba
391393
),
392394

393395
MODEL_TENSOR.FFN_GATE_EXP: (
@@ -433,7 +435,7 @@ class TensorNameMap:
433435
"encoder.layer.{bid}.mlp.down_layer", # jina-bert-v2
434436
"encoder.layers.{bid}.mlp.dense_4h_to_h", # chatglm
435437
"model.layers.h.{bid}.mlp.c_proj", # exaone
436-
"model.layers.{bid}.feed_forward.down_proj", # llama4
438+
"model.layers.{bid}.feed_forward.down_proj", # llama4 jamba
437439
"transformer_encoder.{bid}.ffn.w3", # neobert
438440
),
439441

@@ -554,38 +556,53 @@ class TensorNameMap:
554556
),
555557

556558
MODEL_TENSOR.SSM_IN: (
557-
"model.layers.{bid}.in_proj",
558-
"backbone.layers.{bid}.mixer.in_proj",
559-
"model.layers.{bid}.mamba.in_proj",
559+
"model.layers.{bid}.in_proj", # mamba-hf
560+
"backbone.layers.{bid}.mixer.in_proj", # mamba
561+
"model.layers.{bid}.mamba.in_proj", # jamba falcon-h1
560562
),
561563

562564
MODEL_TENSOR.SSM_CONV1D: (
563-
"model.layers.{bid}.conv1d",
564-
"backbone.layers.{bid}.mixer.conv1d",
565-
"model.layers.{bid}.mamba.conv1d",
565+
"model.layers.{bid}.conv1d", # mamba-hf
566+
"backbone.layers.{bid}.mixer.conv1d", # mamba
567+
"model.layers.{bid}.mamba.conv1d", # jamba falcon-h1
566568
),
567569

568570
MODEL_TENSOR.SSM_X: (
569-
"model.layers.{bid}.x_proj",
570-
"backbone.layers.{bid}.mixer.x_proj",
571+
"model.layers.{bid}.x_proj", # mamba-hf
572+
"backbone.layers.{bid}.mixer.x_proj", # mamba
573+
"model.layers.{bid}.mamba.x_proj", # jamba
571574
),
572575

573576
MODEL_TENSOR.SSM_DT: (
574-
"model.layers.{bid}.dt_proj",
575-
"backbone.layers.{bid}.mixer.dt_proj",
576-
"model.layers.{bid}.mamba.dt_proj",
577+
"model.layers.{bid}.dt_proj", # mamba-hf
578+
"backbone.layers.{bid}.mixer.dt_proj", # mamba
579+
"model.layers.{bid}.mamba.dt_proj", # jamba falcon-h1
580+
),
581+
582+
MODEL_TENSOR.SSM_DT_NORM: (
583+
"model.layers.{bid}.mamba.dt_layernorm", # jamba
577584
),
578585

579586
MODEL_TENSOR.SSM_A: (
580-
"model.layers.{bid}.A_log",
581-
"backbone.layers.{bid}.mixer.A_log",
582-
"model.layers.{bid}.mamba.A_log",
587+
"model.layers.{bid}.A_log", # mamba-hf
588+
"backbone.layers.{bid}.mixer.A_log", # mamba
589+
"model.layers.{bid}.mamba.A_log", # jamba falcon-h1
590+
),
591+
592+
MODEL_TENSOR.SSM_B_NORM: (
593+
"model.layers.{bid}.mamba.b_layernorm", # jamba
594+
"model.layers.{bid}.mamba.B_layernorm", # mini-jamba
595+
),
596+
597+
MODEL_TENSOR.SSM_C_NORM: (
598+
"model.layers.{bid}.mamba.c_layernorm", # jamba
599+
"model.layers.{bid}.mamba.C_layernorm", # mini-jamba
583600
),
584601

585602
MODEL_TENSOR.SSM_D: (
586-
"model.layers.{bid}.D",
587-
"backbone.layers.{bid}.mixer.D",
588-
"model.layers.{bid}.mamba.D",
603+
"model.layers.{bid}.D", # mamba-hf
604+
"backbone.layers.{bid}.mixer.D", # mamba
605+
"model.layers.{bid}.mamba.D", # jamba falcon-h1
589606
),
590607

591608
MODEL_TENSOR.SSM_NORM: (
@@ -594,9 +611,9 @@ class TensorNameMap:
594611
),
595612

596613
MODEL_TENSOR.SSM_OUT: (
597-
"model.layers.{bid}.out_proj",
598-
"backbone.layers.{bid}.mixer.out_proj",
599-
"model.layers.{bid}.mamba.out_proj", # falcon-h1
614+
"model.layers.{bid}.out_proj", # mamba-hf
615+
"backbone.layers.{bid}.mixer.out_proj", # mamba
616+
"model.layers.{bid}.mamba.out_proj", # jamba falcon-h1
600617
),
601618

602619
MODEL_TENSOR.TIME_MIX_W0: (

src/llama-arch.cpp

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
4646
{ LLM_ARCH_STARCODER2, "starcoder2" },
4747
{ LLM_ARCH_MAMBA, "mamba" },
4848
{ LLM_ARCH_MAMBA2, "mamba2" },
49+
{ LLM_ARCH_JAMBA, "jamba" },
4950
{ LLM_ARCH_FALCON_H1, "falcon-h1" },
5051
{ LLM_ARCH_XVERSE, "xverse" },
5152
{ LLM_ARCH_COMMAND_R, "command-r" },
@@ -1025,6 +1026,37 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
10251026
{ LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
10261027
},
10271028
},
1029+
{
1030+
LLM_ARCH_JAMBA,
1031+
{
1032+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
1033+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
1034+
{ LLM_TENSOR_OUTPUT, "output" },
1035+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
1036+
{ LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" },
1037+
{ LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" },
1038+
{ LLM_TENSOR_SSM_X, "blk.%d.ssm_x" },
1039+
{ LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" },
1040+
{ LLM_TENSOR_SSM_DT_NORM, "blk.%d.ssm_dt_norm" },
1041+
{ LLM_TENSOR_SSM_A, "blk.%d.ssm_a" },
1042+
{ LLM_TENSOR_SSM_B_NORM, "blk.%d.ssm_b_norm" },
1043+
{ LLM_TENSOR_SSM_C_NORM, "blk.%d.ssm_c_norm" },
1044+
{ LLM_TENSOR_SSM_D, "blk.%d.ssm_d" },
1045+
{ LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
1046+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
1047+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
1048+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
1049+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
1050+
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
1051+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
1052+
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
1053+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
1054+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
1055+
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
1056+
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
1057+
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
1058+
},
1059+
},
10281060
{
10291061
LLM_ARCH_FALCON_H1,
10301062
{
@@ -1845,6 +1877,9 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
18451877
{LLM_TENSOR_FFN_ACT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_DIV}},
18461878
{LLM_TENSOR_SSM_CONV1D, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_CONV}},
18471879
{LLM_TENSOR_SSM_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_SCAN}},
1880+
{LLM_TENSOR_SSM_DT_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
1881+
{LLM_TENSOR_SSM_B_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
1882+
{LLM_TENSOR_SSM_C_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
18481883
{LLM_TENSOR_SSM_D, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
18491884
{LLM_TENSOR_SSM_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
18501885
{LLM_TENSOR_TIME_MIX_LERP_X, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
@@ -1994,6 +2029,7 @@ bool llm_arch_is_recurrent(const llm_arch & arch) {
19942029
bool llm_arch_is_hybrid(const llm_arch & arch) {
19952030
// List all mamba-attention hybrid models here
19962031
switch (arch) {
2032+
case LLM_ARCH_JAMBA:
19972033
case LLM_ARCH_FALCON_H1:
19982034
return true;
19992035
default:

0 commit comments

Comments
 (0)