Skip to content

Commit d6dac85

Browse files
committed
add chatglm3-6b model support
huggingface model: https://hf-mirror.com/THUDM/chatglm3-6b Signed-off-by: XingXing Qiao <qiaoxx@dingdao.com>
1 parent 8843a98 commit d6dac85

File tree

5 files changed

+393
-2
lines changed

5 files changed

+393
-2
lines changed

convert-hf-to-gguf.py

Lines changed: 160 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def __init__(self, dir_model: Path, ftype: int, fname_out: Path, is_big_endian:
5656
self.part_names = self._get_part_names()
5757
self.hparams = Model.load_hparams(self.dir_model)
5858
self.gguf_writer = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file)
59-
self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer"])
59+
self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer", "num_layers"])
6060

6161
@property
6262
@abstractmethod
@@ -2903,6 +2903,165 @@ def write_tensors(self):
29032903

29042904
self.gguf_writer.add_tensor(new_name, data)
29052905

2906+
@Model.register("ChatGLMModel")
2907+
class ChatGLMModel(Model):
2908+
model_arch = gguf.MODEL_ARCH.CHATGLM
2909+
2910+
def set_vocab(self):
2911+
dir_model = self.dir_model
2912+
hparams = self.hparams
2913+
tokens: list[bytearray] = []
2914+
toktypes: list[int] = []
2915+
scores: list[float] = []
2916+
2917+
from transformers import AutoTokenizer
2918+
tokenizer = AutoTokenizer.from_pretrained(dir_model, trust_remote_code=True)
2919+
vocab_size = hparams.get("padded_vocab_size", len(tokenizer.get_vocab()))
2920+
assert max(tokenizer.get_vocab().values()) < vocab_size
2921+
2922+
reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in tokenizer.get_vocab().items()}
2923+
2924+
for token_id in range(vocab_size):
2925+
piece = tokenizer._convert_id_to_token(token_id)
2926+
if token_id == 0:
2927+
piece = "<unk>"
2928+
elif token_id == 1:
2929+
piece = "<bos>"
2930+
elif token_id == 2:
2931+
piece = "<eos>"
2932+
2933+
text = piece.encode("utf-8")
2934+
score = 0.0
2935+
if len(piece) != 0 and token_id < 64789:
2936+
score = tokenizer.tokenizer.sp_model.get_score(token_id)
2937+
2938+
if len(piece) == 0:
2939+
text = f"[PAD{token_id}]".encode("utf-8")
2940+
2941+
if token_id >= 64789:
2942+
toktype = SentencePieceTokenTypes.UNKNOWN
2943+
tokens.append(text)
2944+
scores.append(score)
2945+
toktypes.append(toktype)
2946+
continue
2947+
2948+
toktype = SentencePieceTokenTypes.NORMAL
2949+
if tokenizer.tokenizer.sp_model.is_unknown(token_id):
2950+
toktype = SentencePieceTokenTypes.UNKNOWN
2951+
elif tokenizer.tokenizer.sp_model.is_control(token_id):
2952+
toktype = SentencePieceTokenTypes.CONTROL
2953+
elif tokenizer.tokenizer.sp_model.is_unused(token_id):
2954+
toktype = SentencePieceTokenTypes.UNUSED
2955+
elif tokenizer.tokenizer.sp_model.is_byte(token_id):
2956+
toktype = SentencePieceTokenTypes.BYTE
2957+
2958+
tokens.append(text)
2959+
scores.append(score)
2960+
toktypes.append(toktype)
2961+
2962+
self.gguf_writer.add_tokenizer_model("llama")
2963+
self.gguf_writer.add_token_list(tokens)
2964+
self.gguf_writer.add_token_scores(scores)
2965+
self.gguf_writer.add_token_types(toktypes)
2966+
2967+
special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
2968+
special_vocab.add_to_gguf(self.gguf_writer)
2969+
2970+
def set_gguf_parameters(self):
2971+
self.gguf_writer.add_name("ChatGLM-6b-chat")
2972+
n_embed = self.hparams.get("hidden_size", self.hparams.get("n_embed"))
2973+
n_head = self.hparams.get("n_head", self.hparams.get("num_attention_heads"))
2974+
n_head_kv = self.hparams.get("multi_query_group_num", n_head)
2975+
self.gguf_writer.add_context_length(self.hparams.get("seq_length", n_embed))
2976+
self.gguf_writer.add_embedding_length(n_embed)
2977+
self.gguf_writer.add_feed_forward_length(self.hparams.get("ffn_hidden_size", 4 * n_embed))
2978+
self.gguf_writer.add_block_count(self.hparams["num_layers"])
2979+
self.gguf_writer.add_head_count(n_head)
2980+
self.gguf_writer.add_head_count_kv(n_head_kv)
2981+
self.gguf_writer.add_layer_norm_rms_eps(self.hparams["layernorm_epsilon"])
2982+
self.gguf_writer.add_file_type(self.ftype)
2983+
self.gguf_writer.add_rope_dimension_count(64)
2984+
self.gguf_writer.add_add_bos_token(False)
2985+
2986+
def write_tensors(self):
2987+
block_count = self.hparams["num_layers"]
2988+
tensors = dict(self.get_tensors())
2989+
tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
2990+
has_lm_head = True
2991+
n_head = self.hparams.get("n_head", self.hparams.get("num_attention_heads"))
2992+
n_embed = self.hparams.get("hidden_size", self.hparams.get("n_embed"))
2993+
2994+
for name, data_torch in tensors.items():
2995+
if name.endswith(".rotary_pos_emb.inv_freq"):
2996+
continue
2997+
2998+
if "lm_head.weight" not in tensors.keys() and "output.weight" not in tensors.keys():
2999+
has_lm_head = False
3000+
3001+
name = re.sub(r'transformer\.', '', name)
3002+
3003+
old_dtype = data_torch.dtype
3004+
3005+
# convert any unsupported data types to float32
3006+
if data_torch.dtype not in (torch.float16, torch.float32):
3007+
data_torch = data_torch.to(torch.float32)
3008+
3009+
data = data_torch.squeeze().numpy()
3010+
3011+
if re.match(r"h\.\d+\.self_attention\.query_key_value\.weight", name):
3012+
# Map bloom-style qkv_linear to gpt-style qkv_linear
3013+
# bloom: https://github.com/huggingface/transformers/blob/main/src/transformers/models/bloom/modeling_bloom.py#L238-L252 # noqa
3014+
# gpt-2: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py#L312 # noqa
3015+
qkv_weights = data.reshape((n_head, 3, n_embed // n_head, n_embed))
3016+
data = np.concatenate(
3017+
(
3018+
qkv_weights[:, 0, :, :].reshape((-1, n_embed)),
3019+
qkv_weights[:, 1, :, :].reshape((-1, n_embed)),
3020+
qkv_weights[:, 2, :, :].reshape((-1, n_embed)),
3021+
),
3022+
axis=0,
3023+
)
3024+
print("re-format attention.linear_qkv.weight")
3025+
elif re.match(r"h\.\d+\.self_attention\.query_key_value\.bias", name):
3026+
qkv_bias = data.reshape((n_head, 3, n_embed // n_head))
3027+
data = np.concatenate(
3028+
(
3029+
qkv_bias[:, 0, :].reshape((n_embed,)),
3030+
qkv_bias[:, 1, :].reshape((n_embed,)),
3031+
qkv_bias[:, 2, :].reshape((n_embed,)),
3032+
),
3033+
axis=0,
3034+
)
3035+
print("re-format attention.linear_qkv.bias")
3036+
3037+
# map tensor names
3038+
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
3039+
if new_name is None:
3040+
print(f"Can not map tensor {name!r}")
3041+
sys.exit()
3042+
3043+
n_dims = len(data.shape)
3044+
data_dtype = data.dtype
3045+
3046+
# if f32 desired, convert any float16 to float32
3047+
if self.ftype == 0 and data_dtype == np.float16:
3048+
data = data.astype(np.float32)
3049+
3050+
# TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
3051+
if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
3052+
data = data.astype(np.float32)
3053+
3054+
# if f16 desired, convert any float32 2-dim weight tensors to float16
3055+
if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
3056+
data = data.astype(np.float16)
3057+
3058+
print(f"=> {new_name}, shape = {data.shape}, {old_dtype} --> {data.dtype}")
3059+
3060+
self.gguf_writer.add_tensor(new_name, data)
3061+
3062+
if not has_lm_head and name == "word_embeddings.weight":
3063+
self.gguf_writer.add_tensor("output.weight", data)
3064+
print(name, f"=> output.weight, shape = {data.shape}, {old_dtype} --> {data.dtype}")
29063065

29073066
###### CONVERSION LOGIC ######
29083067

gguf-py/gguf/constants.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ class MODEL_ARCH(IntEnum):
138138
COMMAND_R = auto()
139139
DBRX = auto()
140140
OLMO = auto()
141+
CHATGLM = auto()
141142

142143

143144
class MODEL_TENSOR(IntEnum):
@@ -215,6 +216,7 @@ class MODEL_TENSOR(IntEnum):
215216
MODEL_ARCH.COMMAND_R: "command-r",
216217
MODEL_ARCH.DBRX: "dbrx",
217218
MODEL_ARCH.OLMO: "olmo",
219+
MODEL_ARCH.CHATGLM: "chatglm",
218220
}
219221

220222
TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
@@ -725,6 +727,18 @@ class MODEL_TENSOR(IntEnum):
725727
MODEL_TENSOR.FFN_DOWN,
726728
MODEL_TENSOR.FFN_UP,
727729
],
730+
MODEL_ARCH.CHATGLM : [
731+
MODEL_TENSOR.TOKEN_EMBD,
732+
MODEL_TENSOR.ROPE_FREQS,
733+
MODEL_TENSOR.OUTPUT_NORM,
734+
MODEL_TENSOR.OUTPUT,
735+
MODEL_TENSOR.ATTN_NORM,
736+
MODEL_TENSOR.ATTN_QKV,
737+
MODEL_TENSOR.ATTN_OUT,
738+
MODEL_TENSOR.FFN_NORM,
739+
MODEL_TENSOR.FFN_DOWN,
740+
MODEL_TENSOR.FFN_UP,
741+
],
728742
# TODO
729743
}
730744

@@ -761,6 +775,9 @@ class MODEL_TENSOR(IntEnum):
761775
MODEL_TENSOR.ROPE_FREQS,
762776
MODEL_TENSOR.ATTN_ROT_EMBD,
763777
],
778+
MODEL_ARCH.CHATGLM: [
779+
MODEL_TENSOR.ROPE_FREQS,
780+
]
764781
}
765782

766783
#

gguf-py/gguf/tensor_mapping.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ class TensorNameMap:
2424
"backbone.embedding", # mamba
2525
"backbone.embeddings", # mamba-hf
2626
"transformer.in_out_embed", # Grok
27+
"embedding.word_embeddings", # chatglm
2728
),
2829

2930
# Token type embeddings
@@ -52,6 +53,7 @@ class TensorNameMap:
5253
"output", # llama-pth bloom internlm2
5354
"word_embeddings_for_head", # persimmon
5455
"lm_head.linear", # phi2
56+
"output_layer", # chatglm
5557
),
5658

5759
# Output norm
@@ -68,11 +70,13 @@ class TensorNameMap:
6870
"model.norm_f", # mamba-qbert
6971
"backbone.norm_f", # mamba
7072
"transformer.rms_norm", # Grok
73+
"encoder.final_layernorm", # chatglm
7174
),
7275

7376
# Rope frequencies
7477
MODEL_TENSOR.ROPE_FREQS: (
7578
"rope.freqs", # llama-pth
79+
"rotary_pos_emb.inv_freq", # chatglm
7680
),
7781
}
7882

@@ -97,6 +101,7 @@ class TensorNameMap:
97101
"backbone.layers.{bid}.norm", # mamba
98102
"transformer.decoder_layer.{bid}.rms_norm", # Grok
99103
"transformer.blocks.{bid}.norm_attn_norm.norm_1", # dbrx
104+
"encoder.layers.{bid}.input_layernorm", # chatglm
100105
),
101106

102107
# Attention norm 2
@@ -118,6 +123,7 @@ class TensorNameMap:
118123
"transformer.h.{bid}.mixer.Wqkv", # phi2
119124
"encoder.layers.{bid}.attn.Wqkv", # nomic-bert
120125
"model.layers.{bid}.self_attn.qkv_proj" # phi3
126+
"encoder.layers.{bid}.self_attention.query_key_value", # chatglm
121127
),
122128

123129
# Attention query
@@ -173,6 +179,7 @@ class TensorNameMap:
173179
"encoder.layers.{bid}.attn.out_proj", # nomic-bert
174180
"transformer.decoder_layer.{bid}.multi_head_attention.linear", # Grok
175181
"transformer.blocks.{bid}.norm_attn_norm.attn.out_proj", # dbrx
182+
"encoder.layers.{bid}.self_attention.dense", # chatglm
176183
),
177184

178185
# Attention output norm
@@ -204,6 +211,7 @@ class TensorNameMap:
204211
"h.{bid}.ln_2", # gpt2
205212
"model.layers.{bid}.ffn_norm", # internlm2
206213
"transformer.decoder_layer.{bid}.rms_norm_2", # Grok
214+
"encoder.layers.{bid}.post_attention_layernorm", # chatglm
207215
),
208216

209217
MODEL_TENSOR.FFN_GATE_INP: (
@@ -240,6 +248,7 @@ class TensorNameMap:
240248
"model.layers.{bid}.feed_forward.w3", # internlm2
241249
"encoder.layers.{bid}.mlp.fc11", # nomic-bert
242250
"model.layers.{bid}.mlp.c_fc", # starcoder2
251+
"encoder.layers.{bid}.mlp.dense_h_to_4h", # chatglm
243252
),
244253

245254
MODEL_TENSOR.FFN_UP_EXP: (
@@ -299,6 +308,7 @@ class TensorNameMap:
299308
"model.layers.{bid}.feed_forward.w2", # internlm2
300309
"encoder.layers.{bid}.mlp.fc2", # nomic-bert
301310
"model.layers.{bid}.mlp.c_proj", # starcoder2
311+
"encoder.layers.{bid}.mlp.dense_4h_to_h", # chatglm
302312
),
303313

304314
MODEL_TENSOR.FFN_DOWN_EXP: (

gguf-py/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "gguf"
3-
version = "0.9.0"
3+
version = "0.9.1"
44
description = "Read and write ML models in GGUF for GGML"
55
authors = ["GGML <ggml@ggml.ai>"]
66
packages = [

0 commit comments

Comments
 (0)