Skip to content

Commit 68e37a6

Browse files
mitmulcompiladeCISCggerganov
authored
model : add PLaMo-2 support (#14560)
* Add PLaMo-2 model using hybrid memory module * Fix z shape * Add cmath to include from llama-vocab.h * Explicitly dequantize normalization weights before RoPE apply * Revert unnecessary cast because the problem can be solved by excluding attn_k, attn_q when quantizing * Use ATTN_K/Q_NORM for k,q weights to prevent quantization * Remove SSM_BCDT that is not used from anywhere * Do not duplicate embedding weights for output.weight * Fix tokenizer encoding problem for multibyte strings * Apply suggestion from @CISC Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> * Update src/llama-model.cpp Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> * Use LLM_FFN_SWIGLU instead of splitting ffn_gate and ffn_up * Remove unnecessary part for Grouped Query Attention * Fix how to load special token id to gguf * Remove unused tensor mapping * Update src/llama-model.cpp Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> * Remove llama_vocab_plamo2 class and replace it with llm_tokenizer_plamo2_session to follow the other tokenizer implementations * Update src/llama-vocab.cpp Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * Update convert_hf_to_gguf.py Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> * Update src/llama-model.cpp Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> * Update src/llama-model.cpp Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> * Update convert_hf_to_gguf.py Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> * Update convert_hf_to_gguf.py Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> * Fix plamo2 tokenizer session to prevent multiple calls of build() --------- Co-authored-by: Francis Couture-Harpin <git@compilade.net> Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
1 parent cbc68be commit 68e37a6

File tree

8 files changed

+1048
-44
lines changed

8 files changed

+1048
-44
lines changed

convert_hf_to_gguf.py

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3508,6 +3508,175 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
35083508
return [(new_name, data_torch)]
35093509

35103510

3511+
@ModelBase.register("Plamo2ForCausalLM", "PLaMo2ForCausalLM")
3512+
class Plamo2Model(TextModel):
3513+
model_arch = gguf.MODEL_ARCH.PLAMO2
3514+
3515+
def set_vocab(self):
3516+
# PLaMo 2 uses a custom tokenizer with a .jsonl file
3517+
# We need to handle this specially
3518+
tokenizer_jsonl_path = self.dir_model / "tokenizer.jsonl"
3519+
tokenizer_config_path = self.dir_model / "tokenizer_config.json"
3520+
3521+
if not tokenizer_jsonl_path.is_file():
3522+
raise FileNotFoundError(f"PLaMo 2 tokenizer file not found: {tokenizer_jsonl_path}")
3523+
3524+
# Load tokenizer config
3525+
with open(tokenizer_config_path, 'r', encoding='utf-8') as f:
3526+
tokenizer_config = json.load(f)
3527+
3528+
# Load tokens from JSONL file (actually a list format)
3529+
tokens = []
3530+
scores = []
3531+
toktypes = []
3532+
3533+
with open(tokenizer_jsonl_path, 'r', encoding='utf-8') as f:
3534+
for line_num, line in enumerate(f):
3535+
if line.strip():
3536+
token_data = json.loads(line)
3537+
# Format: [token, score, type, ?, ?, ?, ?]
3538+
token = token_data[0].encode("utf-8")
3539+
score = float(token_data[1])
3540+
token_type_str = token_data[2] if len(token_data) > 2 else "NORMAL"
3541+
3542+
tokens.append(token)
3543+
scores.append(score)
3544+
3545+
# Map token type strings to GGUF token types
3546+
if token_type_str == "UNKNOWN":
3547+
toktypes.append(gguf.TokenType.UNKNOWN)
3548+
elif token_type_str == "CONTROL":
3549+
toktypes.append(gguf.TokenType.CONTROL)
3550+
elif token_type_str == "BYTE":
3551+
toktypes.append(gguf.TokenType.BYTE)
3552+
else:
3553+
# Check for PLaMo-2 special tokens
3554+
token_str = token_data[0]
3555+
if token_str.startswith("<|plamo:") and token_str.endswith("|>"):
3556+
toktypes.append(gguf.TokenType.CONTROL)
3557+
else:
3558+
toktypes.append(gguf.TokenType.NORMAL)
3559+
3560+
vocab_size = self.hparams["vocab_size"]
3561+
if vocab_size > len(tokens):
3562+
pad_count = vocab_size - len(tokens)
3563+
logger.debug(f"Padding vocab with {pad_count} token(s) - [PAD1] through [PAD{pad_count}]")
3564+
for i in range(1, pad_count + 1):
3565+
tokens.append(bytes(f"[PAD{i}]", encoding="utf-8"))
3566+
scores.append(-1000.0)
3567+
toktypes.append(gguf.TokenType.UNUSED)
3568+
3569+
# Use "plamo2" tokenizer type for PLaMo-2's custom Aho-Corasick tokenizer
3570+
self.gguf_writer.add_tokenizer_model("plamo2")
3571+
self.gguf_writer.add_tokenizer_pre("default")
3572+
self.gguf_writer.add_token_list(tokens)
3573+
self.gguf_writer.add_token_scores(scores)
3574+
self.gguf_writer.add_token_types(toktypes)
3575+
3576+
# Add special tokens from config
3577+
if "bos_token" in tokenizer_config and tokenizer_config["bos_token"] is not None:
3578+
token_id = tokens.index(tokenizer_config["bos_token"].encode("utf-8"))
3579+
self.gguf_writer.add_bos_token_id(token_id)
3580+
if "eos_token" in tokenizer_config and tokenizer_config["eos_token"] is not None:
3581+
token_id = tokens.index(tokenizer_config["eos_token"].encode("utf-8"))
3582+
self.gguf_writer.add_eos_token_id(token_id)
3583+
if "pad_token" in tokenizer_config and tokenizer_config["pad_token"] is not None:
3584+
token_id = tokens.index(tokenizer_config["pad_token"].encode("utf-8"))
3585+
self.gguf_writer.add_pad_token_id(token_id)
3586+
if "sep_token" in tokenizer_config and tokenizer_config["sep_token"] is not None:
3587+
token_id = tokens.index(tokenizer_config["sep_token"].encode("utf-8"))
3588+
self.gguf_writer.add_sep_token_id(token_id)
3589+
if "unk_token" in tokenizer_config and tokenizer_config["unk_token"] is not None:
3590+
token_id = tokens.index(tokenizer_config["unk_token"].encode("utf-8"))
3591+
self.gguf_writer.add_unk_token_id(token_id)
3592+
3593+
# Add <|plamo:op|> as EOT to ensure appropriate end of generation
3594+
self.gguf_writer.add_eot_token_id(4)
3595+
3596+
self.gguf_writer.add_add_space_prefix(False)
3597+
3598+
def set_gguf_parameters(self):
3599+
hparams = self.hparams
3600+
block_count = hparams["num_hidden_layers"]
3601+
self.gguf_writer.add_vocab_size(self.hparams["vocab_size"])
3602+
3603+
# Which layers are Mamba layers
3604+
# PLaMo 2 uses mamba_step to indicate the pattern (e.g., 2 means every other layer)
3605+
# This logic matches modeling_plamo.py's is_mamba function
3606+
mamba_step = hparams.get("mamba_step", 2)
3607+
mamba_enabled = hparams.get("mamba_enabled", True)
3608+
mamba_layers = []
3609+
3610+
if mamba_enabled:
3611+
for i in range(block_count):
3612+
if block_count <= (mamba_step // 2):
3613+
# use attention in last layer
3614+
is_mamba = (i != block_count - 1)
3615+
else:
3616+
is_mamba = (i % mamba_step) != (mamba_step // 2)
3617+
if is_mamba:
3618+
mamba_layers.append(0)
3619+
else:
3620+
mamba_layers.append(hparams.get("num_key_value_heads", 4))
3621+
3622+
if mamba_layers:
3623+
self.gguf_writer.add_head_count_kv(mamba_layers)
3624+
3625+
self.gguf_writer.add_context_length(hparams.get("max_position_embeddings", 2048))
3626+
self.gguf_writer.add_embedding_length(hparams.get("hidden_size", 4096))
3627+
self.gguf_writer.add_block_count(block_count)
3628+
self.gguf_writer.add_head_count(hparams.get("num_attention_heads", 32))
3629+
self.gguf_writer.add_layer_norm_rms_eps(hparams.get("rms_norm_eps", 1e-06))
3630+
self.gguf_writer.add_rope_freq_base(hparams.get("rope_theta", 1000000.0))
3631+
3632+
# Mamba parameters
3633+
self.gguf_writer.add_ssm_state_size(hparams.get("mamba_d_state", 64))
3634+
self.gguf_writer.add_ssm_conv_kernel(hparams.get("mamba_d_conv", 4))
3635+
self.gguf_writer.add_ssm_time_step_rank(hparams.get("mamba_num_heads", 64))
3636+
intermediate_size = hparams.get("mamba_num_heads", 64) * hparams.get("hidden_size_per_head", 128)
3637+
self.gguf_writer.add_ssm_inner_size(intermediate_size)
3638+
self.gguf_writer.add_ssm_group_count(0)
3639+
3640+
# MLP feed forward parameters (for attention layers)
3641+
self.gguf_writer.add_feed_forward_length(hparams.get("intermediate_size", 16384))
3642+
self.gguf_writer.add_file_type(self.ftype)
3643+
3644+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
3645+
del bid # unused
3646+
3647+
if name.endswith(".A_log"):
3648+
data_torch = -torch.exp(data_torch)
3649+
elif name.endswith(".dt_bias"):
3650+
name = name.rpartition(".dt_bias")[0] + ".dt_proj.bias"
3651+
elif name.endswith(".dt_norm_weight"):
3652+
name = name.rpartition(".dt_norm_weight")[0] + ".dt_norm.weight"
3653+
elif name.endswith(".B_norm_weight"):
3654+
name = name.rpartition(".B_norm_weight")[0] + ".B_norm.weight"
3655+
elif name.endswith(".C_norm_weight"):
3656+
name = name.rpartition(".C_norm_weight")[0] + ".C_norm.weight"
3657+
elif name.endswith(".k_weight"):
3658+
name = name.rpartition(".k_weight")[0] + ".k.weight"
3659+
elif name.endswith(".q_weight"):
3660+
name = name.rpartition(".q_weight")[0] + ".q.weight"
3661+
elif name.endswith(".conv1d.weight"):
3662+
data_torch = torch.squeeze(data_torch) # remove (, 1, )
3663+
assert data_torch.ndim == 2
3664+
elif name.endswith(".pre_mixer_norm.weight"):
3665+
data_torch += 1.0
3666+
elif name.endswith(".post_mixer_norm.weight"):
3667+
data_torch += 1.0 / 5
3668+
elif name.endswith(".pre_mlp_norm.weight"):
3669+
data_torch += 1.0
3670+
elif name.endswith(".post_mlp_norm.weight"):
3671+
data_torch += 1.0 / (5**1.5)
3672+
elif name.endswith(".norm.weight"):
3673+
data_torch += 1.0
3674+
3675+
new_name = self.map_tensor_name(name)
3676+
3677+
return [(new_name, data_torch)]
3678+
3679+
35113680
@ModelBase.register("CodeShellForCausalLM")
35123681
class CodeShellModel(TextModel):
35133682
model_arch = gguf.MODEL_ARCH.CODESHELL

gguf-py/gguf/constants.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,7 @@ class MODEL_ARCH(IntEnum):
317317
PHI3 = auto()
318318
PHIMOE = auto()
319319
PLAMO = auto()
320+
PLAMO2 = auto()
320321
CODESHELL = auto()
321322
ORION = auto()
322323
INTERNLM2 = auto()
@@ -631,6 +632,7 @@ class MODEL_TENSOR(IntEnum):
631632
MODEL_ARCH.PHI3: "phi3",
632633
MODEL_ARCH.PHIMOE: "phimoe",
633634
MODEL_ARCH.PLAMO: "plamo",
635+
MODEL_ARCH.PLAMO2: "plamo2",
634636
MODEL_ARCH.CODESHELL: "codeshell",
635637
MODEL_ARCH.ORION: "orion",
636638
MODEL_ARCH.INTERNLM2: "internlm2",
@@ -1369,6 +1371,36 @@ class MODEL_TENSOR(IntEnum):
13691371
MODEL_TENSOR.FFN_DOWN,
13701372
MODEL_TENSOR.FFN_UP,
13711373
],
1374+
MODEL_ARCH.PLAMO2: [
1375+
MODEL_TENSOR.TOKEN_EMBD,
1376+
MODEL_TENSOR.OUTPUT_NORM,
1377+
MODEL_TENSOR.OUTPUT,
1378+
MODEL_TENSOR.ROPE_FREQS,
1379+
MODEL_TENSOR.ATTN_NORM,
1380+
MODEL_TENSOR.ATTN_QKV,
1381+
MODEL_TENSOR.ATTN_Q,
1382+
MODEL_TENSOR.ATTN_K,
1383+
MODEL_TENSOR.ATTN_OUT,
1384+
MODEL_TENSOR.ATTN_ROT_EMBD,
1385+
MODEL_TENSOR.ATTN_Q_NORM,
1386+
MODEL_TENSOR.ATTN_K_NORM,
1387+
MODEL_TENSOR.ATTN_POST_NORM,
1388+
MODEL_TENSOR.FFN_NORM,
1389+
MODEL_TENSOR.FFN_GATE,
1390+
MODEL_TENSOR.FFN_DOWN,
1391+
MODEL_TENSOR.FFN_UP,
1392+
MODEL_TENSOR.FFN_POST_NORM,
1393+
MODEL_TENSOR.SSM_IN,
1394+
MODEL_TENSOR.SSM_CONV1D,
1395+
MODEL_TENSOR.SSM_X,
1396+
MODEL_TENSOR.SSM_DT,
1397+
MODEL_TENSOR.SSM_A,
1398+
MODEL_TENSOR.SSM_D,
1399+
MODEL_TENSOR.SSM_OUT,
1400+
MODEL_TENSOR.SSM_DT_NORM,
1401+
MODEL_TENSOR.SSM_B_NORM,
1402+
MODEL_TENSOR.SSM_C_NORM,
1403+
],
13721404
MODEL_ARCH.GPT2: [
13731405
MODEL_TENSOR.TOKEN_EMBD,
13741406
MODEL_TENSOR.POS_EMBD,

0 commit comments

Comments
 (0)