Skip to content

Commit 0465506

Browse files
ibrahimkhadraouiyounesbelkadaggerganovCISCcompilade
authored
model : add support for Falcon-H1 family (#14534)
* v1 * push more fixes * another fix * fix * more fixes * minor fix * more cleaning on python code * python fixes * changed precision for multipliers float 32->64 * fixes * another fix * fix * pre-norm -> norm * fix * Revert "fix" This reverts commit 243e4d1. * fix * small fix ffn_norm * try * mix instead of max * fix vocab size * conflict solve * fixed multipliers * falcon-h1 specefic vocab resolved * read arch from gguf.MODEL_ARCH * mamba_d_ssm added to d_inner find_hparam * remove unused functions from gguf_writer.py * override modify_tensors instead of get_tensors * fix conversion and d_inner * added some cb functions for debugging puposes * inp_out_ids moved outside of layers loop * mup_vec create as float64 * fix rope_theta * injected mup * clean ups * rm extra space * rm unused MAMBA_CHUNK_SIZE * rm unused key * add bos False * changed ROPE_TYPE * cleaning debugging stuff * cleaning debug quant * fix comment * some cleanups * some cleanups * Update src/llama-model-loader.cpp * more cleanups * moe cleanuips * d_ssm -> d_inner; * cleaning unused hparams * cleanup * more cleanups * more cleanups on python conversion; * minor cleanups * Apply suggestions from code review Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * remove todo * added falcon-h1 * tensor not required * clean * remove unneeded attributes * more cleanups and fixed conversion * remove final_norm * flake8 fixes * Update src/llama-model.cpp Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> * flake8 fixes * Update src/llama-hparams.cpp Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> * Update src/llama-model.cpp Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> * Update src/llama-model.cpp Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> * Update src/llama-arch.cpp Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> * Update convert_hf_to_gguf.py Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> * added hashes * Update src/llama-arch.cpp Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * Update src/llama-vocab.cpp Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * update the update file * Revert "update the update file" This reverts commit 082ab4a. * fix: address suggestions * fix: update convert_hf_to_gguf.py * Update gguf-py/gguf/constants.py Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> * Update src/llama-model-loader.cpp Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> * d_inner fixed * Update src/llama-model.cpp Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> * reshaping ssm_norm for 34B * removing generate_mup * remove duplicates metadata keys * rm comment * final comment * fix unused args * fix constants * fix bad merge * Update src/llama-model.cpp Co-authored-by: compilade <git@compilade.net> * falcon-h1: remove unused ssm_in_b and bad merge * Update src/llama-model.cpp Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> * falcon-h1: fix last comment * Update convert_hf_to_gguf.py Co-authored-by: compilade <git@compilade.net> * falcon-h1: revert add_add_bos(False) * falcon-h1: fix tied weights * falcon-h1: remove whitespace * falcon-h1: fix wrong size param * falcon-h1: fix whitespace issues --------- Co-authored-by: younesbelkada <younes.belkada@tii.ae> Co-authored-by: Younes B <49240599+younesbelkada@users.noreply.github.com> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> Co-authored-by: compilade <git@compilade.net>
1 parent 20b7bf8 commit 0465506

File tree

8 files changed

+585
-9
lines changed

8 files changed

+585
-9
lines changed

convert_hf_to_gguf.py

Lines changed: 126 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -818,6 +818,18 @@ def get_vocab_base_pre(self, tokenizer) -> str:
818818
if chkhsh == "7e57df22b1fe23a7b1e1c7f3dc4e3f96d43a4eb0836d0c6bdc3436d7b2f1c664":
819819
# ref: https://huggingface.co/tencent/Hunyuan-A13B-Instruct
820820
res = "hunyuan"
821+
if chkhsh == "a6b57017d60e6edb4d88ecc2845188e0eb333a70357e45dcc9b53964a73bbae6":
822+
# ref: https://huggingface.co/tiiuae/Falcon-H1-0.5B-Base
823+
res = "falcon-h1"
824+
if chkhsh == "60476e1243776c4fb1b993dbd7a5f15ac22f83c80afdf425fa5ae01c8d44ef86":
825+
# ref: https://huggingface.co/tiiuae/Falcon-H1-1B-Base
826+
res = "falcon-h1"
827+
if chkhsh == "3eda48b4c4dc7de733d1a8b3e3b4a85243dbbf704da2ee9d42c6beced8897896":
828+
# ref: https://huggingface.co/tiiuae/Falcon-H1-7B-Base
829+
res = "falcon-h1"
830+
if chkhsh == "48f8e02c0359c0bbdd82f26909171fac1c18a457bb47573ed1fe3bbb2c1cfd4b":
831+
# ref: https://huggingface.co/tiiuae/Falcon-H1-34B-Base
832+
res = "falcon-h1"
821833

822834
if res is None:
823835
logger.warning("\n")
@@ -4899,17 +4911,19 @@ def set_vocab(self):
48994911
def set_gguf_parameters(self):
49004912
d_model = self.find_hparam(["hidden_size", "d_model", "dim"])
49014913
d_conv = self.find_hparam(["conv_kernel", "d_conv"], optional=True) or 4
4902-
d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * d_model
4914+
d_inner = self.find_hparam(["mamba_d_ssm", "intermediate_size", "d_inner"], optional=True) or 2 * d_model
49034915
d_state = self.find_hparam(["state_size", "d_state"], optional=True) or 128
4904-
head_dim = self.find_hparam(["head_dim"], optional=True) or 64
4916+
head_dim = self.find_hparam(["mamba_d_head", "head_dim"], optional=True) or 64
49054917
n_group = self.find_hparam(["n_groups"], optional=True) or 1
49064918

49074919
rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-5
49084920

49094921
# Fail early for models which don't have a block expansion factor of 2
49104922
# TODO: does this really matter?
4911-
assert d_inner == 2 * d_model
4912-
assert d_inner % head_dim == 0
4923+
# skip the assertion for FalconH1 Model
4924+
if self.model_arch != gguf.MODEL_ARCH.FALCON_H1:
4925+
assert d_inner == 2 * d_model
4926+
assert d_inner % head_dim == 0
49134927

49144928
self.gguf_writer.add_context_length(2**20) # arbitrary value; for those who use the default
49154929
self.gguf_writer.add_embedding_length(d_model)
@@ -4946,7 +4960,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
49464960
data_torch = data_torch.reshape((*data_torch.shape, 1))
49474961
elif self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.SSM_NORM, bid):
49484962
d_model = self.find_hparam(["hidden_size", "d_model", "dim"])
4949-
d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * d_model
4963+
d_inner = self.find_hparam(["mamba_d_ssm", "intermediate_size", "d_inner"], optional=True) or 2 * d_model
49504964
n_group = self.hparams.get("n_groups", 1)
49514965
data_torch = data_torch.reshape((n_group, d_inner // n_group))
49524966

@@ -6539,6 +6553,113 @@ def set_gguf_parameters(self):
65396553
self.gguf_writer.add_audio_stack_factor(self.global_config["stack_factor"])
65406554

65416555

6556+
@ModelBase.register("FalconH1ForCausalLM")
6557+
class FalconH1Model(Mamba2Model):
6558+
model_arch = gguf.MODEL_ARCH.FALCON_H1
6559+
6560+
def __init__(self, *args, **kwargs):
6561+
# Set the hparam prefixes for Falcon Mamba2
6562+
self.hparam_prefixes = ["mamba"]
6563+
6564+
# Initialize the base Mamba2Model
6565+
super().__init__(*args, **kwargs)
6566+
6567+
# Use Llama conversion for attention
6568+
self._transformer_model_class = LlamaModel
6569+
6570+
# n_group and d_inner are used during reshape_tensors for mamaba2
6571+
self.n_group = self.find_hparam(["n_groups"])
6572+
self.d_inner = self.find_hparam(["mamba_d_ssm"])
6573+
self.d_head = self.find_hparam(["d_head"])
6574+
6575+
# Initialize any Falcon Mamba2 specific attributes
6576+
self.has_attention = True # Falcon Mamba2 has attention components
6577+
6578+
# Load Falcon-H1 multipliers from hyperparameters
6579+
self.attention_in_multiplier = self.find_hparam(["attention_in_multiplier"], optional=True)
6580+
self.attention_out_multiplier = self.find_hparam(["attention_out_multiplier"], optional=True)
6581+
self.ssm_in_multiplier = self.find_hparam(["ssm_in_multiplier"], optional=True)
6582+
self.ssm_out_multiplier = self.find_hparam(["ssm_out_multiplier"], optional=True)
6583+
self.mlp_multipliers = self.find_hparam(["mlp_multipliers"], optional=True)
6584+
self.ssm_multipliers = self.find_hparam(["ssm_multipliers"], optional=True)
6585+
self.intermediate_size = self.find_hparam(["intermediate_size"])
6586+
self.key_multiplier = self.find_hparam(["key_multiplier"], optional=True)
6587+
6588+
def find_hparam(self, keys: Iterable[str], *args, **kwargs) -> Any:
6589+
prefixed = []
6590+
for pfx in self.hparam_prefixes:
6591+
prefixed.extend(
6592+
"_".join([pfx, k])
6593+
for k in keys
6594+
)
6595+
keys = list(keys) + prefixed
6596+
return super().find_hparam(keys, *args, **kwargs)
6597+
6598+
def set_vocab(self):
6599+
self._set_vocab_gpt2()
6600+
6601+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
6602+
tensors = list(super().modify_tensors(data_torch, name, bid))
6603+
tensor = tensors[0][1]
6604+
6605+
if "down_proj" in name:
6606+
tensor = tensor * self.mlp_multipliers[1]
6607+
elif "gate_proj" in name:
6608+
tensor = tensor * self.mlp_multipliers[0]
6609+
elif "k_proj" in name:
6610+
tensor = tensor * self.key_multiplier * self.attention_in_multiplier
6611+
elif "q_proj" in name:
6612+
tensor = tensor * self.attention_in_multiplier
6613+
elif "v_proj" in name:
6614+
tensor = tensor * self.attention_in_multiplier
6615+
elif "o_proj" in name:
6616+
tensor = tensor * self.attention_out_multiplier
6617+
elif "out_proj" in name:
6618+
tensor = tensor * self.ssm_out_multiplier
6619+
elif "in_proj" in name:
6620+
tensor = tensor * self.ssm_in_multiplier
6621+
zxbcdt_multipliers = self.hparams["ssm_multipliers"]
6622+
intermediate_size = self.hparams["mamba_d_ssm"]
6623+
groups_time_state_size = self.hparams["mamba_n_groups"] * self.hparams["mamba_d_state"]
6624+
tensor[:intermediate_size, :] *= zxbcdt_multipliers[0]
6625+
tensor[intermediate_size:2 * intermediate_size, :] *= zxbcdt_multipliers[1]
6626+
tensor[2 * intermediate_size:2 * intermediate_size + groups_time_state_size, :] *= zxbcdt_multipliers[2]
6627+
tensor[2 * intermediate_size + groups_time_state_size:2 * intermediate_size + 2 * groups_time_state_size, :] *= zxbcdt_multipliers[3]
6628+
tensor[2 * intermediate_size + 2 * groups_time_state_size:, :] *= zxbcdt_multipliers[4]
6629+
elif "lm_head" in name:
6630+
tensor = tensor * self.hparams["lm_head_multiplier"]
6631+
elif "embed_tokens" in name:
6632+
tensor = tensor * self.hparams["embedding_multiplier"]
6633+
elif "mamba.norm" in name:
6634+
tensor = tensor.reshape(self.n_group, self.d_inner // self.n_group)
6635+
6636+
tensors = [(tensors[0][0], tensor)]
6637+
return tensors
6638+
6639+
def set_gguf_parameters(self):
6640+
super().set_gguf_parameters()
6641+
6642+
## General Params ##
6643+
self.gguf_writer.add_vocab_size(self.hparams["vocab_size"])
6644+
# Override some Mamba2 defaults
6645+
self.gguf_writer.add_block_count(self.block_count)
6646+
self.gguf_writer.add_context_length(self.hparams.get("max_position_embeddings", 0))
6647+
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
6648+
6649+
## Attention params ##
6650+
self.gguf_writer.add_head_count(self.hparams["num_attention_heads"]) # Override value 0 from Mamba2
6651+
self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"])
6652+
self.gguf_writer.add_key_length(self.hparams["head_dim"])
6653+
self.gguf_writer.add_value_length(self.hparams["head_dim"])
6654+
6655+
## Validation ##
6656+
assert self.hparams.get("hidden_act") in [None, "silu"], "Only SILU activation supported"
6657+
assert self.d_inner % self.d_head == 0, f"SSM inner size {self.d_inner} not a multiple of head dim {self.d_head}"
6658+
6659+
# Add any other Falcon Mamba2 specific configuration
6660+
self.gguf_writer.add_rope_freq_base(self.find_hparam(["rope_theta"]))
6661+
6662+
65426663
@ModelBase.register("HunYuanMoEV1ForCausalLM")
65436664
class HunYuanMoEModel(TextModel):
65446665
model_arch = gguf.MODEL_ARCH.HUNYUAN_MOE

convert_hf_to_gguf_update.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,11 @@ class TOKENIZER_TYPE(IntEnum):
138138
{"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-hf", "chkhsh": "a1336059768a55c99a734006ffb02203cd450fed003e9a71886c88acf24fdbc2"},
139139
{"name": "minerva-7b", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0", "chkhsh": "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35"},
140140
{"name": "hunyuan", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tencent/Hunyuan-A13B-Instruct", "chkhsh": "7e57df22b1fe23a7b1e1c7f3dc4e3f96d43a4eb0836d0c6bdc3436d7b2f1c664"},
141+
# falcon-h1 series uses 4 different tokenizers across model sizes (0.5b - 34b), hence we need to define 4 different hashes
142+
{"name": "falcon-h1", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/Falcon-H1-0.5B-Base", "chkhsh": "a6b57017d60e6edb4d88ecc2845188e0eb333a70357e45dcc9b53964a73bbae6"},
143+
{"name": "falcon-h1", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/Falcon-H1-1B-Base", "chkhsh": "60476e1243776c4fb1b993dbd7a5f15ac22f83c80afdf425fa5ae01c8d44ef86"},
144+
{"name": "falcon-h1", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/Falcon-H1-7B-Base", "chkhsh": "3eda48b4c4dc7de733d1a8b3e3b4a85243dbbf704da2ee9d42c6beced8897896"},
145+
{"name": "falcon-h1", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/Falcon-H1-34B-Base", "chkhsh": "48f8e02c0359c0bbdd82f26909171fac1c18a457bb47573ed1fe3bbb2c1cfd4b"},
141146
]
142147

143148

gguf-py/gguf/constants.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,7 @@ class MODEL_ARCH(IntEnum):
288288
LLAMA4 = auto()
289289
DECI = auto()
290290
FALCON = auto()
291+
FALCON_H1 = auto()
291292
BAICHUAN = auto()
292293
GROK = auto()
293294
GPT2 = auto()
@@ -662,6 +663,7 @@ class MODEL_TENSOR(IntEnum):
662663
MODEL_ARCH.DOTS1: "dots1",
663664
MODEL_ARCH.ARCEE: "arcee",
664665
MODEL_ARCH.ERNIE4_5: "ernie4_5",
666+
MODEL_ARCH.FALCON_H1: "falcon-h1",
665667
MODEL_ARCH.HUNYUAN_MOE: "hunyuan-moe",
666668
MODEL_ARCH.SMOLLM3: "smollm3",
667669
}
@@ -2215,6 +2217,40 @@ class MODEL_TENSOR(IntEnum):
22152217
MODEL_TENSOR.FFN_DOWN,
22162218
MODEL_TENSOR.FFN_UP,
22172219
],
2220+
MODEL_ARCH.FALCON_H1: [
2221+
# Token embedding
2222+
MODEL_TENSOR.TOKEN_EMBD,
2223+
2224+
# Input layernorm
2225+
MODEL_TENSOR.ATTN_NORM,
2226+
2227+
# Attention components
2228+
MODEL_TENSOR.ATTN_Q, # Query projection
2229+
MODEL_TENSOR.ATTN_K, # Key projection
2230+
MODEL_TENSOR.ATTN_V, # Value projection
2231+
MODEL_TENSOR.ATTN_OUT, # Output projection
2232+
2233+
# SSM components (Mamba2 specific)
2234+
MODEL_TENSOR.SSM_IN, # Input projection for SSM
2235+
MODEL_TENSOR.SSM_CONV1D, # Convolution layer
2236+
MODEL_TENSOR.SSM_DT, # Delta time projection
2237+
MODEL_TENSOR.SSM_A, # A parameter (log form)
2238+
MODEL_TENSOR.SSM_D, # D parameter
2239+
MODEL_TENSOR.SSM_NORM, # Normalization in SSM
2240+
MODEL_TENSOR.SSM_OUT, # Output projection
2241+
2242+
# Pre-feedforward layernorm
2243+
MODEL_TENSOR.FFN_PRE_NORM,
2244+
2245+
# Feed-forward network components
2246+
MODEL_TENSOR.FFN_GATE, # Gate projection (SwiGLU)
2247+
MODEL_TENSOR.FFN_DOWN, # Down projection
2248+
MODEL_TENSOR.FFN_UP, # Up projection
2249+
2250+
# Post-feedforward layernorm
2251+
MODEL_TENSOR.OUTPUT_NORM, # Final layer norm
2252+
MODEL_TENSOR.OUTPUT, # Output projection (lm_head)
2253+
],
22182254
MODEL_ARCH.HUNYUAN_MOE: [
22192255
MODEL_TENSOR.TOKEN_EMBD,
22202256
MODEL_TENSOR.OUTPUT_NORM,

gguf-py/gguf/tensor_mapping.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,12 +286,14 @@ class TensorNameMap:
286286
# Post feed-forward norm
287287
MODEL_TENSOR.FFN_PRE_NORM: (
288288
"model.layers.{bid}.pre_feedforward_layernorm", # gemma2
289+
"model.layers.{bid}.pre_ff_layernorm.weight",
289290
),
290291

291292
# Post feed-forward norm
292293
MODEL_TENSOR.FFN_POST_NORM: (
293294
"model.layers.{bid}.post_feedforward_layernorm", # gemma2 olmo2
294295
"model.layers.{bid}.post_mlp_layernorm", # glm-4-0414
296+
"model.layers.{bid}.feed_forward.up_proj",
295297
),
296298

297299
MODEL_TENSOR.FFN_GATE_INP: (
@@ -363,6 +365,7 @@ class TensorNameMap:
363365
"model.layers.{bid}.mlp.shared_expert.up_proj", # qwen2moe
364366
"model.layers.{bid}.mlp.shared_experts.up_proj", # deepseek deepseek2
365367
"model.layers.{bid}.feed_forward.shared_expert.up_proj", # llama4
368+
"model.layers.{bid}.feed_forward.down_proj",
366369
"model.layers.{bid}.mlp.shared_mlp.up_proj", # hunyuan
367370
),
368371

@@ -553,11 +556,13 @@ class TensorNameMap:
553556
MODEL_TENSOR.SSM_IN: (
554557
"model.layers.{bid}.in_proj",
555558
"backbone.layers.{bid}.mixer.in_proj",
559+
"model.layers.{bid}.mamba.in_proj",
556560
),
557561

558562
MODEL_TENSOR.SSM_CONV1D: (
559563
"model.layers.{bid}.conv1d",
560564
"backbone.layers.{bid}.mixer.conv1d",
565+
"model.layers.{bid}.mamba.conv1d",
561566
),
562567

563568
MODEL_TENSOR.SSM_X: (
@@ -568,25 +573,30 @@ class TensorNameMap:
568573
MODEL_TENSOR.SSM_DT: (
569574
"model.layers.{bid}.dt_proj",
570575
"backbone.layers.{bid}.mixer.dt_proj",
576+
"model.layers.{bid}.mamba.dt_proj",
571577
),
572578

573579
MODEL_TENSOR.SSM_A: (
574580
"model.layers.{bid}.A_log",
575581
"backbone.layers.{bid}.mixer.A_log",
582+
"model.layers.{bid}.mamba.A_log",
576583
),
577584

578585
MODEL_TENSOR.SSM_D: (
579586
"model.layers.{bid}.D",
580587
"backbone.layers.{bid}.mixer.D",
588+
"model.layers.{bid}.mamba.D",
581589
),
582590

583591
MODEL_TENSOR.SSM_NORM: (
592+
"model.layers.{bid}.mamba.norm", # falcon-h1
584593
"backbone.layers.{bid}.mixer.norm", # mamba2
585594
),
586595

587596
MODEL_TENSOR.SSM_OUT: (
588597
"model.layers.{bid}.out_proj",
589598
"backbone.layers.{bid}.mixer.out_proj",
599+
"model.layers.{bid}.mamba.out_proj", # falcon-h1
590600
),
591601

592602
MODEL_TENSOR.TIME_MIX_W0: (

src/llama-arch.cpp

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
4646
{ LLM_ARCH_STARCODER2, "starcoder2" },
4747
{ LLM_ARCH_MAMBA, "mamba" },
4848
{ LLM_ARCH_MAMBA2, "mamba2" },
49+
{ LLM_ARCH_FALCON_H1, "falcon-h1" },
4950
{ LLM_ARCH_XVERSE, "xverse" },
5051
{ LLM_ARCH_COMMAND_R, "command-r" },
5152
{ LLM_ARCH_COHERE2, "cohere2" },
@@ -1024,6 +1025,30 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
10241025
{ LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
10251026
},
10261027
},
1028+
{
1029+
LLM_ARCH_FALCON_H1,
1030+
{
1031+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
1032+
{ LLM_TENSOR_OUTPUT, "output" },
1033+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
1034+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
1035+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
1036+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
1037+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
1038+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
1039+
{ LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" },
1040+
{ LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" },
1041+
{ LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" },
1042+
{ LLM_TENSOR_SSM_A, "blk.%d.ssm_a" },
1043+
{ LLM_TENSOR_SSM_D, "blk.%d.ssm_d" },
1044+
{ LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" },
1045+
{ LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
1046+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
1047+
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
1048+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
1049+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
1050+
},
1051+
},
10271052
{
10281053
LLM_ARCH_XVERSE,
10291054
{
@@ -1967,9 +1992,10 @@ bool llm_arch_is_recurrent(const llm_arch & arch) {
19671992
}
19681993

19691994
bool llm_arch_is_hybrid(const llm_arch & arch) {
1970-
// TODO: There are currently no hybrid models! Once there are, this will be
1971-
// the place to identify them
1995+
// List all mamba-attention hybrid models here
19721996
switch (arch) {
1997+
case LLM_ARCH_FALCON_H1:
1998+
return true;
19731999
default:
19742000
return false;
19752001
}

src/llama-arch.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ enum llm_arch {
5050
LLM_ARCH_STARCODER2,
5151
LLM_ARCH_MAMBA,
5252
LLM_ARCH_MAMBA2,
53+
LLM_ARCH_FALCON_H1,
5354
LLM_ARCH_XVERSE,
5455
LLM_ARCH_COMMAND_R,
5556
LLM_ARCH_COHERE2,

0 commit comments

Comments
 (0)