8
8
9
9
import numpy as np
10
10
from typing import Any , Mapping , Sequence
11
+ from dataclasses import dataclass
11
12
12
13
# Necessary to load the local gguf package
13
14
if "NO_LOCAL_GGUF" not in os .environ and (Path (__file__ ).parent .parent .parent / 'gguf-py' ).exists ():
18
19
logger = logging .getLogger ("gguf-new-metadata" )
19
20
20
21
22
+ @dataclass
23
+ class MetadataDetails :
24
+ type : gguf .GGUFValueType
25
+ value : Any
26
+ description : str = ''
27
+
21
28
def get_byteorder (reader : gguf .GGUFReader ) -> gguf .GGUFEndian :
22
29
if np .uint32 (1 ) == np .uint32 (1 ).newbyteorder ("<" ):
23
30
# Host is little endian
@@ -59,7 +66,16 @@ def get_field_data(reader: gguf.GGUFReader, key: str) -> Any:
59
66
return decode_field (field )
60
67
61
68
62
- def copy_with_new_metadata (reader : gguf .GGUFReader , writer : gguf .GGUFWriter , new_metadata : Mapping [str , str ], remove_metadata : Sequence [str ]) -> None :
69
+ def find_token (token_list : Sequence [int ], token : str ) -> Sequence [int ]:
70
+ token_ids = [index for index , value in enumerate (token_list ) if value == token ]
71
+
72
+ if len (token_ids ) == 0 :
73
+ raise LookupError (f'Unable to find "{ token } " in token list!' )
74
+
75
+ return token_ids
76
+
77
+
78
+ def copy_with_new_metadata (reader : gguf .GGUFReader , writer : gguf .GGUFWriter , new_metadata : Mapping [str , MetadataDetails ], remove_metadata : Sequence [str ]) -> None :
63
79
for field in reader .fields .values ():
64
80
# Suppress virtual fields and fields written by GGUFWriter
65
81
if field .name == gguf .Keys .General .ARCHITECTURE or field .name .startswith ('GGUF.' ):
@@ -75,29 +91,28 @@ def copy_with_new_metadata(reader: gguf.GGUFReader, writer: gguf.GGUFWriter, new
75
91
logger .debug (f'Removing { field .name } ' )
76
92
continue
77
93
78
- old_val = decode_field (field )
94
+ old_val = MetadataDetails ( field . types [ 0 ], decode_field (field ) )
79
95
val = new_metadata .get (field .name , old_val )
80
96
81
97
if field .name in new_metadata :
82
- logger .debug (f'Modifying { field .name } : "{ old_val } " -> "{ val } " ' )
98
+ logger .debug (f'Modifying { field .name } : "{ old_val . value } " -> "{ val . value } " { val . description } ' )
83
99
del new_metadata [field .name ]
84
- elif val is not None :
100
+ elif val . value is not None :
85
101
logger .debug (f'Copying { field .name } ' )
86
102
87
- if val is not None :
103
+ if val . value is not None :
88
104
writer .add_key (field .name )
89
- writer .add_val (val , field . types [ 0 ] )
105
+ writer .add_val (val . value , val . type )
90
106
91
107
if gguf .Keys .Tokenizer .CHAT_TEMPLATE in new_metadata :
92
108
logger .debug ('Adding chat template(s)' )
93
- writer .add_chat_template (new_metadata [gguf .Keys .Tokenizer .CHAT_TEMPLATE ])
109
+ writer .add_chat_template (new_metadata [gguf .Keys .Tokenizer .CHAT_TEMPLATE ]. value )
94
110
del new_metadata [gguf .Keys .Tokenizer .CHAT_TEMPLATE ]
95
111
96
- # TODO: Support other types than string?
97
112
for key , val in new_metadata .items ():
98
- logger .debug (f'Adding { key } : { val } ' )
113
+ logger .debug (f'Adding { key } : " { val . value } " { val . description } ' )
99
114
writer .add_key (key )
100
- writer .add_val (val , gguf . GGUFValueType . STRING )
115
+ writer .add_val (val . value , val . type )
101
116
102
117
for tensor in reader .tensors :
103
118
# Dimensions are written in reverse order, so flip them first
@@ -115,6 +130,9 @@ def copy_with_new_metadata(reader: gguf.GGUFReader, writer: gguf.GGUFWriter, new
115
130
116
131
117
132
def main () -> None :
133
+ tokenizer_metadata = (getattr (gguf .Keys .Tokenizer , n ) for n in gguf .Keys .Tokenizer .__dict__ .keys () if not n .startswith ('_' ))
134
+ token_names = dict ((n .split ('.' )[- 1 ][:- len ('_token_id' )], n ) for n in tokenizer_metadata if n .endswith ('_token_id' ))
135
+
118
136
parser = argparse .ArgumentParser (description = "Make a copy of a GGUF file with new metadata" )
119
137
parser .add_argument ("input" , type = Path , help = "GGUF format model input filename" )
120
138
parser .add_argument ("output" , type = Path , help = "GGUF format model output filename" )
@@ -123,6 +141,8 @@ def main() -> None:
123
141
parser .add_argument ("--chat-template" , type = str , help = "Chat template string (or JSON string containing templates)" )
124
142
parser .add_argument ("--chat-template-config" , type = Path , help = "Config file (tokenizer_config.json) containing chat template(s)" )
125
143
parser .add_argument ("--remove-metadata" , action = "append" , type = str , help = "Remove metadata (by key name) from output model" )
144
+ parser .add_argument ("--special-token" , action = "append" , type = str , help = "Special token by value" , nargs = 2 , metavar = (' | ' .join (token_names .keys ()), '"<token>"' ))
145
+ parser .add_argument ("--special-token-by-id" , action = "append" , type = str , help = "Special token by id" , nargs = 2 , metavar = (' | ' .join (token_names .keys ()), '0' ))
126
146
parser .add_argument ("--force" , action = "store_true" , help = "Bypass warnings without confirmation" )
127
147
parser .add_argument ("--verbose" , action = "store_true" , help = "Increase output verbosity" )
128
148
args = parser .parse_args (None if len (sys .argv ) > 2 else ["--help" ])
@@ -133,20 +153,20 @@ def main() -> None:
133
153
remove_metadata = args .remove_metadata or []
134
154
135
155
if args .general_name :
136
- new_metadata [gguf .Keys .General .NAME ] = args .general_name
156
+ new_metadata [gguf .Keys .General .NAME ] = MetadataDetails ( gguf . GGUFValueType . STRING , args .general_name )
137
157
138
158
if args .general_description :
139
- new_metadata [gguf .Keys .General .DESCRIPTION ] = args .general_description
159
+ new_metadata [gguf .Keys .General .DESCRIPTION ] = MetadataDetails ( gguf . GGUFValueType . STRING , args .general_description )
140
160
141
161
if args .chat_template :
142
- new_metadata [gguf .Keys .Tokenizer .CHAT_TEMPLATE ] = json .loads (args .chat_template ) if args .chat_template .startswith ('[' ) else args .chat_template
162
+ new_metadata [gguf .Keys .Tokenizer .CHAT_TEMPLATE ] = MetadataDetails ( gguf . GGUFValueType . STRING , json .loads (args .chat_template ) if args .chat_template .startswith ('[' ) else args .chat_template )
143
163
144
164
if args .chat_template_config :
145
165
with open (args .chat_template_config , 'r' ) as fp :
146
166
config = json .load (fp )
147
167
template = config .get ('chat_template' )
148
168
if template :
149
- new_metadata [gguf .Keys .Tokenizer .CHAT_TEMPLATE ] = template
169
+ new_metadata [gguf .Keys .Tokenizer .CHAT_TEMPLATE ] = MetadataDetails ( gguf . GGUFValueType . STRING , template )
150
170
151
171
if remove_metadata :
152
172
logger .warning ('*** Warning *** Warning *** Warning **' )
@@ -166,6 +186,32 @@ def main() -> None:
166
186
arch = get_field_data (reader , gguf .Keys .General .ARCHITECTURE )
167
187
endianess = get_byteorder (reader )
168
188
189
+ token_list = get_field_data (reader , gguf .Keys .Tokenizer .LIST ) or []
190
+
191
+ for name , token in args .special_token or []:
192
+ if name not in token_names :
193
+ logger .warning (f'Unknown special token "{ name } ", ignoring...' )
194
+ else :
195
+ ids = find_token (token_list , token )
196
+ new_metadata [token_names [name ]] = MetadataDetails (gguf .GGUFValueType .UINT32 , ids [0 ], f'= { token } ' )
197
+
198
+ if len (ids ) > 1 :
199
+ logger .warning (f'Multiple "{ token } " tokens found, choosing ID { ids [0 ]} , use --special-token-by-id if you want another:' )
200
+ logger .warning (', ' .join (ids ))
201
+
202
+ for name , id_string in args .special_token_by_id or []:
203
+ if name not in token_names :
204
+ logger .warning (f'Unknown special token "{ name } ", ignoring...' )
205
+ elif not id_string .isdecimal ():
206
+ logger .warning (f'Token ID "{ id_string } " is not a valid ID, ignoring...' )
207
+ else :
208
+ id_int = int (id_string )
209
+
210
+ if id_int >= 0 and id_int < len (token_list ):
211
+ new_metadata [token_names [name ]] = MetadataDetails (gguf .GGUFValueType .UINT32 , id_int , f'= { token_list [id_int ]} ' )
212
+ else :
213
+ logger .warning (f'Token ID { id_int } is not within token list, ignoring...' )
214
+
169
215
if os .path .isfile (args .output ) and not args .force :
170
216
logger .warning ('*** Warning *** Warning *** Warning **' )
171
217
logger .warning (f'* The "{ args .output } " GGUF file already exists, it will be overwritten!' )
0 commit comments