Skip to content

Commit 1334c71

Browse files
committed
Merge remote-tracking branch 'origin/master' into GraniteFour
* origin/master: ggml : prevent integer overflow in gguf tensor size calculation (ggml-org#14595) model : add skt/A.X-4.0 model vocabulary (ggml-org#14589) llama : remove unintended whitespace (ggml-org#14592) model : add support for Falcon-H1 family (ggml-org#14534) convert : fix smollm3 jinja template (ggml-org#14586)
2 parents 0b84bd5 + 26a48ad commit 1334c71

File tree

9 files changed

+608
-16
lines changed

9 files changed

+608
-16
lines changed

convert_hf_to_gguf.py

Lines changed: 140 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -818,6 +818,21 @@ def get_vocab_base_pre(self, tokenizer) -> str:
818818
if chkhsh == "7e57df22b1fe23a7b1e1c7f3dc4e3f96d43a4eb0836d0c6bdc3436d7b2f1c664":
819819
# ref: https://huggingface.co/tencent/Hunyuan-A13B-Instruct
820820
res = "hunyuan"
821+
if chkhsh == "b0a6b1c0bd5998ebd9df08611efde34a4ff03faed45ae09c43e6b31ebd4b94cf":
822+
# ref: https://huggingface.co/skt/A.X-4.0
823+
res = "a.x-4.0"
824+
if chkhsh == "a6b57017d60e6edb4d88ecc2845188e0eb333a70357e45dcc9b53964a73bbae6":
825+
# ref: https://huggingface.co/tiiuae/Falcon-H1-0.5B-Base
826+
res = "falcon-h1"
827+
if chkhsh == "60476e1243776c4fb1b993dbd7a5f15ac22f83c80afdf425fa5ae01c8d44ef86":
828+
# ref: https://huggingface.co/tiiuae/Falcon-H1-1B-Base
829+
res = "falcon-h1"
830+
if chkhsh == "3eda48b4c4dc7de733d1a8b3e3b4a85243dbbf704da2ee9d42c6beced8897896":
831+
# ref: https://huggingface.co/tiiuae/Falcon-H1-7B-Base
832+
res = "falcon-h1"
833+
if chkhsh == "48f8e02c0359c0bbdd82f26909171fac1c18a457bb47573ed1fe3bbb2c1cfd4b":
834+
# ref: https://huggingface.co/tiiuae/Falcon-H1-34B-Base
835+
res = "falcon-h1"
821836

822837
if res is None:
823838
logger.warning("\n")
@@ -4876,7 +4891,7 @@ def __init__(self, dir_model: Path, *args, **kwargs):
48764891
hparams = json.load(f)
48774892
super().__init__(dir_model, *args, hparams=hparams, **kwargs)
48784893
self.d_model = self.find_hparam(["hidden_size", "d_model", "dim"])
4879-
self.d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * self.d_model
4894+
self.d_inner = self.find_hparam(["mamba_d_ssm", "intermediate_size", "d_inner"], optional=True) or 2 * self.d_model
48804895
self.n_group = self.find_hparam(["n_groups"], optional=True) or 1
48814896

48824897
def set_vocab(self):
@@ -4900,16 +4915,18 @@ def set_vocab(self):
49004915
self._set_vocab_builtin("gpt-neox", vocab_size)
49014916

49024917
def set_gguf_parameters(self):
4903-
d_conv = self.find_hparam(["conv_kernel", "d_conv"], optional=True) or 4
4904-
d_state = self.find_hparam(["state_size", "d_state"], optional=True) or 128
4905-
head_dim = self.find_hparam(["head_dim"], optional=True) or 64
4918+
d_conv = self.find_hparam(["conv_kernel", "d_conv"], optional=True) or 4
4919+
d_state = self.find_hparam(["state_size", "d_state"], optional=True) or 128
4920+
head_dim = self.find_hparam(["mamba_d_head", "head_dim"], optional=True) or 64
49064921

49074922
rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-5
49084923

49094924
# Fail early for models which don't have a block expansion factor of 2
49104925
# TODO: does this really matter?
4911-
assert self.d_inner == 2 * self.d_model
4912-
assert self.d_inner % head_dim == 0
4926+
# skip the assertion for FalconH1 Model
4927+
if self.model_arch != gguf.MODEL_ARCH.FALCON_H1:
4928+
assert self.d_inner == 2 * self.d_model
4929+
assert self.d_inner % head_dim == 0
49134930

49144931
self.gguf_writer.add_context_length(2**20) # arbitrary value; for those who use the default
49154932
self.gguf_writer.add_embedding_length(self.d_model)
@@ -6804,6 +6821,113 @@ def set_gguf_parameters(self):
68046821
self.gguf_writer.add_audio_stack_factor(self.global_config["stack_factor"])
68056822

68066823

6824+
@ModelBase.register("FalconH1ForCausalLM")
6825+
class FalconH1Model(Mamba2Model):
6826+
model_arch = gguf.MODEL_ARCH.FALCON_H1
6827+
6828+
def __init__(self, *args, **kwargs):
6829+
# Set the hparam prefixes for Falcon Mamba2
6830+
self.hparam_prefixes = ["mamba"]
6831+
6832+
# Initialize the base Mamba2Model
6833+
super().__init__(*args, **kwargs)
6834+
6835+
# Use Llama conversion for attention
6836+
self._transformer_model_class = LlamaModel
6837+
6838+
# n_group and d_inner are used during reshape_tensors for mamaba2
6839+
self.n_group = self.find_hparam(["n_groups"])
6840+
self.d_inner = self.find_hparam(["mamba_d_ssm"])
6841+
self.d_head = self.find_hparam(["d_head"])
6842+
6843+
# Initialize any Falcon Mamba2 specific attributes
6844+
self.has_attention = True # Falcon Mamba2 has attention components
6845+
6846+
# Load Falcon-H1 multipliers from hyperparameters
6847+
self.attention_in_multiplier = self.find_hparam(["attention_in_multiplier"], optional=True)
6848+
self.attention_out_multiplier = self.find_hparam(["attention_out_multiplier"], optional=True)
6849+
self.ssm_in_multiplier = self.find_hparam(["ssm_in_multiplier"], optional=True)
6850+
self.ssm_out_multiplier = self.find_hparam(["ssm_out_multiplier"], optional=True)
6851+
self.mlp_multipliers = self.find_hparam(["mlp_multipliers"], optional=True)
6852+
self.ssm_multipliers = self.find_hparam(["ssm_multipliers"], optional=True)
6853+
self.intermediate_size = self.find_hparam(["intermediate_size"])
6854+
self.key_multiplier = self.find_hparam(["key_multiplier"], optional=True)
6855+
6856+
def find_hparam(self, keys: Iterable[str], *args, **kwargs) -> Any:
6857+
prefixed = []
6858+
for pfx in self.hparam_prefixes:
6859+
prefixed.extend(
6860+
"_".join([pfx, k])
6861+
for k in keys
6862+
)
6863+
keys = list(keys) + prefixed
6864+
return super().find_hparam(keys, *args, **kwargs)
6865+
6866+
def set_vocab(self):
6867+
self._set_vocab_gpt2()
6868+
6869+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
6870+
tensors = list(super().modify_tensors(data_torch, name, bid))
6871+
tensor = tensors[0][1]
6872+
6873+
if "down_proj" in name:
6874+
tensor = tensor * self.mlp_multipliers[1]
6875+
elif "gate_proj" in name:
6876+
tensor = tensor * self.mlp_multipliers[0]
6877+
elif "k_proj" in name:
6878+
tensor = tensor * self.key_multiplier * self.attention_in_multiplier
6879+
elif "q_proj" in name:
6880+
tensor = tensor * self.attention_in_multiplier
6881+
elif "v_proj" in name:
6882+
tensor = tensor * self.attention_in_multiplier
6883+
elif "o_proj" in name:
6884+
tensor = tensor * self.attention_out_multiplier
6885+
elif "out_proj" in name:
6886+
tensor = tensor * self.ssm_out_multiplier
6887+
elif "in_proj" in name:
6888+
tensor = tensor * self.ssm_in_multiplier
6889+
zxbcdt_multipliers = self.hparams["ssm_multipliers"]
6890+
intermediate_size = self.hparams["mamba_d_ssm"]
6891+
groups_time_state_size = self.hparams["mamba_n_groups"] * self.hparams["mamba_d_state"]
6892+
tensor[:intermediate_size, :] *= zxbcdt_multipliers[0]
6893+
tensor[intermediate_size:2 * intermediate_size, :] *= zxbcdt_multipliers[1]
6894+
tensor[2 * intermediate_size:2 * intermediate_size + groups_time_state_size, :] *= zxbcdt_multipliers[2]
6895+
tensor[2 * intermediate_size + groups_time_state_size:2 * intermediate_size + 2 * groups_time_state_size, :] *= zxbcdt_multipliers[3]
6896+
tensor[2 * intermediate_size + 2 * groups_time_state_size:, :] *= zxbcdt_multipliers[4]
6897+
elif "lm_head" in name:
6898+
tensor = tensor * self.hparams["lm_head_multiplier"]
6899+
elif "embed_tokens" in name:
6900+
tensor = tensor * self.hparams["embedding_multiplier"]
6901+
elif "mamba.norm" in name:
6902+
tensor = tensor.reshape(self.n_group, self.d_inner // self.n_group)
6903+
6904+
tensors = [(tensors[0][0], tensor)]
6905+
return tensors
6906+
6907+
def set_gguf_parameters(self):
6908+
super().set_gguf_parameters()
6909+
6910+
## General Params ##
6911+
self.gguf_writer.add_vocab_size(self.hparams["vocab_size"])
6912+
# Override some Mamba2 defaults
6913+
self.gguf_writer.add_block_count(self.block_count)
6914+
self.gguf_writer.add_context_length(self.hparams.get("max_position_embeddings", 0))
6915+
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
6916+
6917+
## Attention params ##
6918+
self.gguf_writer.add_head_count(self.hparams["num_attention_heads"]) # Override value 0 from Mamba2
6919+
self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"])
6920+
self.gguf_writer.add_key_length(self.hparams["head_dim"])
6921+
self.gguf_writer.add_value_length(self.hparams["head_dim"])
6922+
6923+
## Validation ##
6924+
assert self.hparams.get("hidden_act") in [None, "silu"], "Only SILU activation supported"
6925+
assert self.d_inner % self.d_head == 0, f"SSM inner size {self.d_inner} not a multiple of head dim {self.d_head}"
6926+
6927+
# Add any other Falcon Mamba2 specific configuration
6928+
self.gguf_writer.add_rope_freq_base(self.find_hparam(["rope_theta"]))
6929+
6930+
68076931
@ModelBase.register("HunYuanMoEV1ForCausalLM")
68086932
class HunYuanMoEModel(TextModel):
68096933
model_arch = gguf.MODEL_ARCH.HUNYUAN_MOE
@@ -6957,6 +7081,16 @@ def prepare_tensors(self):
69577081
class SmolLM3Model(LlamaModel):
69587082
model_arch = gguf.MODEL_ARCH.SMOLLM3
69597083

7084+
def set_vocab(self):
7085+
super().set_vocab()
7086+
# remove unsupported array slicing in chat template
7087+
# ref: https://huggingface.co/ggml-org/SmolLM3-3B-GGUF/discussions/1
7088+
from transformers import AutoTokenizer
7089+
tokenizer = AutoTokenizer.from_pretrained(self.dir_model)
7090+
if tokenizer.chat_template is not None:
7091+
chat_template = tokenizer.chat_template.replace("[:]", "")
7092+
self.gguf_writer.add_chat_template(chat_template)
7093+
69607094
###### CONVERSION LOGIC ######
69617095

69627096

convert_hf_to_gguf_update.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ class TOKENIZER_TYPE(IntEnum):
128128
{"name": "llama4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct", },
129129
{"name": "pixtral", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/mistral-community/pixtral-12b", },
130130
{"name": "seed-coder", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ByteDance-Seed/Seed-Coder-8B-Base", },
131+
{"name": "a.x-4.0", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/skt/A.X-4.0", },
131132
]
132133

133134
# some models are known to be broken upstream, so we will skip them as exceptions
@@ -138,6 +139,11 @@ class TOKENIZER_TYPE(IntEnum):
138139
{"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-hf", "chkhsh": "a1336059768a55c99a734006ffb02203cd450fed003e9a71886c88acf24fdbc2"},
139140
{"name": "minerva-7b", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0", "chkhsh": "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35"},
140141
{"name": "hunyuan", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tencent/Hunyuan-A13B-Instruct", "chkhsh": "7e57df22b1fe23a7b1e1c7f3dc4e3f96d43a4eb0836d0c6bdc3436d7b2f1c664"},
142+
# falcon-h1 series uses 4 different tokenizers across model sizes (0.5b - 34b), hence we need to define 4 different hashes
143+
{"name": "falcon-h1", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/Falcon-H1-0.5B-Base", "chkhsh": "a6b57017d60e6edb4d88ecc2845188e0eb333a70357e45dcc9b53964a73bbae6"},
144+
{"name": "falcon-h1", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/Falcon-H1-1B-Base", "chkhsh": "60476e1243776c4fb1b993dbd7a5f15ac22f83c80afdf425fa5ae01c8d44ef86"},
145+
{"name": "falcon-h1", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/Falcon-H1-7B-Base", "chkhsh": "3eda48b4c4dc7de733d1a8b3e3b4a85243dbbf704da2ee9d42c6beced8897896"},
146+
{"name": "falcon-h1", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/Falcon-H1-34B-Base", "chkhsh": "48f8e02c0359c0bbdd82f26909171fac1c18a457bb47573ed1fe3bbb2c1cfd4b"},
141147
]
142148

143149

ggml/src/gguf.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -631,7 +631,14 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
631631
gguf_free(ctx);
632632
return nullptr;
633633
}
634-
ctx->size += GGML_PAD(ggml_nbytes(&ti.t), ctx->alignment);
634+
size_t padded_size = GGML_PAD(ggml_nbytes(&ti.t), ctx->alignment);
635+
if (SIZE_MAX - ctx->size < padded_size) {
636+
GGML_LOG_ERROR("%s: tensor '%s' size overflow, cannot accumulate size %zu + %zu\n",
637+
__func__, ti.t.name, ctx->size, padded_size);
638+
gguf_free(ctx);
639+
return nullptr;
640+
}
641+
ctx->size += padded_size;
635642
}
636643
}
637644

gguf-py/gguf/constants.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,7 @@ class MODEL_ARCH(IntEnum):
291291
LLAMA4 = auto()
292292
DECI = auto()
293293
FALCON = auto()
294+
FALCON_H1 = auto()
294295
BAICHUAN = auto()
295296
GROK = auto()
296297
GPT2 = auto()
@@ -602,6 +603,7 @@ class MODEL_TENSOR(IntEnum):
602603
MODEL_ARCH.LLAMA4: "llama4",
603604
MODEL_ARCH.DECI: "deci",
604605
MODEL_ARCH.FALCON: "falcon",
606+
MODEL_ARCH.FALCON_H1: "falcon-h1",
605607
MODEL_ARCH.BAICHUAN: "baichuan",
606608
MODEL_ARCH.GROK: "grok",
607609
MODEL_ARCH.GPT2: "gpt2",
@@ -2313,6 +2315,40 @@ class MODEL_TENSOR(IntEnum):
23132315
MODEL_TENSOR.FFN_DOWN,
23142316
MODEL_TENSOR.FFN_UP,
23152317
],
2318+
MODEL_ARCH.FALCON_H1: [
2319+
# Token embedding
2320+
MODEL_TENSOR.TOKEN_EMBD,
2321+
2322+
# Input layernorm
2323+
MODEL_TENSOR.ATTN_NORM,
2324+
2325+
# Attention components
2326+
MODEL_TENSOR.ATTN_Q, # Query projection
2327+
MODEL_TENSOR.ATTN_K, # Key projection
2328+
MODEL_TENSOR.ATTN_V, # Value projection
2329+
MODEL_TENSOR.ATTN_OUT, # Output projection
2330+
2331+
# SSM components (Mamba2 specific)
2332+
MODEL_TENSOR.SSM_IN, # Input projection for SSM
2333+
MODEL_TENSOR.SSM_CONV1D, # Convolution layer
2334+
MODEL_TENSOR.SSM_DT, # Delta time projection
2335+
MODEL_TENSOR.SSM_A, # A parameter (log form)
2336+
MODEL_TENSOR.SSM_D, # D parameter
2337+
MODEL_TENSOR.SSM_NORM, # Normalization in SSM
2338+
MODEL_TENSOR.SSM_OUT, # Output projection
2339+
2340+
# Pre-feedforward layernorm
2341+
MODEL_TENSOR.FFN_PRE_NORM,
2342+
2343+
# Feed-forward network components
2344+
MODEL_TENSOR.FFN_GATE, # Gate projection (SwiGLU)
2345+
MODEL_TENSOR.FFN_DOWN, # Down projection
2346+
MODEL_TENSOR.FFN_UP, # Up projection
2347+
2348+
# Post-feedforward layernorm
2349+
MODEL_TENSOR.OUTPUT_NORM, # Final layer norm
2350+
MODEL_TENSOR.OUTPUT, # Output projection (lm_head)
2351+
],
23162352
MODEL_ARCH.HUNYUAN_MOE: [
23172353
MODEL_TENSOR.TOKEN_EMBD,
23182354
MODEL_TENSOR.OUTPUT_NORM,

gguf-py/gguf/tensor_mapping.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -289,12 +289,14 @@ class TensorNameMap:
289289
# Post feed-forward norm
290290
MODEL_TENSOR.FFN_PRE_NORM: (
291291
"model.layers.{bid}.pre_feedforward_layernorm", # gemma2
292+
"model.layers.{bid}.pre_ff_layernorm.weight",
292293
),
293294

294295
# Post feed-forward norm
295296
MODEL_TENSOR.FFN_POST_NORM: (
296297
"model.layers.{bid}.post_feedforward_layernorm", # gemma2 olmo2
297298
"model.layers.{bid}.post_mlp_layernorm", # glm-4-0414
299+
"model.layers.{bid}.feed_forward.up_proj",
298300
),
299301

300302
MODEL_TENSOR.FFN_GATE_INP: (
@@ -369,6 +371,7 @@ class TensorNameMap:
369371
"model.layers.{bid}.mlp.shared_expert.up_proj", # qwen2moe
370372
"model.layers.{bid}.mlp.shared_experts.up_proj", # deepseek deepseek2
371373
"model.layers.{bid}.feed_forward.shared_expert.up_proj", # llama4
374+
"model.layers.{bid}.feed_forward.down_proj",
372375
"model.layers.{bid}.mlp.shared_mlp.up_proj", # hunyuan
373376
),
374377

@@ -563,13 +566,13 @@ class TensorNameMap:
563566
MODEL_TENSOR.SSM_IN: (
564567
"model.layers.{bid}.in_proj", # mamba-hf
565568
"backbone.layers.{bid}.mixer.in_proj", # mamba
566-
"model.layers.{bid}.mamba.in_proj", # jamba, bamba
569+
"model.layers.{bid}.mamba.in_proj", # falcon-h1, jamba, bamba
567570
),
568571

569572
MODEL_TENSOR.SSM_CONV1D: (
570573
"model.layers.{bid}.conv1d", # mamba-hf
571574
"backbone.layers.{bid}.mixer.conv1d", # mamba
572-
"model.layers.{bid}.mamba.conv1d", # jamba, bamba
575+
"model.layers.{bid}.mamba.conv1d", # falcon-h1, jamba, bamba
573576
),
574577

575578
MODEL_TENSOR.SSM_X: (
@@ -581,7 +584,7 @@ class TensorNameMap:
581584
MODEL_TENSOR.SSM_DT: (
582585
"model.layers.{bid}.dt_proj", # mamba-hf
583586
"backbone.layers.{bid}.mixer.dt_proj", # mamba
584-
"model.layers.{bid}.mamba.dt_proj", # jamba, bamba
587+
"model.layers.{bid}.mamba.dt_proj", # falcon-h1, jamba, bamba
585588
),
586589

587590
MODEL_TENSOR.SSM_DT_NORM: (
@@ -591,7 +594,7 @@ class TensorNameMap:
591594
MODEL_TENSOR.SSM_A: (
592595
"model.layers.{bid}.A_log", # mamba-hf
593596
"backbone.layers.{bid}.mixer.A_log", # mamba
594-
"model.layers.{bid}.mamba.A_log", # jamba, bamba
597+
"model.layers.{bid}.mamba.A_log", # falcon-h1, jamba, bamba
595598
),
596599

597600
MODEL_TENSOR.SSM_B_NORM: (
@@ -607,18 +610,19 @@ class TensorNameMap:
607610
MODEL_TENSOR.SSM_D: (
608611
"model.layers.{bid}.D", # mamba-hf
609612
"backbone.layers.{bid}.mixer.D", # mamba
610-
"model.layers.{bid}.mamba.D", # jamba, bamba
613+
"model.layers.{bid}.mamba.D", # falcon-h1, jamba, bamba
611614
),
612615

613616
MODEL_TENSOR.SSM_NORM: (
617+
"model.layers.{bid}.mamba.norm", # falcon-h1
614618
"backbone.layers.{bid}.mixer.norm", # mamba2
615619
"model.layers.{bid}.mamba.norm", # bamba
616620
),
617621

618622
MODEL_TENSOR.SSM_OUT: (
619623
"model.layers.{bid}.out_proj", # mamba-hf
620624
"backbone.layers.{bid}.mixer.out_proj", # mamba
621-
"model.layers.{bid}.mamba.out_proj", # jamba, bamba
625+
"model.layers.{bid}.mamba.out_proj", # falcon-h1, jamba, bamba
622626
),
623627

624628
MODEL_TENSOR.TIME_MIX_W0: (

0 commit comments

Comments
 (0)