Skip to content

Commit 5747a0c

Browse files
committed
add chatglm3-6b model support huggingface model: https://hf-mirror.com/THUDM/chatglm3-6b
Signed-off-by: XingXing Qiao <qiaoxx@dingdao.com>
1 parent 152da28 commit 5747a0c

File tree

6 files changed

+413
-7
lines changed

6 files changed

+413
-7
lines changed

convert-hf-to-gguf.py

Lines changed: 161 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path,
7979
if not self.is_safetensors:
8080
self.part_names = Model.get_model_part_names(self.dir_model, ".bin")
8181
self.hparams = Model.load_hparams(self.dir_model)
82-
self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer"])
82+
self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer", "num_layers"])
8383
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
8484
self.tensor_names = None
8585
if self.ftype == gguf.LlamaFileType.GUESSED:
@@ -2427,6 +2427,166 @@ def set_vocab(self, *args, **kwargs):
24272427
self.gguf_writer.add_add_bos_token(True)
24282428
self.gguf_writer.add_add_eos_token(True)
24292429

2430+
@Model.register("ChatGLMModel")
2431+
class ChatGLMModel(Model):
2432+
model_arch = gguf.MODEL_ARCH.CHATGLM
2433+
2434+
def set_vocab(self):
2435+
dir_model = self.dir_model
2436+
hparams = self.hparams
2437+
tokens: list[bytearray] = []
2438+
toktypes: list[int] = []
2439+
scores: list[float] = []
2440+
2441+
from transformers import AutoTokenizer
2442+
tokenizer = AutoTokenizer.from_pretrained(dir_model, trust_remote_code=True)
2443+
vocab_size = hparams.get("padded_vocab_size", len(tokenizer.get_vocab()))
2444+
assert max(tokenizer.get_vocab().values()) < vocab_size
2445+
2446+
reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in tokenizer.get_vocab().items()}
2447+
2448+
for token_id in range(vocab_size):
2449+
piece = tokenizer._convert_id_to_token(token_id)
2450+
if token_id == 0:
2451+
piece = "<unk>"
2452+
elif token_id == 1:
2453+
piece = "<bos>"
2454+
elif token_id == 2:
2455+
piece = "<eos>"
2456+
2457+
text = piece.encode("utf-8")
2458+
score = 0.0
2459+
if len(piece) != 0 and token_id < 64789:
2460+
score = tokenizer.tokenizer.sp_model.get_score(token_id)
2461+
2462+
if len(piece) == 0:
2463+
text = f"[PAD{token_id}]".encode("utf-8")
2464+
2465+
if token_id >= 64789:
2466+
toktype = SentencePieceTokenTypes.UNKNOWN
2467+
tokens.append(text)
2468+
scores.append(score)
2469+
toktypes.append(toktype)
2470+
continue
2471+
2472+
toktype = SentencePieceTokenTypes.NORMAL
2473+
if tokenizer.tokenizer.sp_model.is_unknown(token_id):
2474+
toktype = SentencePieceTokenTypes.UNKNOWN
2475+
elif tokenizer.tokenizer.sp_model.is_control(token_id):
2476+
toktype = SentencePieceTokenTypes.CONTROL
2477+
elif tokenizer.tokenizer.sp_model.is_unused(token_id):
2478+
toktype = SentencePieceTokenTypes.UNUSED
2479+
elif tokenizer.tokenizer.sp_model.is_byte(token_id):
2480+
toktype = SentencePieceTokenTypes.BYTE
2481+
2482+
tokens.append(text)
2483+
scores.append(score)
2484+
toktypes.append(toktype)
2485+
2486+
self.gguf_writer.add_tokenizer_model("llama")
2487+
self.gguf_writer.add_token_list(tokens)
2488+
self.gguf_writer.add_token_scores(scores)
2489+
self.gguf_writer.add_token_types(toktypes)
2490+
2491+
special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
2492+
special_vocab.add_to_gguf(self.gguf_writer)
2493+
2494+
def set_gguf_parameters(self):
2495+
self.gguf_writer.add_name("ChatGLM-6b-chat")
2496+
n_embed = self.hparams.get("hidden_size", self.hparams.get("n_embed"))
2497+
n_head = self.hparams.get("n_head", self.hparams.get("num_attention_heads"))
2498+
n_head_kv = self.hparams.get("multi_query_group_num", n_head)
2499+
self.gguf_writer.add_context_length(self.hparams.get("seq_length", n_embed))
2500+
self.gguf_writer.add_embedding_length(n_embed)
2501+
self.gguf_writer.add_feed_forward_length(self.hparams.get("ffn_hidden_size", 4 * n_embed))
2502+
self.gguf_writer.add_block_count(self.hparams["num_layers"])
2503+
self.gguf_writer.add_head_count(n_head)
2504+
self.gguf_writer.add_head_count_kv(n_head_kv)
2505+
self.gguf_writer.add_layer_norm_rms_eps(self.hparams["layernorm_epsilon"])
2506+
self.gguf_writer.add_file_type(self.ftype)
2507+
self.gguf_writer.add_rope_dimension_count(64)
2508+
self.gguf_writer.add_add_bos_token(False)
2509+
2510+
def write_tensors(self):
2511+
block_count = self.hparams["num_layers"]
2512+
tensors = dict(self.get_tensors())
2513+
tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
2514+
has_lm_head = True
2515+
n_head = self.hparams.get("n_head", self.hparams.get("num_attention_heads"))
2516+
n_embed = self.hparams.get("hidden_size", self.hparams.get("n_embed"))
2517+
2518+
for name, data_torch in tensors.items():
2519+
if name.endswith(".rotary_pos_emb.inv_freq"):
2520+
continue
2521+
2522+
if "lm_head.weight" not in tensors.keys() and "output.weight" not in tensors.keys():
2523+
has_lm_head = False
2524+
2525+
name = re.sub(r'transformer\.', '', name)
2526+
2527+
old_dtype = data_torch.dtype
2528+
2529+
# convert any unsupported data types to float32
2530+
if data_torch.dtype not in (torch.float16, torch.float32):
2531+
data_torch = data_torch.to(torch.float32)
2532+
2533+
data = data_torch.squeeze().numpy()
2534+
2535+
if re.match(r"h\.\d+\.self_attention\.query_key_value\.weight", name):
2536+
# Map bloom-style qkv_linear to gpt-style qkv_linear
2537+
# bloom: https://github.com/huggingface/transformers/blob/main/src/transformers/models/bloom/modeling_bloom.py#L238-L252 # noqa
2538+
# gpt-2: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py#L312 # noqa
2539+
qkv_weights = data.reshape((n_head, 3, n_embed // n_head, n_embed))
2540+
data = np.concatenate(
2541+
(
2542+
qkv_weights[:, 0, :, :].reshape((-1, n_embed)),
2543+
qkv_weights[:, 1, :, :].reshape((-1, n_embed)),
2544+
qkv_weights[:, 2, :, :].reshape((-1, n_embed)),
2545+
),
2546+
axis=0,
2547+
)
2548+
print("re-format attention.linear_qkv.weight")
2549+
elif re.match(r"h\.\d+\.self_attention\.query_key_value\.bias", name):
2550+
qkv_bias = data.reshape((n_head, 3, n_embed // n_head))
2551+
data = np.concatenate(
2552+
(
2553+
qkv_bias[:, 0, :].reshape((n_embed,)),
2554+
qkv_bias[:, 1, :].reshape((n_embed,)),
2555+
qkv_bias[:, 2, :].reshape((n_embed,)),
2556+
),
2557+
axis=0,
2558+
)
2559+
print("re-format attention.linear_qkv.bias")
2560+
2561+
# map tensor names
2562+
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
2563+
if new_name is None:
2564+
print(f"Can not map tensor {name!r}")
2565+
sys.exit()
2566+
2567+
n_dims = len(data.shape)
2568+
data_dtype = data.dtype
2569+
2570+
# if f32 desired, convert any float16 to float32
2571+
if self.ftype == 0 and data_dtype == np.float16:
2572+
data = data.astype(np.float32)
2573+
2574+
# TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
2575+
if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
2576+
data = data.astype(np.float32)
2577+
2578+
# if f16 desired, convert any float32 2-dim weight tensors to float16
2579+
if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
2580+
data = data.astype(np.float16)
2581+
2582+
print(f"=> {new_name}, shape = {data.shape}, {old_dtype} --> {data.dtype}")
2583+
2584+
self.gguf_writer.add_tensor(new_name, data)
2585+
2586+
if not has_lm_head and name == "word_embeddings.weight":
2587+
self.gguf_writer.add_tensor("output.weight", data)
2588+
print(name, f"=> output.weight, shape = {data.shape}, {old_dtype} --> {data.dtype}")
2589+
24302590

24312591
###### CONVERSION LOGIC ######
24322592

gguf-py/gguf/constants.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ class MODEL_ARCH(IntEnum):
139139
COMMAND_R = auto()
140140
DBRX = auto()
141141
OLMO = auto()
142+
CHATGLM = auto()
142143

143144

144145
class MODEL_TENSOR(IntEnum):
@@ -218,6 +219,7 @@ class MODEL_TENSOR(IntEnum):
218219
MODEL_ARCH.COMMAND_R: "command-r",
219220
MODEL_ARCH.DBRX: "dbrx",
220221
MODEL_ARCH.OLMO: "olmo",
222+
MODEL_ARCH.CHATGLM: "chatglm",
221223
}
222224

223225
TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
@@ -732,6 +734,18 @@ class MODEL_TENSOR(IntEnum):
732734
MODEL_TENSOR.FFN_DOWN,
733735
MODEL_TENSOR.FFN_UP,
734736
],
737+
MODEL_ARCH.CHATGLM : [
738+
MODEL_TENSOR.TOKEN_EMBD,
739+
MODEL_TENSOR.ROPE_FREQS,
740+
MODEL_TENSOR.OUTPUT_NORM,
741+
MODEL_TENSOR.OUTPUT,
742+
MODEL_TENSOR.ATTN_NORM,
743+
MODEL_TENSOR.ATTN_QKV,
744+
MODEL_TENSOR.ATTN_OUT,
745+
MODEL_TENSOR.FFN_NORM,
746+
MODEL_TENSOR.FFN_DOWN,
747+
MODEL_TENSOR.FFN_UP,
748+
],
735749
# TODO
736750
}
737751

@@ -765,6 +779,9 @@ class MODEL_TENSOR(IntEnum):
765779
MODEL_TENSOR.ROPE_FREQS,
766780
MODEL_TENSOR.ATTN_ROT_EMBD,
767781
],
782+
MODEL_ARCH.CHATGLM: [
783+
MODEL_TENSOR.ROPE_FREQS,
784+
],
768785
}
769786

770787
#

gguf-py/gguf/tensor_mapping.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ class TensorNameMap:
2424
"backbone.embedding", # mamba
2525
"backbone.embeddings", # mamba-hf
2626
"transformer.in_out_embed", # Grok
27+
"embedding.word_embeddings", # chatglm
2728
),
2829

2930
# Token type embeddings
@@ -52,6 +53,7 @@ class TensorNameMap:
5253
"output", # llama-pth bloom internlm2
5354
"word_embeddings_for_head", # persimmon
5455
"lm_head.linear", # phi2
56+
"output_layer", # chatglm
5557
),
5658

5759
# Output norm
@@ -68,11 +70,13 @@ class TensorNameMap:
6870
"model.norm_f", # mamba-qbert
6971
"backbone.norm_f", # mamba
7072
"transformer.rms_norm", # Grok
73+
"encoder.final_layernorm", # chatglm
7174
),
7275

7376
# Rope frequencies
7477
MODEL_TENSOR.ROPE_FREQS: (
7578
"rope.freqs", # llama-pth
79+
"rotary_pos_emb.inv_freq", # chatglm
7680
),
7781
}
7882

@@ -97,6 +101,7 @@ class TensorNameMap:
97101
"backbone.layers.{bid}.norm", # mamba
98102
"transformer.decoder_layer.{bid}.rms_norm", # Grok
99103
"transformer.blocks.{bid}.norm_attn_norm.norm_1", # dbrx
104+
"encoder.layers.{bid}.input_layernorm", # chatglm
100105
),
101106

102107
# Attention norm 2
@@ -117,7 +122,8 @@ class TensorNameMap:
117122
"h.{bid}.attn.c_attn", # gpt2
118123
"transformer.h.{bid}.mixer.Wqkv", # phi2
119124
"encoder.layers.{bid}.attn.Wqkv", # nomic-bert
120-
"model.layers.{bid}.self_attn.qkv_proj" # phi3
125+
"model.layers.{bid}.self_attn.qkv_proj", # phi3
126+
"encoder.layers.{bid}.self_attention.query_key_value", # chatglm
121127
),
122128

123129
# Attention query
@@ -128,7 +134,7 @@ class TensorNameMap:
128134
"transformer.h.{bid}.attn.q_proj", # gpt-j
129135
"model.layers.layers.{bid}.self_attn.q_proj", # plamo
130136
"model.layers.{bid}.attention.wq", # internlm2
131-
"transformer.decoder_layer.{bid}.multi_head_attention.query" # Grok
137+
"transformer.decoder_layer.{bid}.multi_head_attention.query",# Grok
132138
),
133139

134140
# Attention key
@@ -140,7 +146,7 @@ class TensorNameMap:
140146
"transformer.h.{bid}.attn.k", # refact
141147
"model.layers.layers.{bid}.self_attn.k_proj", # plamo
142148
"model.layers.{bid}.attention.wk", # internlm2
143-
"transformer.decoder_layer.{bid}.multi_head_attention.key" # Grok
149+
"transformer.decoder_layer.{bid}.multi_head_attention.key",# Grok
144150
),
145151

146152
# Attention value
@@ -152,7 +158,7 @@ class TensorNameMap:
152158
"transformer.h.{bid}.attn.v", # refact
153159
"model.layers.layers.{bid}.self_attn.v_proj", # plamo
154160
"model.layers.{bid}.attention.wv", # internlm2
155-
"transformer.decoder_layer.{bid}.multi_head_attention.value" # Grok
161+
"transformer.decoder_layer.{bid}.multi_head_attention.value",# Grok
156162
),
157163

158164
# Attention output
@@ -175,6 +181,7 @@ class TensorNameMap:
175181
"encoder.layers.{bid}.attn.out_proj", # nomic-bert
176182
"transformer.decoder_layer.{bid}.multi_head_attention.linear", # Grok
177183
"transformer.blocks.{bid}.norm_attn_norm.attn.out_proj", # dbrx
184+
"encoder.layers.{bid}.self_attention.dense", # chatglm
178185
),
179186

180187
# Attention output norm
@@ -206,6 +213,7 @@ class TensorNameMap:
206213
"h.{bid}.ln_2", # gpt2
207214
"model.layers.{bid}.ffn_norm", # internlm2
208215
"transformer.decoder_layer.{bid}.rms_norm_2", # Grok
216+
"encoder.layers.{bid}.post_attention_layernorm", # chatglm
209217
),
210218

211219
MODEL_TENSOR.FFN_GATE_INP: (
@@ -244,6 +252,7 @@ class TensorNameMap:
244252
"encoder.layers.{bid}.mlp.fc11", # nomic-bert
245253
"model.layers.{bid}.mlp.c_fc", # starcoder2
246254
"encoder.layer.{bid}.mlp.gated_layers_v", # jina-bert-v2
255+
"encoder.layers.{bid}.mlp.dense_h_to_4h", # chatglm
247256
),
248257

249258
MODEL_TENSOR.FFN_UP_EXP: (
@@ -306,6 +315,7 @@ class TensorNameMap:
306315
"encoder.layers.{bid}.mlp.fc2", # nomic-bert
307316
"model.layers.{bid}.mlp.c_proj", # starcoder2
308317
"encoder.layer.{bid}.mlp.wo", # jina-bert-v2
318+
"encoder.layers.{bid}.mlp.dense_4h_to_h", # chatglm
309319
),
310320

311321
MODEL_TENSOR.FFN_DOWN_EXP: (

gguf-py/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "gguf"
3-
version = "0.9.0"
3+
version = "0.9.1"
44
description = "Read and write ML models in GGUF for GGML"
55
authors = ["GGML <ggml@ggml.ai>"]
66
packages = [

0 commit comments

Comments
 (0)