Skip to content

Commit a60a24b

Browse files
committed
Merge branch 'master' into compilade/refactor-kv-cache
2 parents f7c7a92 + 26a48ad commit a60a24b

File tree

9 files changed

+608
-17
lines changed

9 files changed

+608
-17
lines changed

convert_hf_to_gguf.py

Lines changed: 139 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -818,6 +818,21 @@ def get_vocab_base_pre(self, tokenizer) -> str:
818818
if chkhsh == "7e57df22b1fe23a7b1e1c7f3dc4e3f96d43a4eb0836d0c6bdc3436d7b2f1c664":
819819
# ref: https://huggingface.co/tencent/Hunyuan-A13B-Instruct
820820
res = "hunyuan"
821+
if chkhsh == "b0a6b1c0bd5998ebd9df08611efde34a4ff03faed45ae09c43e6b31ebd4b94cf":
822+
# ref: https://huggingface.co/skt/A.X-4.0
823+
res = "a.x-4.0"
824+
if chkhsh == "a6b57017d60e6edb4d88ecc2845188e0eb333a70357e45dcc9b53964a73bbae6":
825+
# ref: https://huggingface.co/tiiuae/Falcon-H1-0.5B-Base
826+
res = "falcon-h1"
827+
if chkhsh == "60476e1243776c4fb1b993dbd7a5f15ac22f83c80afdf425fa5ae01c8d44ef86":
828+
# ref: https://huggingface.co/tiiuae/Falcon-H1-1B-Base
829+
res = "falcon-h1"
830+
if chkhsh == "3eda48b4c4dc7de733d1a8b3e3b4a85243dbbf704da2ee9d42c6beced8897896":
831+
# ref: https://huggingface.co/tiiuae/Falcon-H1-7B-Base
832+
res = "falcon-h1"
833+
if chkhsh == "48f8e02c0359c0bbdd82f26909171fac1c18a457bb47573ed1fe3bbb2c1cfd4b":
834+
# ref: https://huggingface.co/tiiuae/Falcon-H1-34B-Base
835+
res = "falcon-h1"
821836

822837
if res is None:
823838
logger.warning("\n")
@@ -4899,17 +4914,19 @@ def set_vocab(self):
48994914
def set_gguf_parameters(self):
49004915
d_model = self.find_hparam(["hidden_size", "d_model", "dim"])
49014916
d_conv = self.find_hparam(["conv_kernel", "d_conv"], optional=True) or 4
4902-
d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * d_model
4917+
d_inner = self.find_hparam(["mamba_d_ssm", "intermediate_size", "d_inner"], optional=True) or 2 * d_model
49034918
d_state = self.find_hparam(["state_size", "d_state"], optional=True) or 128
4904-
head_dim = self.find_hparam(["head_dim"], optional=True) or 64
4919+
head_dim = self.find_hparam(["mamba_d_head", "head_dim"], optional=True) or 64
49054920
n_group = self.find_hparam(["n_groups"], optional=True) or 1
49064921

49074922
rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-5
49084923

49094924
# Fail early for models which don't have a block expansion factor of 2
49104925
# TODO: does this really matter?
4911-
assert d_inner == 2 * d_model
4912-
assert d_inner % head_dim == 0
4926+
# skip the assertion for FalconH1 Model
4927+
if self.model_arch != gguf.MODEL_ARCH.FALCON_H1:
4928+
assert d_inner == 2 * d_model
4929+
assert d_inner % head_dim == 0
49134930

49144931
self.gguf_writer.add_context_length(2**20) # arbitrary value; for those who use the default
49154932
self.gguf_writer.add_embedding_length(d_model)
@@ -4946,7 +4963,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
49464963
data_torch = data_torch.reshape((*data_torch.shape, 1))
49474964
elif self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.SSM_NORM, bid):
49484965
d_model = self.find_hparam(["hidden_size", "d_model", "dim"])
4949-
d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * d_model
4966+
d_inner = self.find_hparam(["mamba_d_ssm", "intermediate_size", "d_inner"], optional=True) or 2 * d_model
49504967
n_group = self.hparams.get("n_groups", 1)
49514968
data_torch = data_torch.reshape((n_group, d_inner // n_group))
49524969

@@ -6656,6 +6673,113 @@ def set_gguf_parameters(self):
66566673
self.gguf_writer.add_audio_stack_factor(self.global_config["stack_factor"])
66576674

66586675

6676+
@ModelBase.register("FalconH1ForCausalLM")
6677+
class FalconH1Model(Mamba2Model):
6678+
model_arch = gguf.MODEL_ARCH.FALCON_H1
6679+
6680+
def __init__(self, *args, **kwargs):
6681+
# Set the hparam prefixes for Falcon Mamba2
6682+
self.hparam_prefixes = ["mamba"]
6683+
6684+
# Initialize the base Mamba2Model
6685+
super().__init__(*args, **kwargs)
6686+
6687+
# Use Llama conversion for attention
6688+
self._transformer_model_class = LlamaModel
6689+
6690+
# n_group and d_inner are used during reshape_tensors for mamaba2
6691+
self.n_group = self.find_hparam(["n_groups"])
6692+
self.d_inner = self.find_hparam(["mamba_d_ssm"])
6693+
self.d_head = self.find_hparam(["d_head"])
6694+
6695+
# Initialize any Falcon Mamba2 specific attributes
6696+
self.has_attention = True # Falcon Mamba2 has attention components
6697+
6698+
# Load Falcon-H1 multipliers from hyperparameters
6699+
self.attention_in_multiplier = self.find_hparam(["attention_in_multiplier"], optional=True)
6700+
self.attention_out_multiplier = self.find_hparam(["attention_out_multiplier"], optional=True)
6701+
self.ssm_in_multiplier = self.find_hparam(["ssm_in_multiplier"], optional=True)
6702+
self.ssm_out_multiplier = self.find_hparam(["ssm_out_multiplier"], optional=True)
6703+
self.mlp_multipliers = self.find_hparam(["mlp_multipliers"], optional=True)
6704+
self.ssm_multipliers = self.find_hparam(["ssm_multipliers"], optional=True)
6705+
self.intermediate_size = self.find_hparam(["intermediate_size"])
6706+
self.key_multiplier = self.find_hparam(["key_multiplier"], optional=True)
6707+
6708+
def find_hparam(self, keys: Iterable[str], *args, **kwargs) -> Any:
6709+
prefixed = []
6710+
for pfx in self.hparam_prefixes:
6711+
prefixed.extend(
6712+
"_".join([pfx, k])
6713+
for k in keys
6714+
)
6715+
keys = list(keys) + prefixed
6716+
return super().find_hparam(keys, *args, **kwargs)
6717+
6718+
def set_vocab(self):
6719+
self._set_vocab_gpt2()
6720+
6721+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
6722+
tensors = list(super().modify_tensors(data_torch, name, bid))
6723+
tensor = tensors[0][1]
6724+
6725+
if "down_proj" in name:
6726+
tensor = tensor * self.mlp_multipliers[1]
6727+
elif "gate_proj" in name:
6728+
tensor = tensor * self.mlp_multipliers[0]
6729+
elif "k_proj" in name:
6730+
tensor = tensor * self.key_multiplier * self.attention_in_multiplier
6731+
elif "q_proj" in name:
6732+
tensor = tensor * self.attention_in_multiplier
6733+
elif "v_proj" in name:
6734+
tensor = tensor * self.attention_in_multiplier
6735+
elif "o_proj" in name:
6736+
tensor = tensor * self.attention_out_multiplier
6737+
elif "out_proj" in name:
6738+
tensor = tensor * self.ssm_out_multiplier
6739+
elif "in_proj" in name:
6740+
tensor = tensor * self.ssm_in_multiplier
6741+
zxbcdt_multipliers = self.hparams["ssm_multipliers"]
6742+
intermediate_size = self.hparams["mamba_d_ssm"]
6743+
groups_time_state_size = self.hparams["mamba_n_groups"] * self.hparams["mamba_d_state"]
6744+
tensor[:intermediate_size, :] *= zxbcdt_multipliers[0]
6745+
tensor[intermediate_size:2 * intermediate_size, :] *= zxbcdt_multipliers[1]
6746+
tensor[2 * intermediate_size:2 * intermediate_size + groups_time_state_size, :] *= zxbcdt_multipliers[2]
6747+
tensor[2 * intermediate_size + groups_time_state_size:2 * intermediate_size + 2 * groups_time_state_size, :] *= zxbcdt_multipliers[3]
6748+
tensor[2 * intermediate_size + 2 * groups_time_state_size:, :] *= zxbcdt_multipliers[4]
6749+
elif "lm_head" in name:
6750+
tensor = tensor * self.hparams["lm_head_multiplier"]
6751+
elif "embed_tokens" in name:
6752+
tensor = tensor * self.hparams["embedding_multiplier"]
6753+
elif "mamba.norm" in name:
6754+
tensor = tensor.reshape(self.n_group, self.d_inner // self.n_group)
6755+
6756+
tensors = [(tensors[0][0], tensor)]
6757+
return tensors
6758+
6759+
def set_gguf_parameters(self):
6760+
super().set_gguf_parameters()
6761+
6762+
## General Params ##
6763+
self.gguf_writer.add_vocab_size(self.hparams["vocab_size"])
6764+
# Override some Mamba2 defaults
6765+
self.gguf_writer.add_block_count(self.block_count)
6766+
self.gguf_writer.add_context_length(self.hparams.get("max_position_embeddings", 0))
6767+
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
6768+
6769+
## Attention params ##
6770+
self.gguf_writer.add_head_count(self.hparams["num_attention_heads"]) # Override value 0 from Mamba2
6771+
self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"])
6772+
self.gguf_writer.add_key_length(self.hparams["head_dim"])
6773+
self.gguf_writer.add_value_length(self.hparams["head_dim"])
6774+
6775+
## Validation ##
6776+
assert self.hparams.get("hidden_act") in [None, "silu"], "Only SILU activation supported"
6777+
assert self.d_inner % self.d_head == 0, f"SSM inner size {self.d_inner} not a multiple of head dim {self.d_head}"
6778+
6779+
# Add any other Falcon Mamba2 specific configuration
6780+
self.gguf_writer.add_rope_freq_base(self.find_hparam(["rope_theta"]))
6781+
6782+
66596783
@ModelBase.register("HunYuanMoEV1ForCausalLM")
66606784
class HunYuanMoEModel(TextModel):
66616785
model_arch = gguf.MODEL_ARCH.HUNYUAN_MOE
@@ -6809,6 +6933,16 @@ def prepare_tensors(self):
68096933
class SmolLM3Model(LlamaModel):
68106934
model_arch = gguf.MODEL_ARCH.SMOLLM3
68116935

6936+
def set_vocab(self):
6937+
super().set_vocab()
6938+
# remove unsupported array slicing in chat template
6939+
# ref: https://huggingface.co/ggml-org/SmolLM3-3B-GGUF/discussions/1
6940+
from transformers import AutoTokenizer
6941+
tokenizer = AutoTokenizer.from_pretrained(self.dir_model)
6942+
if tokenizer.chat_template is not None:
6943+
chat_template = tokenizer.chat_template.replace("[:]", "")
6944+
self.gguf_writer.add_chat_template(chat_template)
6945+
68126946
###### CONVERSION LOGIC ######
68136947

68146948

convert_hf_to_gguf_update.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ class TOKENIZER_TYPE(IntEnum):
128128
{"name": "llama4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct", },
129129
{"name": "pixtral", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/mistral-community/pixtral-12b", },
130130
{"name": "seed-coder", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ByteDance-Seed/Seed-Coder-8B-Base", },
131+
{"name": "a.x-4.0", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/skt/A.X-4.0", },
131132
]
132133

133134
# some models are known to be broken upstream, so we will skip them as exceptions
@@ -138,6 +139,11 @@ class TOKENIZER_TYPE(IntEnum):
138139
{"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-hf", "chkhsh": "a1336059768a55c99a734006ffb02203cd450fed003e9a71886c88acf24fdbc2"},
139140
{"name": "minerva-7b", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0", "chkhsh": "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35"},
140141
{"name": "hunyuan", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tencent/Hunyuan-A13B-Instruct", "chkhsh": "7e57df22b1fe23a7b1e1c7f3dc4e3f96d43a4eb0836d0c6bdc3436d7b2f1c664"},
142+
# falcon-h1 series uses 4 different tokenizers across model sizes (0.5b - 34b), hence we need to define 4 different hashes
143+
{"name": "falcon-h1", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/Falcon-H1-0.5B-Base", "chkhsh": "a6b57017d60e6edb4d88ecc2845188e0eb333a70357e45dcc9b53964a73bbae6"},
144+
{"name": "falcon-h1", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/Falcon-H1-1B-Base", "chkhsh": "60476e1243776c4fb1b993dbd7a5f15ac22f83c80afdf425fa5ae01c8d44ef86"},
145+
{"name": "falcon-h1", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/Falcon-H1-7B-Base", "chkhsh": "3eda48b4c4dc7de733d1a8b3e3b4a85243dbbf704da2ee9d42c6beced8897896"},
146+
{"name": "falcon-h1", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/Falcon-H1-34B-Base", "chkhsh": "48f8e02c0359c0bbdd82f26909171fac1c18a457bb47573ed1fe3bbb2c1cfd4b"},
141147
]
142148

143149

ggml/src/gguf.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -631,7 +631,14 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
631631
gguf_free(ctx);
632632
return nullptr;
633633
}
634-
ctx->size += GGML_PAD(ggml_nbytes(&ti.t), ctx->alignment);
634+
size_t padded_size = GGML_PAD(ggml_nbytes(&ti.t), ctx->alignment);
635+
if (SIZE_MAX - ctx->size < padded_size) {
636+
GGML_LOG_ERROR("%s: tensor '%s' size overflow, cannot accumulate size %zu + %zu\n",
637+
__func__, ti.t.name, ctx->size, padded_size);
638+
gguf_free(ctx);
639+
return nullptr;
640+
}
641+
ctx->size += padded_size;
635642
}
636643
}
637644

gguf-py/gguf/constants.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,7 @@ class MODEL_ARCH(IntEnum):
288288
LLAMA4 = auto()
289289
DECI = auto()
290290
FALCON = auto()
291+
FALCON_H1 = auto()
291292
BAICHUAN = auto()
292293
GROK = auto()
293294
GPT2 = auto()
@@ -667,6 +668,7 @@ class MODEL_TENSOR(IntEnum):
667668
MODEL_ARCH.DOTS1: "dots1",
668669
MODEL_ARCH.ARCEE: "arcee",
669670
MODEL_ARCH.ERNIE4_5: "ernie4_5",
671+
MODEL_ARCH.FALCON_H1: "falcon-h1",
670672
MODEL_ARCH.HUNYUAN_MOE: "hunyuan-moe",
671673
MODEL_ARCH.SMOLLM3: "smollm3",
672674
}
@@ -2251,6 +2253,40 @@ class MODEL_TENSOR(IntEnum):
22512253
MODEL_TENSOR.FFN_DOWN,
22522254
MODEL_TENSOR.FFN_UP,
22532255
],
2256+
MODEL_ARCH.FALCON_H1: [
2257+
# Token embedding
2258+
MODEL_TENSOR.TOKEN_EMBD,
2259+
2260+
# Input layernorm
2261+
MODEL_TENSOR.ATTN_NORM,
2262+
2263+
# Attention components
2264+
MODEL_TENSOR.ATTN_Q, # Query projection
2265+
MODEL_TENSOR.ATTN_K, # Key projection
2266+
MODEL_TENSOR.ATTN_V, # Value projection
2267+
MODEL_TENSOR.ATTN_OUT, # Output projection
2268+
2269+
# SSM components (Mamba2 specific)
2270+
MODEL_TENSOR.SSM_IN, # Input projection for SSM
2271+
MODEL_TENSOR.SSM_CONV1D, # Convolution layer
2272+
MODEL_TENSOR.SSM_DT, # Delta time projection
2273+
MODEL_TENSOR.SSM_A, # A parameter (log form)
2274+
MODEL_TENSOR.SSM_D, # D parameter
2275+
MODEL_TENSOR.SSM_NORM, # Normalization in SSM
2276+
MODEL_TENSOR.SSM_OUT, # Output projection
2277+
2278+
# Pre-feedforward layernorm
2279+
MODEL_TENSOR.FFN_PRE_NORM,
2280+
2281+
# Feed-forward network components
2282+
MODEL_TENSOR.FFN_GATE, # Gate projection (SwiGLU)
2283+
MODEL_TENSOR.FFN_DOWN, # Down projection
2284+
MODEL_TENSOR.FFN_UP, # Up projection
2285+
2286+
# Post-feedforward layernorm
2287+
MODEL_TENSOR.OUTPUT_NORM, # Final layer norm
2288+
MODEL_TENSOR.OUTPUT, # Output projection (lm_head)
2289+
],
22542290
MODEL_ARCH.HUNYUAN_MOE: [
22552291
MODEL_TENSOR.TOKEN_EMBD,
22562292
MODEL_TENSOR.OUTPUT_NORM,

gguf-py/gguf/tensor_mapping.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -288,12 +288,14 @@ class TensorNameMap:
288288
# Post feed-forward norm
289289
MODEL_TENSOR.FFN_PRE_NORM: (
290290
"model.layers.{bid}.pre_feedforward_layernorm", # gemma2
291+
"model.layers.{bid}.pre_ff_layernorm.weight",
291292
),
292293

293294
# Post feed-forward norm
294295
MODEL_TENSOR.FFN_POST_NORM: (
295296
"model.layers.{bid}.post_feedforward_layernorm", # gemma2 olmo2
296297
"model.layers.{bid}.post_mlp_layernorm", # glm-4-0414
298+
"model.layers.{bid}.feed_forward.up_proj",
297299
),
298300

299301
MODEL_TENSOR.FFN_GATE_INP: (
@@ -367,6 +369,7 @@ class TensorNameMap:
367369
"model.layers.{bid}.mlp.shared_expert.up_proj", # qwen2moe
368370
"model.layers.{bid}.mlp.shared_experts.up_proj", # deepseek deepseek2
369371
"model.layers.{bid}.feed_forward.shared_expert.up_proj", # llama4
372+
"model.layers.{bid}.feed_forward.down_proj",
370373
"model.layers.{bid}.mlp.shared_mlp.up_proj", # hunyuan
371374
),
372375

@@ -559,13 +562,13 @@ class TensorNameMap:
559562
MODEL_TENSOR.SSM_IN: (
560563
"model.layers.{bid}.in_proj", # mamba-hf
561564
"backbone.layers.{bid}.mixer.in_proj", # mamba
562-
"model.layers.{bid}.mamba.in_proj", # jamba
565+
"model.layers.{bid}.mamba.in_proj", # jamba falcon-h1
563566
),
564567

565568
MODEL_TENSOR.SSM_CONV1D: (
566569
"model.layers.{bid}.conv1d", # mamba-hf
567570
"backbone.layers.{bid}.mixer.conv1d", # mamba
568-
"model.layers.{bid}.mamba.conv1d", # jamba
571+
"model.layers.{bid}.mamba.conv1d", # jamba falcon-h1
569572
),
570573

571574
MODEL_TENSOR.SSM_X: (
@@ -577,7 +580,7 @@ class TensorNameMap:
577580
MODEL_TENSOR.SSM_DT: (
578581
"model.layers.{bid}.dt_proj", # mamba-hf
579582
"backbone.layers.{bid}.mixer.dt_proj", # mamba
580-
"model.layers.{bid}.mamba.dt_proj", # jamba
583+
"model.layers.{bid}.mamba.dt_proj", # jamba falcon-h1
581584
),
582585

583586
MODEL_TENSOR.SSM_DT_NORM: (
@@ -587,7 +590,7 @@ class TensorNameMap:
587590
MODEL_TENSOR.SSM_A: (
588591
"model.layers.{bid}.A_log", # mamba-hf
589592
"backbone.layers.{bid}.mixer.A_log", # mamba
590-
"model.layers.{bid}.mamba.A_log", # jamba
593+
"model.layers.{bid}.mamba.A_log", # jamba falcon-h1
591594
),
592595

593596
MODEL_TENSOR.SSM_B_NORM: (
@@ -603,17 +606,18 @@ class TensorNameMap:
603606
MODEL_TENSOR.SSM_D: (
604607
"model.layers.{bid}.D", # mamba-hf
605608
"backbone.layers.{bid}.mixer.D", # mamba
606-
"model.layers.{bid}.mamba.D", # jamba
609+
"model.layers.{bid}.mamba.D", # jamba falcon-h1
607610
),
608611

609612
MODEL_TENSOR.SSM_NORM: (
613+
"model.layers.{bid}.mamba.norm", # falcon-h1
610614
"backbone.layers.{bid}.mixer.norm", # mamba2
611615
),
612616

613617
MODEL_TENSOR.SSM_OUT: (
614618
"model.layers.{bid}.out_proj", # mamba-hf
615619
"backbone.layers.{bid}.mixer.out_proj", # mamba
616-
"model.layers.{bid}.mamba.out_proj", # jamba
620+
"model.layers.{bid}.mamba.out_proj", # jamba falcon-h1
617621
),
618622

619623
MODEL_TENSOR.TIME_MIX_W0: (

0 commit comments

Comments
 (0)