Skip to content

Add EXAONE 4.0 model architecture #14630

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Jul 18, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -761,6 +761,9 @@ def get_vocab_base_pre(self, tokenizer) -> str:
if chkhsh == "4e2b24cc4770243d65a2c9ec19770a72f08cffc161adbb73fcbb6b7dd45a0aae":
# ref: https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct
res = "exaone"
if chkhsh == "2085e1638f6c377a0aa4ead21b27bb4cb941bf800df86ed391011769c1758dfb":
# ref: temporary model
res = "exaone4"
if chkhsh == "fcace8b9cac38ce847670c970cd5892031a753a1ef381abd1d9af00f713da085":
# ref: https://huggingface.co/microsoft/phi-2
res = "phi-2"
Expand Down Expand Up @@ -6388,6 +6391,77 @@ def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FREQS), torch.tensor(rope_factors, dtype=torch.float32))


@ModelBase.register("Exaone4ForCausalLM")
class Exaone4Model(TextModel):
model_arch = gguf.MODEL_ARCH.EXAONE4

def set_vocab(self):
tokens, toktypes, tokpre = self.get_vocab_base()
self.gguf_writer.add_tokenizer_model("gpt2")
self.gguf_writer.add_tokenizer_pre(tokpre)
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_types(toktypes)

special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True)
special_vocab.chat_template = "exaone4"
special_vocab.add_to_gguf(self.gguf_writer)

def set_gguf_parameters(self):
super().set_gguf_parameters()
hparams = self.hparams
self.gguf_writer.add_vocab_size(hparams["vocab_size"])

self.gguf_writer.add_context_length(hparams.get("max_position_embeddings", 131072))
self.gguf_writer.add_embedding_length(hparams["hidden_size"])
self.gguf_writer.add_feed_forward_length(hparams.get("intermediate_size", 4 * hparams["hidden_size"]))
self.gguf_writer.add_block_count(hparams["num_hidden_layers"])
self.gguf_writer.add_head_count(hparams["num_attention_heads"])
self.gguf_writer.add_head_count_kv(hparams.get("num_key_value_heads", hparams["num_attention_heads"]))
self.gguf_writer.add_layer_norm_rms_eps(hparams.get("layer_norm_epsilon", 1e-5))
self.gguf_writer.add_rope_freq_base(hparams.get("rope_theta", 1_000_000.0))
self.gguf_writer.add_key_length(hparams["head_dim"])
self.gguf_writer.add_value_length(hparams["head_dim"])
self.gguf_writer.add_file_type(self.ftype)

if hparams.get("sliding_window") is not None:
self.gguf_writer.add_sliding_window(hparams["sliding_window"])
# sliding window pattern 어떻게?

rope_scaling = self.hparams.get("rope_scaling") or {}
if rope_scaling.get("rope_type", rope_scaling.get("type")) == "linear" and "factor" in rope_scaling:
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"])

def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
if rope_scaling := self.find_hparam(["rope_scaling"], optional=True):
if rope_scaling.get("rope_type", '').lower() == "llama3":
base = self.hparams.get("rope_theta", 1_000_000.0)
if (dim := self.hparams.get("head_dim")) is None:
dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))

factor = rope_scaling.get("factor", 16.0)
low_freq_factor = rope_scaling.get("low_freq_factor", 1.0)
high_freq_factor = rope_scaling.get("high_freq_factor", 4.0)
old_context_len = self.hparams.get("original_max_position_embeddings", 8192)

low_freq_wavelen = old_context_len / low_freq_factor
high_freq_wavelen = old_context_len / high_freq_factor

rope_factors = []
for freq in freqs:
wavelen = 2 * math.pi / freq
if wavelen < high_freq_wavelen:
rope_factors.append(1)
elif wavelen > low_freq_wavelen:
rope_factors.append(factor)
else:
smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
rope_factors.append(1 / ((1 - smooth) / factor + smooth))

yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FREQS), torch.tensor(rope_factors, dtype=torch.float32))


@ModelBase.register("GraniteForCausalLM")
class GraniteModel(LlamaModel):
"""Conversion for IBM's GraniteForCausalLM"""
Expand Down
1 change: 1 addition & 0 deletions convert_hf_to_gguf_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ class TOKENIZER_TYPE(IntEnum):
{'name': "bloom", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/bigscience/bloom", },
{'name': "gpt3-finnish", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/TurkuNLP/gpt3-finnish-small", },
{"name": "exaone", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct", },
{"name": "exaone4", "tokt": TOKENIZER_TYPE.BPE, "repo": "temporary model", },
{"name": "phi-2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/microsoft/phi-2", },
{"name": "chameleon", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/facebook/chameleon-7b", },
{"name": "roberta-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/sentence-transformers/stsb-roberta-base"},
Expand Down
19 changes: 19 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,7 @@ class MODEL_ARCH(IntEnum):
JAIS = auto()
NEMOTRON = auto()
EXAONE = auto()
EXAONE4 = auto()
GRANITE = auto()
GRANITE_MOE = auto()
GRANITE_HYBRID = auto()
Expand Down Expand Up @@ -660,6 +661,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_ARCH.JAIS: "jais",
MODEL_ARCH.NEMOTRON: "nemotron",
MODEL_ARCH.EXAONE: "exaone",
MODEL_ARCH.EXAONE4: "exaone4",
MODEL_ARCH.GRANITE: "granite",
MODEL_ARCH.GRANITE_MOE: "granitemoe",
MODEL_ARCH.GRANITE_HYBRID: "granitehybrid",
Expand Down Expand Up @@ -2113,6 +2115,23 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.EXAONE4: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_Q_NORM,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_K_NORM,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.ATTN_POST_NORM,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.FFN_POST_NORM,
],
MODEL_ARCH.GRANITE: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
Expand Down
28 changes: 14 additions & 14 deletions gguf-py/gguf/tensor_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class TensorNameMap:
"transformer.wte", # gpt2 gpt-j mpt refact qwen dbrx jais exaone
"transformer.word_embeddings", # falcon
"word_embeddings", # bloom
"model.embed_tokens", # llama-hf nemotron olmoe olmo2 rwkv6qwen2 glm4-0414 granite-hybrid
"model.embed_tokens", # llama-hf nemotron olmoe olmo2 rwkv6qwen2 glm4-0414 granite-hybrid exaone4
"tok_embeddings", # llama-pth
"embeddings.word_embeddings", # bert nomic-bert
"language_model.embedding.word_embeddings", # persimmon
Expand Down Expand Up @@ -62,7 +62,7 @@ class TensorNameMap:
# Output
MODEL_TENSOR.OUTPUT: (
"embed_out", # gptneox
"lm_head", # gpt2 mpt falcon llama-hf baichuan qwen mamba dbrx jais nemotron exaone olmoe olmo2 phimoe
"lm_head", # gpt2 mpt falcon llama-hf baichuan qwen mamba dbrx jais nemotron exaone exaone4 olmoe olmo2 phimoe
"output", # llama-pth bloom internlm2
"word_embeddings_for_head", # persimmon
"lm_head.linear", # phi2
Expand All @@ -76,7 +76,7 @@ class TensorNameMap:
MODEL_TENSOR.OUTPUT_NORM: (
"gpt_neox.final_layer_norm", # gptneox
"transformer.ln_f", # gpt2 gpt-j falcon jais exaone
"model.norm", # llama-hf baichuan internlm2 olmoe olmo2 phimoe
"model.norm", # llama-hf baichuan internlm2 olmoe olmo2 phimoe exaone4
"norm", # llama-pth
"transformer.norm_f", # mpt dbrx
"ln_f", # refact bloom qwen gpt2
Expand Down Expand Up @@ -168,7 +168,7 @@ class TensorNameMap:

# Attention query
MODEL_TENSOR.ATTN_Q: (
"model.layers.{bid}.self_attn.q_proj", # llama-hf nemotron olmoe olmo2 phimoe
"model.layers.{bid}.self_attn.q_proj", # llama-hf nemotron olmoe olmo2 phimoe exaone4
"model.layers.{bid}.self_attn.q_proj_no_perm", # llama-custom
"layers.{bid}.attention.wq", # llama-pth
"encoder.layer.{bid}.attention.self.query", # bert
Expand All @@ -183,7 +183,7 @@ class TensorNameMap:

# Attention key
MODEL_TENSOR.ATTN_K: (
"model.layers.{bid}.self_attn.k_proj", # llama-hf nemotron olmoe olmo2 phimoe
"model.layers.{bid}.self_attn.k_proj", # llama-hf nemotron olmoe olmo2 phimoe exaone4
"model.layers.{bid}.self_attn.k_proj_no_perm", # llama-custom
"layers.{bid}.attention.wk", # llama-pth
"encoder.layer.{bid}.attention.self.key", # bert
Expand All @@ -199,7 +199,7 @@ class TensorNameMap:

# Attention value
MODEL_TENSOR.ATTN_V: (
"model.layers.{bid}.self_attn.v_proj", # llama-hf nemotron olmoe olmo2 phimoe
"model.layers.{bid}.self_attn.v_proj", # llama-hf nemotron olmoe olmo2 phimoe exaone4
"layers.{bid}.attention.wv", # llama-pth
"encoder.layer.{bid}.attention.self.value", # bert
"transformer.layer.{bid}.attention.v_lin", # distillbert
Expand All @@ -219,7 +219,7 @@ class TensorNameMap:
"transformer.blocks.{bid}.attn.out_proj", # mpt
"transformer.h.{bid}.self_attention.dense", # falcon
"h.{bid}.self_attention.dense", # bloom
"model.layers.{bid}.self_attn.o_proj", # llama-hf nemotron olmoe olmo2 phimoe
"model.layers.{bid}.self_attn.o_proj", # llama-hf nemotron olmoe olmo2 phimoe exaone4
"model.layers.{bid}.self_attn.linear_attn", # deci
"layers.{bid}.attention.wo", # llama-pth
"encoder.layer.{bid}.attention.output.dense", # bert
Expand Down Expand Up @@ -252,7 +252,7 @@ class TensorNameMap:
),

MODEL_TENSOR.ATTN_POST_NORM: (
"model.layers.{bid}.post_attention_layernorm", # gemma2 olmo2 # ge
"model.layers.{bid}.post_attention_layernorm", # gemma2 olmo2 exaone4 # ge
"model.layers.{bid}.post_self_attn_layernorm", # glm-4-0414
),

Expand Down Expand Up @@ -293,7 +293,7 @@ class TensorNameMap:

# Post feed-forward norm
MODEL_TENSOR.FFN_POST_NORM: (
"model.layers.{bid}.post_feedforward_layernorm", # gemma2 olmo2
"model.layers.{bid}.post_feedforward_layernorm", # gemma2 olmo2 exaone4
"model.layers.{bid}.post_mlp_layernorm", # glm-4-0414
"model.layers.{bid}.feed_forward.up_proj",
),
Expand Down Expand Up @@ -325,7 +325,7 @@ class TensorNameMap:
"transformer.blocks.{bid}.ffn.up_proj", # mpt
"transformer.h.{bid}.mlp.dense_h_to_4h", # falcon
"h.{bid}.mlp.dense_h_to_4h", # bloom
"model.layers.{bid}.mlp.up_proj", # llama-hf refact nemotron olmo2
"model.layers.{bid}.mlp.up_proj", # llama-hf refact nemotron olmo2 exaone4
"layers.{bid}.feed_forward.w3", # llama-pth
"encoder.layer.{bid}.intermediate.dense", # bert
"transformer.layer.{bid}.ffn.lin1", # distillbert
Expand Down Expand Up @@ -378,7 +378,7 @@ class TensorNameMap:

# Feed-forward gate
MODEL_TENSOR.FFN_GATE: (
"model.layers.{bid}.mlp.gate_proj", # llama-hf refact olmo2
"model.layers.{bid}.mlp.gate_proj", # llama-hf refact olmo2 exaone4
"layers.{bid}.feed_forward.w1", # llama-pth
"transformer.h.{bid}.mlp.w2", # qwen
"transformer.h.{bid}.mlp.c_fc2", # jais
Expand Down Expand Up @@ -415,7 +415,7 @@ class TensorNameMap:
"transformer.blocks.{bid}.ffn.down_proj", # mpt
"transformer.h.{bid}.mlp.dense_4h_to_h", # falcon
"h.{bid}.mlp.dense_4h_to_h", # bloom
"model.layers.{bid}.mlp.down_proj", # llama-hf nemotron olmo2
"model.layers.{bid}.mlp.down_proj", # llama-hf nemotron olmo2 exaone4
"layers.{bid}.feed_forward.w2", # llama-pth
"encoder.layer.{bid}.output.dense", # bert
"transformer.layer.{bid}.ffn.lin2", # distillbert
Expand Down Expand Up @@ -462,7 +462,7 @@ class TensorNameMap:
"language_model.encoder.layers.{bid}.self_attention.q_layernorm",
"model.layers.{bid}.self_attn.q_layernorm", # persimmon
"model.layers.{bid}.self_attn.query_layernorm", # hunyuan
"model.layers.{bid}.self_attn.q_norm", # cohere olmoe chameleon olmo2
"model.layers.{bid}.self_attn.q_norm", # cohere olmoe chameleon olmo2 exaone4
"transformer.blocks.{bid}.attn.q_ln", # sea-lion
"encoder.layer.{bid}.attention.self.layer_norm_q", # jina-bert-v2
"transformer.layers.{bid}.attn.q_norm", # openelm
Expand All @@ -472,7 +472,7 @@ class TensorNameMap:
"language_model.encoder.layers.{bid}.self_attention.k_layernorm",
"model.layers.{bid}.self_attn.k_layernorm", # persimmon
"model.layers.{bid}.self_attn.key_layernorm", # hunyuan
"model.layers.{bid}.self_attn.k_norm", # cohere olmoe chameleon olmo2
"model.layers.{bid}.self_attn.k_norm", # cohere olmoe chameleon olmo2 exaone4
"transformer.blocks.{bid}.attn.k_ln", # sea-lion
"encoder.layer.{bid}.attention.self.layer_norm_k", # jina-bert-v2
"transformer.layers.{bid}.attn.k_norm", # openelm
Expand Down
23 changes: 12 additions & 11 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,17 +107,18 @@ extern "C" {
LLAMA_VOCAB_PRE_TYPE_BLOOM = 23,
LLAMA_VOCAB_PRE_TYPE_GPT3_FINNISH = 24,
LLAMA_VOCAB_PRE_TYPE_EXAONE = 25,
LLAMA_VOCAB_PRE_TYPE_CHAMELEON = 26,
LLAMA_VOCAB_PRE_TYPE_MINERVA = 27,
LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM = 28,
LLAMA_VOCAB_PRE_TYPE_GPT4O = 29,
LLAMA_VOCAB_PRE_TYPE_SUPERBPE = 30,
LLAMA_VOCAB_PRE_TYPE_TRILLION = 31,
LLAMA_VOCAB_PRE_TYPE_BAILINGMOE = 32,
LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33,
LLAMA_VOCAB_PRE_TYPE_PIXTRAL = 34,
LLAMA_VOCAB_PRE_TYPE_SEED_CODER = 35,
LLAMA_VOCAB_PRE_TYPE_HUNYUAN = 36,
LLAMA_VOCAB_PRE_TYPE_EXAONE4 = 26,
LLAMA_VOCAB_PRE_TYPE_CHAMELEON = 27,
LLAMA_VOCAB_PRE_TYPE_MINERVA = 28,
LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM = 29,
LLAMA_VOCAB_PRE_TYPE_GPT4O = 30,
LLAMA_VOCAB_PRE_TYPE_SUPERBPE = 31,
LLAMA_VOCAB_PRE_TYPE_TRILLION = 32,
LLAMA_VOCAB_PRE_TYPE_BAILINGMOE = 33,
LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 34,
LLAMA_VOCAB_PRE_TYPE_PIXTRAL = 35,
LLAMA_VOCAB_PRE_TYPE_SEED_CODER = 36,
LLAMA_VOCAB_PRE_TYPE_HUNYUAN = 37,
};

enum llama_rope_type {
Expand Down
21 changes: 21 additions & 0 deletions src/llama-arch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_JAIS, "jais" },
{ LLM_ARCH_NEMOTRON, "nemotron" },
{ LLM_ARCH_EXAONE, "exaone" },
{ LLM_ARCH_EXAONE4, "exaone4" },
{ LLM_ARCH_RWKV6, "rwkv6" },
{ LLM_ARCH_RWKV6QWEN2, "rwkv6qwen2" },
{ LLM_ARCH_RWKV7, "rwkv7" },
Expand Down Expand Up @@ -1474,6 +1475,26 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{
LLM_ARCH_EXAONE4,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
}
},
{
LLM_ARCH_RWKV6,
{
Expand Down
1 change: 1 addition & 0 deletions src/llama-arch.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ enum llm_arch {
LLM_ARCH_JAIS,
LLM_ARCH_NEMOTRON,
LLM_ARCH_EXAONE,
LLM_ARCH_EXAONE4,
LLM_ARCH_RWKV6,
LLM_ARCH_RWKV6QWEN2,
LLM_ARCH_RWKV7,
Expand Down
Loading