Skip to content

Commit 238005c

Browse files
authored
gguf-py : fix SpecialVocab parsing when post_processor is null (#14330)
1 parent 66aba7a commit 238005c

File tree

1 file changed

+72
-72
lines changed

1 file changed

+72
-72
lines changed

gguf-py/gguf/vocab.py

Lines changed: 72 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -167,81 +167,81 @@ def _try_load_from_tokenizer_json(self, path: Path) -> bool:
167167
tokenizer_config['bos_token'] = special_bos = special_cls
168168
if not special_eos and special_sep and tokenizer_config:
169169
tokenizer_config['eos_token'] = special_eos = special_sep
170-
post_processor = tokenizer.get('post_processor', {})
171-
for processor in post_processor.get('processors', [post_processor]):
172-
if processor.get('type') == 'RobertaProcessing':
173-
self.add_special_token['bos'] = True
174-
self.add_special_token['eos'] = True
175-
self.add_special_token['sep'] = True
176-
if not special_cls and tokenizer_config:
177-
special_cls = processor.get('cls', [special_bos])[0]
178-
tokenizer_config['cls_token'] = special_cls
179-
if not special_sep and tokenizer_config:
180-
special_sep = processor.get('sep', [special_eos])[0]
181-
tokenizer_config['sep_token'] = special_sep
182-
continue
183-
# Crude parsing of TemplateProcessing to determine if BOS/SEP/EOS should be added
184-
# Only works with simple templates, **will** get it wrong on unusual sequences
185-
if processor.get('type') == 'TemplateProcessing':
186-
tmpl_single = processor.get('single', [])
187-
tmpl_pair = processor.get('pair', [])
188-
special_first = None
189-
special_last = None
190-
if len(tmpl_single) > 1:
191-
if special_first := tmpl_single[0].get('SpecialToken', {}).get('id'):
192-
if not tokenizer_config:
193-
special_bos = special_first
194-
self.add_special_token['bos'] = True if special_first in (special_bos, special_cls) else False
195-
if special_first not in (special_bos, special_cls):
196-
logger.warning(f'Unknown leading special token {special_first!r} in TemplateProcessing<single>')
197-
if special_last := tmpl_single[-1].get('SpecialToken', {}).get('id'):
198-
if not tokenizer_config:
199-
special_eos = special_last
200-
elif special_last != special_eos:
201-
if 'eot' not in self.special_token_types:
202-
self.special_token_types = tuple(self.special_token_types) + ('eot', )
203-
tokenizer_config['eot_token'] = special_eos
204-
elif 'eom' not in self.special_token_types:
205-
self.special_token_types = tuple(self.special_token_types) + ('eom', )
206-
tokenizer_config['eom_token'] = special_eos
207-
else:
208-
logger.warning(f'Overriding EOS token {special_eos!r} with {special_last!r} without EOT/EOM fallback!')
209-
tokenizer_config['eos_token'] = special_eos = special_last
210-
self.add_special_token['eos'] = True if special_last == special_eos else False
211-
if special_last != special_eos:
212-
logger.warning(f'Unknown trailing special token {special_last!r} in TemplateProcessing<single>')
213-
if tmpl_pair:
214-
seq_start = 1 if special_first and tmpl_pair[0].get('SpecialToken', {}).get('id') == special_first else 0
215-
seq_stop = -1 if special_last and tmpl_pair[-1].get('SpecialToken', {}).get('id') == special_last else None
216-
if (special_first and seq_start == 0) or (special_last and seq_stop is None):
217-
logger.warning('TemplateProcessing<single> leading/trailing special tokens do not match TemplateProcessing<pair>')
218-
if tmpl_pair := tmpl_pair[slice(seq_start, seq_stop)]:
219-
tmpl_a = tmpl_pair[0].get('Sequence', {}).get('id')
220-
tmpl_b = tmpl_pair[-1].get('Sequence', {}).get('id')
221-
if tmpl_a != 'A' or tmpl_b != 'B':
222-
logger.warning(f'Unknown sequence {tmpl_a}...{tmpl_b} in TemplateProcessing<pair>')
223-
# A [sep] [eos] B
224-
if tmpl_a == 'A' and tmpl_b == 'B' and (tmpl_pair := tmpl_pair[1:-1]):
225-
add_sep = False
226-
if special_entry := tmpl_pair[0].get('SpecialToken', {}).get('id'):
227-
if special_entry in (special_sep, special_eos) and not special_last:
228-
add_sep = True
229-
if special_entry not in (special_sep, special_eos):
230-
logger.warning(f'Unknown separator token {special_entry!r} in TemplateProcessing<pair>')
231-
else:
232-
logger.warning(f'Unknown middle sequence {tmpl_pair[0]!r} in TemplateProcessing<pair>')
233-
if len(tmpl_pair) == 2:
234-
if special_entry := tmpl_pair[1].get('SpecialToken', {}).get('id'):
235-
if special_entry in (special_sep, special_eos):
170+
if post_processor := tokenizer.get('post_processor'):
171+
for processor in post_processor.get('processors', [post_processor]):
172+
if processor.get('type') == 'RobertaProcessing':
173+
self.add_special_token['bos'] = True
174+
self.add_special_token['eos'] = True
175+
self.add_special_token['sep'] = True
176+
if not special_cls and tokenizer_config:
177+
special_cls = processor.get('cls', [special_bos])[0]
178+
tokenizer_config['cls_token'] = special_cls
179+
if not special_sep and tokenizer_config:
180+
special_sep = processor.get('sep', [special_eos])[0]
181+
tokenizer_config['sep_token'] = special_sep
182+
continue
183+
# Crude parsing of TemplateProcessing to determine if BOS/SEP/EOS should be added
184+
# Only works with simple templates, **will** get it wrong on unusual sequences
185+
if processor.get('type') == 'TemplateProcessing':
186+
tmpl_single = processor.get('single', [])
187+
tmpl_pair = processor.get('pair', [])
188+
special_first = None
189+
special_last = None
190+
if len(tmpl_single) > 1:
191+
if special_first := tmpl_single[0].get('SpecialToken', {}).get('id'):
192+
if not tokenizer_config:
193+
special_bos = special_first
194+
self.add_special_token['bos'] = True if special_first in (special_bos, special_cls) else False
195+
if special_first not in (special_bos, special_cls):
196+
logger.warning(f'Unknown leading special token {special_first!r} in TemplateProcessing<single>')
197+
if special_last := tmpl_single[-1].get('SpecialToken', {}).get('id'):
198+
if not tokenizer_config:
199+
special_eos = special_last
200+
elif special_last != special_eos:
201+
if 'eot' not in self.special_token_types:
202+
self.special_token_types = tuple(self.special_token_types) + ('eot', )
203+
tokenizer_config['eot_token'] = special_eos
204+
elif 'eom' not in self.special_token_types:
205+
self.special_token_types = tuple(self.special_token_types) + ('eom', )
206+
tokenizer_config['eom_token'] = special_eos
207+
else:
208+
logger.warning(f'Overriding EOS token {special_eos!r} with {special_last!r} without EOT/EOM fallback!')
209+
tokenizer_config['eos_token'] = special_eos = special_last
210+
self.add_special_token['eos'] = True if special_last == special_eos else False
211+
if special_last != special_eos:
212+
logger.warning(f'Unknown trailing special token {special_last!r} in TemplateProcessing<single>')
213+
if tmpl_pair:
214+
seq_start = 1 if special_first and tmpl_pair[0].get('SpecialToken', {}).get('id') == special_first else 0
215+
seq_stop = -1 if special_last and tmpl_pair[-1].get('SpecialToken', {}).get('id') == special_last else None
216+
if (special_first and seq_start == 0) or (special_last and seq_stop is None):
217+
logger.warning('TemplateProcessing<single> leading/trailing special tokens do not match TemplateProcessing<pair>')
218+
if tmpl_pair := tmpl_pair[slice(seq_start, seq_stop)]:
219+
tmpl_a = tmpl_pair[0].get('Sequence', {}).get('id')
220+
tmpl_b = tmpl_pair[-1].get('Sequence', {}).get('id')
221+
if tmpl_a != 'A' or tmpl_b != 'B':
222+
logger.warning(f'Unknown sequence {tmpl_a}...{tmpl_b} in TemplateProcessing<pair>')
223+
# A [sep] [eos] B
224+
if tmpl_a == 'A' and tmpl_b == 'B' and (tmpl_pair := tmpl_pair[1:-1]):
225+
add_sep = False
226+
if special_entry := tmpl_pair[0].get('SpecialToken', {}).get('id'):
227+
if special_entry in (special_sep, special_eos) and not special_last:
236228
add_sep = True
237229
if special_entry not in (special_sep, special_eos):
238-
logger.warning(f'Unknown second separator token {special_entry!r} in TemplateProcessing<pair>')
230+
logger.warning(f'Unknown separator token {special_entry!r} in TemplateProcessing<pair>')
239231
else:
240-
logger.warning(f'Unknown second middle sequence {tmpl_pair[1]!r} in TemplateProcessing<pair>')
241-
self.add_special_token['sep'] = add_sep
242-
if add_sep and not special_sep and tokenizer_config:
243-
tokenizer_config['sep_token'] = special_eos
244-
continue
232+
logger.warning(f'Unknown middle sequence {tmpl_pair[0]!r} in TemplateProcessing<pair>')
233+
if len(tmpl_pair) == 2:
234+
if special_entry := tmpl_pair[1].get('SpecialToken', {}).get('id'):
235+
if special_entry in (special_sep, special_eos):
236+
add_sep = True
237+
if special_entry not in (special_sep, special_eos):
238+
logger.warning(f'Unknown second separator token {special_entry!r} in TemplateProcessing<pair>')
239+
else:
240+
logger.warning(f'Unknown second middle sequence {tmpl_pair[1]!r} in TemplateProcessing<pair>')
241+
self.add_special_token['sep'] = add_sep
242+
if add_sep and not special_sep and tokenizer_config:
243+
tokenizer_config['sep_token'] = special_eos
244+
continue
245245
if not tokenizer_config:
246246
return True
247247
chat_template_alt = None

0 commit comments

Comments
 (0)