Skip to content

Commit 45ff609

Browse files
committed
llama : Add Gemma 3 text-only support
1 parent 10f2e81 commit 45ff609

File tree

6 files changed

+316
-0
lines changed

6 files changed

+316
-0
lines changed

convert_hf_to_gguf.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -861,6 +861,9 @@ def _create_vocab_sentencepiece(self):
861861
for token_id, token_data in added_tokens_decoder.items():
862862
token_id = int(token_id)
863863
token: str = token_data["content"]
864+
if token_id >= vocab_size:
865+
logger.warning(f'ignore token {token_id}: id is out of range, max={vocab_size - 1}')
866+
continue
864867
if toktypes[token_id] != SentencePieceTokenTypes.UNUSED:
865868
if tokens[token_id] != token.encode("utf-8"):
866869
logger.warning(f'replacing token {token_id}: {tokens[token_id].decode("utf-8")!r} -> {token!r}')
@@ -3322,6 +3325,83 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
33223325
return [(self.map_tensor_name(name), data_torch)]
33233326

33243327

3328+
@Model.register("Gemma3ForCausalLM", "Gemma3ForConditionalGeneration")
3329+
class Gemma3Model(Model):
3330+
model_arch = gguf.MODEL_ARCH.GEMMA3
3331+
has_vision: bool = False
3332+
3333+
# we need to merge the text_config into the root level of hparams
3334+
def __init__(self, *args, **kwargs):
3335+
hparams = Model.load_hparams(kwargs["dir_model"])
3336+
if "text_config" in hparams:
3337+
hparams = {**hparams, **hparams["text_config"]}
3338+
kwargs["hparams"] = hparams
3339+
super().__init__(*args, **kwargs)
3340+
if "vision_config" in hparams:
3341+
logger.info("Has vision encoder, but it will be ignored")
3342+
self.has_vision = True
3343+
3344+
def write(self):
3345+
super().write()
3346+
if self.has_vision:
3347+
logger.info("NOTE: this script only convert the language model to GGUF")
3348+
logger.info(" for the vision model, please use gemma3_convert_encoder_to_gguf.py")
3349+
3350+
def set_vocab(self):
3351+
self._set_vocab_sentencepiece()
3352+
3353+
self.gguf_writer.add_add_space_prefix(False)
3354+
3355+
def set_gguf_parameters(self):
3356+
hparams = self.hparams
3357+
block_count = hparams["num_hidden_layers"]
3358+
3359+
# some default values are not specified in the hparams
3360+
self.gguf_writer.add_context_length(hparams.get("max_position_embeddings", 131072))
3361+
self.gguf_writer.add_embedding_length(hparams["hidden_size"])
3362+
self.gguf_writer.add_block_count(block_count)
3363+
self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
3364+
self.gguf_writer.add_head_count(hparams.get("num_attention_heads", 8))
3365+
self.gguf_writer.add_layer_norm_rms_eps(self.hparams.get("rms_norm_eps", 1e-6))
3366+
self.gguf_writer.add_key_length(hparams.get("head_dim", 256))
3367+
self.gguf_writer.add_value_length(hparams.get("head_dim", 256))
3368+
self.gguf_writer.add_file_type(self.ftype)
3369+
self.gguf_writer.add_rope_freq_base(hparams.get("rope_theta", 1_000_000.0)) # for global layers
3370+
# both attn_logit_softcapping and final_logit_softcapping are removed in Gemma3
3371+
assert hparams.get("attn_logit_softcapping") == None
3372+
assert hparams.get("final_logit_softcapping") == None
3373+
self.gguf_writer.add_sliding_window(hparams["sliding_window"])
3374+
self.gguf_writer.add_head_count_kv(hparams.get("num_key_value_heads", 4))
3375+
if hparams.get("rope_scaling") is not None:
3376+
assert hparams["rope_scaling"]["rope_type"] == "linear"
3377+
# important: this rope_scaling is only applied for global layers, and not used by 1B model
3378+
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
3379+
self.gguf_writer.add_rope_scaling_factor(hparams["rope_scaling"]["factor"])
3380+
3381+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
3382+
del bid # unused
3383+
3384+
if name.startswith("language_model."):
3385+
name = name.replace("language_model.", "")
3386+
elif name.startswith("multi_modal_projector.") or name.startswith("vision_tower.") \
3387+
or name.startswith("multimodal_projector.") or name.startswith("vision_model."): # this is for old HF model, should be removed later
3388+
# ignore vision tensors
3389+
return []
3390+
3391+
# remove OOV (out-of-vocabulary) rows in token_embd
3392+
if "embed_tokens.weight" in name:
3393+
vocab = self._create_vocab_sentencepiece()
3394+
tokens = vocab[0]
3395+
data_torch = data_torch[:len(tokens)]
3396+
3397+
# ref code in Gemma3RMSNorm
3398+
# output = output * (1.0 + self.weight.float())
3399+
if name.endswith("norm.weight"):
3400+
data_torch = data_torch + 1
3401+
3402+
return [(self.map_tensor_name(name), data_torch)]
3403+
3404+
33253405
@Model.register("Starcoder2ForCausalLM")
33263406
class StarCoder2Model(Model):
33273407
model_arch = gguf.MODEL_ARCH.STARCODER2

gguf-py/gguf/constants.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,7 @@ class MODEL_ARCH(IntEnum):
253253
MINICPM3 = auto()
254254
GEMMA = auto()
255255
GEMMA2 = auto()
256+
GEMMA3 = auto()
256257
STARCODER2 = auto()
257258
RWKV6 = auto()
258259
RWKV6QWEN2 = auto()
@@ -440,6 +441,7 @@ class MODEL_TENSOR(IntEnum):
440441
MODEL_ARCH.MINICPM3: "minicpm3",
441442
MODEL_ARCH.GEMMA: "gemma",
442443
MODEL_ARCH.GEMMA2: "gemma2",
444+
MODEL_ARCH.GEMMA3: "gemma3",
443445
MODEL_ARCH.STARCODER2: "starcoder2",
444446
MODEL_ARCH.RWKV6: "rwkv6",
445447
MODEL_ARCH.RWKV6QWEN2: "rwkv6qwen2",
@@ -1077,6 +1079,23 @@ class MODEL_TENSOR(IntEnum):
10771079
MODEL_TENSOR.FFN_PRE_NORM,
10781080
MODEL_TENSOR.FFN_POST_NORM,
10791081
],
1082+
MODEL_ARCH.GEMMA3: [
1083+
MODEL_TENSOR.TOKEN_EMBD,
1084+
MODEL_TENSOR.OUTPUT_NORM,
1085+
MODEL_TENSOR.ATTN_Q,
1086+
MODEL_TENSOR.ATTN_Q_NORM,
1087+
MODEL_TENSOR.ATTN_K,
1088+
MODEL_TENSOR.ATTN_K_NORM,
1089+
MODEL_TENSOR.ATTN_V,
1090+
MODEL_TENSOR.ATTN_OUT,
1091+
MODEL_TENSOR.FFN_GATE,
1092+
MODEL_TENSOR.FFN_DOWN,
1093+
MODEL_TENSOR.FFN_UP,
1094+
MODEL_TENSOR.ATTN_NORM,
1095+
MODEL_TENSOR.ATTN_POST_NORM,
1096+
MODEL_TENSOR.FFN_PRE_NORM,
1097+
MODEL_TENSOR.FFN_POST_NORM,
1098+
],
10801099
MODEL_ARCH.STARCODER2: [
10811100
MODEL_TENSOR.TOKEN_EMBD,
10821101
MODEL_TENSOR.OUTPUT_NORM,

src/llama-arch.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
3636
{ LLM_ARCH_MINICPM3, "minicpm3" },
3737
{ LLM_ARCH_GEMMA, "gemma" },
3838
{ LLM_ARCH_GEMMA2, "gemma2" },
39+
{ LLM_ARCH_GEMMA3, "gemma3" },
3940
{ LLM_ARCH_STARCODER2, "starcoder2" },
4041
{ LLM_ARCH_MAMBA, "mamba" },
4142
{ LLM_ARCH_XVERSE, "xverse" },
@@ -766,6 +767,26 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
766767
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
767768
},
768769
},
770+
{
771+
LLM_ARCH_GEMMA3,
772+
{
773+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
774+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
775+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
776+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
777+
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
778+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
779+
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
780+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
781+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
782+
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
783+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
784+
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
785+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
786+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
787+
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
788+
},
789+
},
769790
{
770791
LLM_ARCH_STARCODER2,
771792
{

src/llama-arch.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ enum llm_arch {
4040
LLM_ARCH_MINICPM3,
4141
LLM_ARCH_GEMMA,
4242
LLM_ARCH_GEMMA2,
43+
LLM_ARCH_GEMMA3,
4344
LLM_ARCH_STARCODER2,
4445
LLM_ARCH_MAMBA,
4546
LLM_ARCH_XVERSE,

src/llama-model.cpp

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -864,6 +864,23 @@ void llama_model::load_hparams(llama_model_loader & ml) {
864864
default: type = LLM_TYPE_UNKNOWN;
865865
}
866866
} break;
867+
case LLM_ARCH_GEMMA3:
868+
{
869+
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
870+
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
871+
872+
switch (hparams.n_layer) {
873+
case 26: type = LLM_TYPE_1B; break;
874+
case 34: type = LLM_TYPE_4B; break;
875+
case 48: type = LLM_TYPE_12B; break;
876+
case 62: type = LLM_TYPE_27B; break;
877+
default: type = LLM_TYPE_UNKNOWN;
878+
}
879+
880+
hparams.f_attention_scale = type == LLM_TYPE_27B
881+
? 1.0f / sqrtf(float(hparams.n_embd / hparams.n_head(0)))
882+
: 1.0f / sqrtf(float(hparams.n_embd_head_k));
883+
} break;
867884
case LLM_ARCH_STARCODER2:
868885
{
869886
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
@@ -2454,6 +2471,35 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
24542471
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
24552472
layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0);
24562473

2474+
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
2475+
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
2476+
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
2477+
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
2478+
layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
2479+
}
2480+
} break;
2481+
case LLM_ARCH_GEMMA3:
2482+
{
2483+
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
2484+
2485+
// output
2486+
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
2487+
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); // same as tok_embd, duplicated to allow offloading
2488+
2489+
for (int i = 0; i < n_layer; ++i) {
2490+
auto & layer = layers[i];
2491+
2492+
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
2493+
2494+
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
2495+
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0);
2496+
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0);
2497+
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
2498+
2499+
layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0);
2500+
layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0);
2501+
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0);
2502+
24572503
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
24582504
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
24592505
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
@@ -3650,6 +3696,7 @@ void llama_model::print_info() const {
36503696
LLAMA_LOG_INFO("%s: f_clamp_kqv = %.1e\n", __func__, hparams.f_clamp_kqv);
36513697
LLAMA_LOG_INFO("%s: f_max_alibi_bias = %.1e\n", __func__, hparams.f_max_alibi_bias);
36523698
LLAMA_LOG_INFO("%s: f_logit_scale = %.1e\n", __func__, hparams.f_logit_scale);
3699+
LLAMA_LOG_INFO("%s: f_attn_scale = %.1e\n", __func__, hparams.f_attention_scale);
36533700
LLAMA_LOG_INFO("%s: n_ff = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_ff(il); }, hparams.n_layer).c_str());
36543701
LLAMA_LOG_INFO("%s: n_expert = %u\n", __func__, hparams.n_expert);
36553702
LLAMA_LOG_INFO("%s: n_expert_used = %u\n", __func__, hparams.n_expert_used);
@@ -3923,6 +3970,7 @@ enum llama_rope_type llama_model_rope_type(const struct llama_model * model) {
39233970
case LLM_ARCH_PHIMOE:
39243971
case LLM_ARCH_GEMMA:
39253972
case LLM_ARCH_GEMMA2:
3973+
case LLM_ARCH_GEMMA3:
39263974
case LLM_ARCH_STARCODER2:
39273975
case LLM_ARCH_OPENELM:
39283976
case LLM_ARCH_GPTNEOX:

0 commit comments

Comments
 (0)