Skip to content

Commit e0cb5c5

Browse files
authored
model : add EXAONE 4.0 support (#14630)
1 parent f9a31ee commit e0cb5c5

File tree

9 files changed

+333
-0
lines changed

9 files changed

+333
-0
lines changed

convert_hf_to_gguf.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -843,6 +843,9 @@ def get_vocab_base_pre(self, tokenizer) -> str:
843843
if chkhsh == "169bf0296a13c4d9b7672313f749eb36501d931022de052aad6e36f2bf34dd51":
844844
# ref: https://huggingface.co/LiquidAI/LFM2-Tokenizer
845845
res = "lfm2"
846+
if chkhsh == "2085e1638f6c377a0aa4ead21b27bb4cb941bf800df86ed391011769c1758dfb":
847+
# ref: https://huggingface.co/LGAI-EXAONE/EXAONE-4.0-32B
848+
res = "exaone4"
846849

847850
if res is None:
848851
logger.warning("\n")
@@ -6780,6 +6783,75 @@ def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
67806783
yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FREQS), torch.tensor(rope_factors, dtype=torch.float32))
67816784

67826785

6786+
@ModelBase.register("Exaone4ForCausalLM")
6787+
class Exaone4Model(TextModel):
6788+
model_arch = gguf.MODEL_ARCH.EXAONE4
6789+
6790+
def set_vocab(self):
6791+
tokens, toktypes, tokpre = self.get_vocab_base()
6792+
self.gguf_writer.add_tokenizer_model("gpt2")
6793+
self.gguf_writer.add_tokenizer_pre(tokpre)
6794+
self.gguf_writer.add_token_list(tokens)
6795+
self.gguf_writer.add_token_types(toktypes)
6796+
6797+
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True)
6798+
special_vocab.add_to_gguf(self.gguf_writer)
6799+
6800+
def set_gguf_parameters(self):
6801+
super().set_gguf_parameters()
6802+
hparams = self.hparams
6803+
self.gguf_writer.add_vocab_size(hparams["vocab_size"])
6804+
6805+
if hparams.get("sliding_window") is not None:
6806+
self.gguf_writer.add_sliding_window(hparams["sliding_window"])
6807+
if "layer_types" in hparams:
6808+
self.gguf_writer.add_sliding_window_pattern([t == "sliding_attention" for t in hparams["layer_types"]])
6809+
elif "sliding_window_pattern" in hparams:
6810+
sliding_window_pattern = []
6811+
if isinstance(hparams["sliding_window_pattern"], str): # e.g. LLLG
6812+
for i in range(hparams["num_hidden_layers"]):
6813+
sliding_window_pattern.append(hparams["sliding_window_pattern"][i % len(hparams["sliding_window_pattern"])] == "L")
6814+
if isinstance(hparams["sliding_window_pattern"], int): # e.g. 4
6815+
for i in range(hparams["num_hidden_layers"]):
6816+
sliding_window_pattern.append((i + 1) % hparams["sliding_window_pattern"] != 0)
6817+
if len(sliding_window_pattern) == hparams["num_hidden_layers"]:
6818+
self.gguf_writer.add_sliding_window_pattern(sliding_window_pattern)
6819+
6820+
rope_scaling = self.hparams.get("rope_scaling") or {}
6821+
if rope_scaling.get("rope_type", rope_scaling.get("type")) == "linear" and "factor" in rope_scaling:
6822+
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
6823+
self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"])
6824+
6825+
def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
6826+
if rope_scaling := self.find_hparam(["rope_scaling"], optional=True):
6827+
if rope_scaling.get("rope_type", '').lower() == "llama3":
6828+
base = self.hparams.get("rope_theta", 10_000.0)
6829+
if (dim := self.hparams.get("head_dim")) is None:
6830+
dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
6831+
freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
6832+
6833+
factor = rope_scaling.get("factor", 16.0)
6834+
low_freq_factor = rope_scaling.get("low_freq_factor", 1.0)
6835+
high_freq_factor = rope_scaling.get("high_freq_factor", 4.0)
6836+
old_context_len = self.hparams.get("original_max_position_embeddings", 8192)
6837+
6838+
low_freq_wavelen = old_context_len / low_freq_factor
6839+
high_freq_wavelen = old_context_len / high_freq_factor
6840+
6841+
rope_factors = []
6842+
for freq in freqs:
6843+
wavelen = 2 * math.pi / freq
6844+
if wavelen < high_freq_wavelen:
6845+
rope_factors.append(1)
6846+
elif wavelen > low_freq_wavelen:
6847+
rope_factors.append(factor)
6848+
else:
6849+
smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
6850+
rope_factors.append(1 / ((1 - smooth) / factor + smooth))
6851+
6852+
yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FREQS), torch.tensor(rope_factors, dtype=torch.float32))
6853+
6854+
67836855
@ModelBase.register("GraniteForCausalLM")
67846856
class GraniteModel(LlamaModel):
67856857
"""Conversion for IBM's GraniteForCausalLM"""

convert_hf_to_gguf_update.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ class TOKENIZER_TYPE(IntEnum):
129129
{"name": "a.x-4.0", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/skt/A.X-4.0", },
130130
{"name": "midm-2.0", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/K-intelligence/Midm-2.0-Base-Instruct", },
131131
{"name": "lfm2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LiquidAI/LFM2-Tokenizer"},
132+
{"name": "exaone4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LGAI-EXAONE/EXAONE-4.0-32B", },
132133
]
133134

134135
# some models are known to be broken upstream, so we will skip them as exceptions

gguf-py/gguf/constants.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,7 @@ class MODEL_ARCH(IntEnum):
354354
JAIS = auto()
355355
NEMOTRON = auto()
356356
EXAONE = auto()
357+
EXAONE4 = auto()
357358
GRANITE = auto()
358359
GRANITE_MOE = auto()
359360
GRANITE_HYBRID = auto()
@@ -671,6 +672,7 @@ class MODEL_TENSOR(IntEnum):
671672
MODEL_ARCH.JAIS: "jais",
672673
MODEL_ARCH.NEMOTRON: "nemotron",
673674
MODEL_ARCH.EXAONE: "exaone",
675+
MODEL_ARCH.EXAONE4: "exaone4",
674676
MODEL_ARCH.GRANITE: "granite",
675677
MODEL_ARCH.GRANITE_MOE: "granitemoe",
676678
MODEL_ARCH.GRANITE_HYBRID: "granitehybrid",
@@ -2197,6 +2199,23 @@ class MODEL_TENSOR(IntEnum):
21972199
MODEL_TENSOR.FFN_DOWN,
21982200
MODEL_TENSOR.FFN_UP,
21992201
],
2202+
MODEL_ARCH.EXAONE4: [
2203+
MODEL_TENSOR.TOKEN_EMBD,
2204+
MODEL_TENSOR.OUTPUT_NORM,
2205+
MODEL_TENSOR.OUTPUT,
2206+
MODEL_TENSOR.ROPE_FREQS,
2207+
MODEL_TENSOR.ATTN_Q,
2208+
MODEL_TENSOR.ATTN_Q_NORM,
2209+
MODEL_TENSOR.ATTN_K,
2210+
MODEL_TENSOR.ATTN_K_NORM,
2211+
MODEL_TENSOR.ATTN_V,
2212+
MODEL_TENSOR.ATTN_OUT,
2213+
MODEL_TENSOR.ATTN_POST_NORM,
2214+
MODEL_TENSOR.FFN_GATE,
2215+
MODEL_TENSOR.FFN_DOWN,
2216+
MODEL_TENSOR.FFN_UP,
2217+
MODEL_TENSOR.FFN_POST_NORM,
2218+
],
22002219
MODEL_ARCH.GRANITE: [
22012220
MODEL_TENSOR.TOKEN_EMBD,
22022221
MODEL_TENSOR.OUTPUT_NORM,

src/llama-arch.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
6868
{ LLM_ARCH_JAIS, "jais" },
6969
{ LLM_ARCH_NEMOTRON, "nemotron" },
7070
{ LLM_ARCH_EXAONE, "exaone" },
71+
{ LLM_ARCH_EXAONE4, "exaone4" },
7172
{ LLM_ARCH_RWKV6, "rwkv6" },
7273
{ LLM_ARCH_RWKV6QWEN2, "rwkv6qwen2" },
7374
{ LLM_ARCH_RWKV7, "rwkv7" },
@@ -1510,6 +1511,26 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
15101511
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
15111512
},
15121513
},
1514+
{
1515+
LLM_ARCH_EXAONE4,
1516+
{
1517+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
1518+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
1519+
{ LLM_TENSOR_OUTPUT, "output" },
1520+
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
1521+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
1522+
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
1523+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
1524+
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
1525+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
1526+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
1527+
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
1528+
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
1529+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
1530+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
1531+
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
1532+
}
1533+
},
15131534
{
15141535
LLM_ARCH_RWKV6,
15151536
{

src/llama-arch.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ enum llm_arch {
7272
LLM_ARCH_JAIS,
7373
LLM_ARCH_NEMOTRON,
7474
LLM_ARCH_EXAONE,
75+
LLM_ARCH_EXAONE4,
7576
LLM_ARCH_RWKV6,
7677
LLM_ARCH_RWKV6QWEN2,
7778
LLM_ARCH_RWKV7,

src/llama-chat.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
5656
{ "glmedge", LLM_CHAT_TEMPLATE_GLMEDGE },
5757
{ "minicpm", LLM_CHAT_TEMPLATE_MINICPM },
5858
{ "exaone3", LLM_CHAT_TEMPLATE_EXAONE_3 },
59+
{ "exaone4", LLM_CHAT_TEMPLATE_EXAONE_4 },
5960
{ "rwkv-world", LLM_CHAT_TEMPLATE_RWKV_WORLD },
6061
{ "granite", LLM_CHAT_TEMPLATE_GRANITE },
6162
{ "gigachat", LLM_CHAT_TEMPLATE_GIGACHAT },
@@ -168,6 +169,9 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
168169
} else if (tmpl_contains(LU8("<|Assistant|>")) && tmpl_contains(LU8("<|User|>")) && tmpl_contains(LU8("<|end▁of▁sentence|>"))) {
169170
return LLM_CHAT_TEMPLATE_DEEPSEEK_3;
170171
} else if (tmpl_contains("[|system|]") && tmpl_contains("[|assistant|]") && tmpl_contains("[|endofturn|]")) {
172+
if (tmpl_contains("[|tool|]")) {
173+
return LLM_CHAT_TEMPLATE_EXAONE_4;
174+
}
171175
// ref: https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct/discussions/8#66bae61b1893d14ee8ed85bb
172176
// EXAONE-3.0-7.8B-Instruct
173177
return LLM_CHAT_TEMPLATE_EXAONE_3;
@@ -532,6 +536,22 @@ int32_t llm_chat_apply_template(
532536
if (add_ass) {
533537
ss << "[|assistant|]";
534538
}
539+
} else if (tmpl == LLM_CHAT_TEMPLATE_EXAONE_4) {
540+
for (auto message : chat) {
541+
std::string role(message->role);
542+
if (role == "system") {
543+
ss << "[|system|]" << trim(message->content) << "[|endofturn|]\n";
544+
} else if (role == "user") {
545+
ss << "[|user|]" << trim(message->content) << "\n";
546+
} else if (role == "assistant") {
547+
ss << "[|assistant|]" << trim(message->content) << "[|endofturn|]\n";
548+
} else if (role == "tool") {
549+
ss << "[|tool|]" << trim(message->content) << "[|endofturn|]\n";
550+
}
551+
}
552+
if (add_ass) {
553+
ss << "[|assistant|]";
554+
}
535555
} else if (tmpl == LLM_CHAT_TEMPLATE_RWKV_WORLD) {
536556
// this template requires the model to have "\n\n" as EOT token
537557
for (size_t i = 0; i < chat.size(); i++) {

src/llama-chat.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ enum llm_chat_template {
3535
LLM_CHAT_TEMPLATE_GLMEDGE,
3636
LLM_CHAT_TEMPLATE_MINICPM,
3737
LLM_CHAT_TEMPLATE_EXAONE_3,
38+
LLM_CHAT_TEMPLATE_EXAONE_4,
3839
LLM_CHAT_TEMPLATE_RWKV_WORLD,
3940
LLM_CHAT_TEMPLATE_GRANITE,
4041
LLM_CHAT_TEMPLATE_GIGACHAT,

0 commit comments

Comments
 (0)