Skip to content

Commit 9e4968c

Browse files
authored
Add special token modification capability
To be able to fix/amend special tokens in a GGUF let's add two new arguments: * `--special-token <name> <value>` where `<name>` can be bos, eos, prefix, middle, etc. while `<value>` is the token value, f.ex. `"<|fim▁begin|>"` * `--special-token-by-id <name> <id>` where `<id>` is the ID of the token, f.ex. 32006 So, in order to f.ex. add fill-in-middle tokens to a GGUF you would do the following: ```bash python3 gguf-new-metadata.py input.gguf output.gguf --special-token prefix "<|fim▁begin|>" --special-token middle "<|fim▁hole|>" --special-token suffix "<|fim▁end|>" ```
1 parent 0e4802b commit 9e4968c

File tree

1 file changed

+60
-14
lines changed

1 file changed

+60
-14
lines changed

gguf-py/scripts/gguf-new-metadata.py

Lines changed: 60 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import numpy as np
1010
from typing import Any, Mapping, Sequence
11+
from dataclasses import dataclass
1112

1213
# Necessary to load the local gguf package
1314
if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent / 'gguf-py').exists():
@@ -18,6 +19,12 @@
1819
logger = logging.getLogger("gguf-new-metadata")
1920

2021

22+
@dataclass
23+
class MetadataDetails:
24+
type: gguf.GGUFValueType
25+
value: Any
26+
description: str = ''
27+
2128
def get_byteorder(reader: gguf.GGUFReader) -> gguf.GGUFEndian:
2229
if np.uint32(1) == np.uint32(1).newbyteorder("<"):
2330
# Host is little endian
@@ -59,7 +66,16 @@ def get_field_data(reader: gguf.GGUFReader, key: str) -> Any:
5966
return decode_field(field)
6067

6168

62-
def copy_with_new_metadata(reader: gguf.GGUFReader, writer: gguf.GGUFWriter, new_metadata: Mapping[str, str], remove_metadata: Sequence[str]) -> None:
69+
def find_token(token_list: Sequence[int], token: str) -> Sequence[int]:
70+
token_ids = [index for index, value in enumerate(token_list) if value == token]
71+
72+
if len(token_ids) == 0:
73+
raise LookupError(f'Unable to find "{token}" in token list!')
74+
75+
return token_ids
76+
77+
78+
def copy_with_new_metadata(reader: gguf.GGUFReader, writer: gguf.GGUFWriter, new_metadata: Mapping[str, MetadataDetails], remove_metadata: Sequence[str]) -> None:
6379
for field in reader.fields.values():
6480
# Suppress virtual fields and fields written by GGUFWriter
6581
if field.name == gguf.Keys.General.ARCHITECTURE or field.name.startswith('GGUF.'):
@@ -75,29 +91,28 @@ def copy_with_new_metadata(reader: gguf.GGUFReader, writer: gguf.GGUFWriter, new
7591
logger.debug(f'Removing {field.name}')
7692
continue
7793

78-
old_val = decode_field(field)
94+
old_val = MetadataDetails(field.types[0], decode_field(field))
7995
val = new_metadata.get(field.name, old_val)
8096

8197
if field.name in new_metadata:
82-
logger.debug(f'Modifying {field.name}: "{old_val}" -> "{val}"')
98+
logger.debug(f'Modifying {field.name}: "{old_val.value}" -> "{val.value}" {val.description}')
8399
del new_metadata[field.name]
84-
elif val is not None:
100+
elif val.value is not None:
85101
logger.debug(f'Copying {field.name}')
86102

87-
if val is not None:
103+
if val.value is not None:
88104
writer.add_key(field.name)
89-
writer.add_val(val, field.types[0])
105+
writer.add_val(val.value, val.type)
90106

91107
if gguf.Keys.Tokenizer.CHAT_TEMPLATE in new_metadata:
92108
logger.debug('Adding chat template(s)')
93-
writer.add_chat_template(new_metadata[gguf.Keys.Tokenizer.CHAT_TEMPLATE])
109+
writer.add_chat_template(new_metadata[gguf.Keys.Tokenizer.CHAT_TEMPLATE].value)
94110
del new_metadata[gguf.Keys.Tokenizer.CHAT_TEMPLATE]
95111

96-
# TODO: Support other types than string?
97112
for key, val in new_metadata.items():
98-
logger.debug(f'Adding {key}: {val}')
113+
logger.debug(f'Adding {key}: "{val.value}" {val.description}')
99114
writer.add_key(key)
100-
writer.add_val(val, gguf.GGUFValueType.STRING)
115+
writer.add_val(val.value, val.type)
101116

102117
for tensor in reader.tensors:
103118
# Dimensions are written in reverse order, so flip them first
@@ -115,6 +130,9 @@ def copy_with_new_metadata(reader: gguf.GGUFReader, writer: gguf.GGUFWriter, new
115130

116131

117132
def main() -> None:
133+
tokenizer_metadata = (getattr(gguf.Keys.Tokenizer, n) for n in gguf.Keys.Tokenizer.__dict__.keys() if not n.startswith('_'))
134+
token_names = dict((n.split('.')[-1][:-len('_token_id')], n) for n in tokenizer_metadata if n.endswith('_token_id'))
135+
118136
parser = argparse.ArgumentParser(description="Make a copy of a GGUF file with new metadata")
119137
parser.add_argument("input", type=Path, help="GGUF format model input filename")
120138
parser.add_argument("output", type=Path, help="GGUF format model output filename")
@@ -123,6 +141,8 @@ def main() -> None:
123141
parser.add_argument("--chat-template", type=str, help="Chat template string (or JSON string containing templates)")
124142
parser.add_argument("--chat-template-config", type=Path, help="Config file (tokenizer_config.json) containing chat template(s)")
125143
parser.add_argument("--remove-metadata", action="append", type=str, help="Remove metadata (by key name) from output model")
144+
parser.add_argument("--special-token", action="append", type=str, help="Special token by value", nargs=2, metavar=(' | '.join(token_names.keys()), '"<token>"'))
145+
parser.add_argument("--special-token-by-id", action="append", type=str, help="Special token by id", nargs=2, metavar=(' | '.join(token_names.keys()), '0'))
126146
parser.add_argument("--force", action="store_true", help="Bypass warnings without confirmation")
127147
parser.add_argument("--verbose", action="store_true", help="Increase output verbosity")
128148
args = parser.parse_args(None if len(sys.argv) > 2 else ["--help"])
@@ -133,20 +153,20 @@ def main() -> None:
133153
remove_metadata = args.remove_metadata or []
134154

135155
if args.general_name:
136-
new_metadata[gguf.Keys.General.NAME] = args.general_name
156+
new_metadata[gguf.Keys.General.NAME] = MetadataDetails(gguf.GGUFValueType.STRING, args.general_name)
137157

138158
if args.general_description:
139-
new_metadata[gguf.Keys.General.DESCRIPTION] = args.general_description
159+
new_metadata[gguf.Keys.General.DESCRIPTION] = MetadataDetails(gguf.GGUFValueType.STRING, args.general_description)
140160

141161
if args.chat_template:
142-
new_metadata[gguf.Keys.Tokenizer.CHAT_TEMPLATE] = json.loads(args.chat_template) if args.chat_template.startswith('[') else args.chat_template
162+
new_metadata[gguf.Keys.Tokenizer.CHAT_TEMPLATE] = MetadataDetails(gguf.GGUFValueType.STRING, json.loads(args.chat_template) if args.chat_template.startswith('[') else args.chat_template)
143163

144164
if args.chat_template_config:
145165
with open(args.chat_template_config, 'r') as fp:
146166
config = json.load(fp)
147167
template = config.get('chat_template')
148168
if template:
149-
new_metadata[gguf.Keys.Tokenizer.CHAT_TEMPLATE] = template
169+
new_metadata[gguf.Keys.Tokenizer.CHAT_TEMPLATE] = MetadataDetails(gguf.GGUFValueType.STRING, template)
150170

151171
if remove_metadata:
152172
logger.warning('*** Warning *** Warning *** Warning **')
@@ -166,6 +186,32 @@ def main() -> None:
166186
arch = get_field_data(reader, gguf.Keys.General.ARCHITECTURE)
167187
endianess = get_byteorder(reader)
168188

189+
token_list = get_field_data(reader, gguf.Keys.Tokenizer.LIST) or []
190+
191+
for name, token in args.special_token or []:
192+
if name not in token_names:
193+
logger.warning(f'Unknown special token "{name}", ignoring...')
194+
else:
195+
ids = find_token(token_list, token)
196+
new_metadata[token_names[name]] = MetadataDetails(gguf.GGUFValueType.UINT32, ids[0], f'= {token}')
197+
198+
if len(ids) > 1:
199+
logger.warning(f'Multiple "{token}" tokens found, choosing ID {ids[0]}, use --special-token-by-id if you want another:')
200+
logger.warning(', '.join(ids))
201+
202+
for name, id_string in args.special_token_by_id or []:
203+
if name not in token_names:
204+
logger.warning(f'Unknown special token "{name}", ignoring...')
205+
elif not id_string.isdecimal():
206+
logger.warning(f'Token ID "{id_string}" is not a valid ID, ignoring...')
207+
else:
208+
id_int = int(id_string)
209+
210+
if id_int >= 0 and id_int < len(token_list):
211+
new_metadata[token_names[name]] = MetadataDetails(gguf.GGUFValueType.UINT32, id_int, f'= {token_list[id_int]}')
212+
else:
213+
logger.warning(f'Token ID {id_int} is not within token list, ignoring...')
214+
169215
if os.path.isfile(args.output) and not args.force:
170216
logger.warning('*** Warning *** Warning *** Warning **')
171217
logger.warning(f'* The "{args.output}" GGUF file already exists, it will be overwritten!')

0 commit comments

Comments
 (0)