@@ -79,7 +79,7 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path,
79
79
if not self .is_safetensors :
80
80
self .part_names = Model .get_model_part_names (self .dir_model , ".bin" )
81
81
self .hparams = Model .load_hparams (self .dir_model )
82
- self .block_count = self .find_hparam (["n_layers" , "num_hidden_layers" , "n_layer" ])
82
+ self .block_count = self .find_hparam (["n_layers" , "num_hidden_layers" , "n_layer" , "num_layers" ])
83
83
self .tensor_map = gguf .get_tensor_name_map (self .model_arch , self .block_count )
84
84
self .tensor_names = None
85
85
if self .ftype == gguf .LlamaFileType .GUESSED :
@@ -2620,6 +2620,167 @@ def write_tensors(self):
2620
2620
raise ValueError (f"Unprocessed experts: { experts } " )
2621
2621
2622
2622
2623
+ @Model .register ("ChatGLMModel" )
2624
+ class ChatGLMModel (Model ):
2625
+ model_arch = gguf .MODEL_ARCH .CHATGLM
2626
+
2627
+ def set_vocab (self ):
2628
+ dir_model = self .dir_model
2629
+ hparams = self .hparams
2630
+ tokens : list [bytearray ] = []
2631
+ toktypes : list [int ] = []
2632
+ scores : list [float ] = []
2633
+
2634
+ from transformers import AutoTokenizer
2635
+ tokenizer = AutoTokenizer .from_pretrained (dir_model , trust_remote_code = True )
2636
+ vocab_size = hparams .get ("padded_vocab_size" , len (tokenizer .get_vocab ()))
2637
+ assert max (tokenizer .get_vocab ().values ()) < vocab_size
2638
+
2639
+ reverse_vocab = {id_ : encoded_tok for encoded_tok , id_ in tokenizer .get_vocab ().items ()}
2640
+
2641
+ for token_id in range (vocab_size ):
2642
+ piece = tokenizer ._convert_id_to_token (token_id )
2643
+ if token_id == 0 :
2644
+ piece = "<unk>"
2645
+ elif token_id == 1 :
2646
+ piece = "<bos>"
2647
+ elif token_id == 2 :
2648
+ piece = "<eos>"
2649
+
2650
+ text = piece .encode ("utf-8" )
2651
+ score = 0.0
2652
+ if len (piece ) != 0 and token_id < 64789 :
2653
+ score = tokenizer .tokenizer .sp_model .get_score (token_id )
2654
+
2655
+ if len (piece ) == 0 :
2656
+ text = f"[PAD{ token_id } ]" .encode ("utf-8" )
2657
+
2658
+ if token_id >= 64789 :
2659
+ toktype = SentencePieceTokenTypes .UNKNOWN
2660
+ tokens .append (text )
2661
+ scores .append (score )
2662
+ toktypes .append (toktype )
2663
+ continue
2664
+
2665
+ toktype = SentencePieceTokenTypes .NORMAL
2666
+ if tokenizer .tokenizer .sp_model .is_unknown (token_id ):
2667
+ toktype = SentencePieceTokenTypes .UNKNOWN
2668
+ elif tokenizer .tokenizer .sp_model .is_control (token_id ):
2669
+ toktype = SentencePieceTokenTypes .CONTROL
2670
+ elif tokenizer .tokenizer .sp_model .is_unused (token_id ):
2671
+ toktype = SentencePieceTokenTypes .UNUSED
2672
+ elif tokenizer .tokenizer .sp_model .is_byte (token_id ):
2673
+ toktype = SentencePieceTokenTypes .BYTE
2674
+
2675
+ tokens .append (text )
2676
+ scores .append (score )
2677
+ toktypes .append (toktype )
2678
+
2679
+ self .gguf_writer .add_tokenizer_model ("llama" )
2680
+ self .gguf_writer .add_token_list (tokens )
2681
+ self .gguf_writer .add_token_scores (scores )
2682
+ self .gguf_writer .add_token_types (toktypes )
2683
+
2684
+ special_vocab = gguf .SpecialVocab (self .dir_model , n_vocab = len (tokens ))
2685
+ special_vocab .add_to_gguf (self .gguf_writer )
2686
+
2687
+ def set_gguf_parameters (self ):
2688
+ self .gguf_writer .add_name ("ChatGLM-6b-chat" )
2689
+ n_embed = self .hparams .get ("hidden_size" , self .hparams .get ("n_embed" ))
2690
+ n_head = self .hparams .get ("n_head" , self .hparams .get ("num_attention_heads" ))
2691
+ n_head_kv = self .hparams .get ("multi_query_group_num" , n_head )
2692
+ self .gguf_writer .add_context_length (self .hparams .get ("seq_length" , n_embed ))
2693
+ self .gguf_writer .add_embedding_length (n_embed )
2694
+ self .gguf_writer .add_feed_forward_length (self .hparams .get ("ffn_hidden_size" , 4 * n_embed ))
2695
+ self .gguf_writer .add_block_count (self .hparams ["num_layers" ])
2696
+ self .gguf_writer .add_head_count (n_head )
2697
+ self .gguf_writer .add_head_count_kv (n_head_kv )
2698
+ self .gguf_writer .add_layer_norm_rms_eps (self .hparams ["layernorm_epsilon" ])
2699
+ self .gguf_writer .add_file_type (self .ftype )
2700
+ self .gguf_writer .add_rope_dimension_count (64 )
2701
+ self .gguf_writer .add_add_bos_token (False )
2702
+
2703
+ def write_tensors (self ):
2704
+ block_count = self .hparams ["num_layers" ]
2705
+ tensors = dict (self .get_tensors ())
2706
+ tensor_map = gguf .get_tensor_name_map (self .model_arch , block_count )
2707
+ has_lm_head = True
2708
+ n_head = self .hparams .get ("n_head" , self .hparams .get ("num_attention_heads" ))
2709
+ n_embed = self .hparams .get ("hidden_size" , self .hparams .get ("n_embed" ))
2710
+
2711
+ for name , data_torch in tensors .items ():
2712
+ if name .endswith (".rotary_pos_emb.inv_freq" ):
2713
+ continue
2714
+
2715
+ if "lm_head.weight" not in tensors .keys () and "output.weight" not in tensors .keys ():
2716
+ has_lm_head = False
2717
+
2718
+ name = re .sub (r'transformer\.' , '' , name )
2719
+
2720
+ old_dtype = data_torch .dtype
2721
+
2722
+ # convert any unsupported data types to float32
2723
+ if data_torch .dtype not in (torch .float16 , torch .float32 ):
2724
+ data_torch = data_torch .to (torch .float32 )
2725
+
2726
+ data = data_torch .squeeze ().numpy ()
2727
+
2728
+ if re .match (r"h\.\d+\.self_attention\.query_key_value\.weight" , name ):
2729
+ # Map bloom-style qkv_linear to gpt-style qkv_linear
2730
+ # bloom: https://github.com/huggingface/transformers/blob/main/src/transformers/models/bloom/modeling_bloom.py#L238-L252 # noqa
2731
+ # gpt-2: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py#L312 # noqa
2732
+ qkv_weights = data .reshape ((n_head , 3 , n_embed // n_head , n_embed ))
2733
+ data = np .concatenate (
2734
+ (
2735
+ qkv_weights [:, 0 , :, :].reshape ((- 1 , n_embed )),
2736
+ qkv_weights [:, 1 , :, :].reshape ((- 1 , n_embed )),
2737
+ qkv_weights [:, 2 , :, :].reshape ((- 1 , n_embed )),
2738
+ ),
2739
+ axis = 0 ,
2740
+ )
2741
+ print ("re-format attention.linear_qkv.weight" )
2742
+ elif re .match (r"h\.\d+\.self_attention\.query_key_value\.bias" , name ):
2743
+ qkv_bias = data .reshape ((n_head , 3 , n_embed // n_head ))
2744
+ data = np .concatenate (
2745
+ (
2746
+ qkv_bias [:, 0 , :].reshape ((n_embed ,)),
2747
+ qkv_bias [:, 1 , :].reshape ((n_embed ,)),
2748
+ qkv_bias [:, 2 , :].reshape ((n_embed ,)),
2749
+ ),
2750
+ axis = 0 ,
2751
+ )
2752
+ print ("re-format attention.linear_qkv.bias" )
2753
+
2754
+ # map tensor names
2755
+ new_name = tensor_map .get_name (name , try_suffixes = (".weight" , ".bias" ))
2756
+ if new_name is None :
2757
+ print (f"Can not map tensor { name !r} " )
2758
+ sys .exit ()
2759
+
2760
+ n_dims = len (data .shape )
2761
+ data_dtype = data .dtype
2762
+
2763
+ # if f32 desired, convert any float16 to float32
2764
+ if self .ftype == 0 and data_dtype == np .float16 :
2765
+ data = data .astype (np .float32 )
2766
+
2767
+ # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
2768
+ if self .ftype == 1 and data_dtype == np .float16 and n_dims == 1 :
2769
+ data = data .astype (np .float32 )
2770
+
2771
+ # if f16 desired, convert any float32 2-dim weight tensors to float16
2772
+ if self .ftype == 1 and data_dtype == np .float32 and name .endswith (".weight" ) and n_dims == 2 :
2773
+ data = data .astype (np .float16 )
2774
+
2775
+ print (f"=> { new_name } , shape = { data .shape } , { old_dtype } --> { data .dtype } " )
2776
+
2777
+ self .gguf_writer .add_tensor (new_name , data )
2778
+
2779
+ if not has_lm_head and name == "word_embeddings.weight" :
2780
+ self .gguf_writer .add_tensor ("output.weight" , data )
2781
+ print (name , f"=> output.weight, shape = { data .shape } , { old_dtype } --> { data .dtype } " )
2782
+
2783
+
2623
2784
###### CONVERSION LOGIC ######
2624
2785
2625
2786
0 commit comments