@@ -167,81 +167,81 @@ def _try_load_from_tokenizer_json(self, path: Path) -> bool:
167
167
tokenizer_config ['bos_token' ] = special_bos = special_cls
168
168
if not special_eos and special_sep and tokenizer_config :
169
169
tokenizer_config ['eos_token' ] = special_eos = special_sep
170
- post_processor = tokenizer .get ('post_processor' , {})
171
- for processor in post_processor .get ('processors' , [post_processor ]):
172
- if processor .get ('type' ) == 'RobertaProcessing' :
173
- self .add_special_token ['bos' ] = True
174
- self .add_special_token ['eos' ] = True
175
- self .add_special_token ['sep' ] = True
176
- if not special_cls and tokenizer_config :
177
- special_cls = processor .get ('cls' , [special_bos ])[0 ]
178
- tokenizer_config ['cls_token' ] = special_cls
179
- if not special_sep and tokenizer_config :
180
- special_sep = processor .get ('sep' , [special_eos ])[0 ]
181
- tokenizer_config ['sep_token' ] = special_sep
182
- continue
183
- # Crude parsing of TemplateProcessing to determine if BOS/SEP/EOS should be added
184
- # Only works with simple templates, **will** get it wrong on unusual sequences
185
- if processor .get ('type' ) == 'TemplateProcessing' :
186
- tmpl_single = processor .get ('single' , [])
187
- tmpl_pair = processor .get ('pair' , [])
188
- special_first = None
189
- special_last = None
190
- if len (tmpl_single ) > 1 :
191
- if special_first := tmpl_single [0 ].get ('SpecialToken' , {}).get ('id' ):
192
- if not tokenizer_config :
193
- special_bos = special_first
194
- self .add_special_token ['bos' ] = True if special_first in (special_bos , special_cls ) else False
195
- if special_first not in (special_bos , special_cls ):
196
- logger .warning (f'Unknown leading special token { special_first !r} in TemplateProcessing<single>' )
197
- if special_last := tmpl_single [- 1 ].get ('SpecialToken' , {}).get ('id' ):
198
- if not tokenizer_config :
199
- special_eos = special_last
200
- elif special_last != special_eos :
201
- if 'eot' not in self .special_token_types :
202
- self .special_token_types = tuple (self .special_token_types ) + ('eot' , )
203
- tokenizer_config ['eot_token' ] = special_eos
204
- elif 'eom' not in self .special_token_types :
205
- self .special_token_types = tuple (self .special_token_types ) + ('eom' , )
206
- tokenizer_config ['eom_token' ] = special_eos
207
- else :
208
- logger .warning (f'Overriding EOS token { special_eos !r} with { special_last !r} without EOT/EOM fallback!' )
209
- tokenizer_config ['eos_token' ] = special_eos = special_last
210
- self .add_special_token ['eos' ] = True if special_last == special_eos else False
211
- if special_last != special_eos :
212
- logger .warning (f'Unknown trailing special token { special_last !r} in TemplateProcessing<single>' )
213
- if tmpl_pair :
214
- seq_start = 1 if special_first and tmpl_pair [0 ].get ('SpecialToken' , {}).get ('id' ) == special_first else 0
215
- seq_stop = - 1 if special_last and tmpl_pair [- 1 ].get ('SpecialToken' , {}).get ('id' ) == special_last else None
216
- if (special_first and seq_start == 0 ) or (special_last and seq_stop is None ):
217
- logger .warning ('TemplateProcessing<single> leading/trailing special tokens do not match TemplateProcessing<pair>' )
218
- if tmpl_pair := tmpl_pair [slice (seq_start , seq_stop )]:
219
- tmpl_a = tmpl_pair [0 ].get ('Sequence' , {}).get ('id' )
220
- tmpl_b = tmpl_pair [- 1 ].get ('Sequence' , {}).get ('id' )
221
- if tmpl_a != 'A' or tmpl_b != 'B' :
222
- logger .warning (f'Unknown sequence { tmpl_a } ...{ tmpl_b } in TemplateProcessing<pair>' )
223
- # A [sep] [eos] B
224
- if tmpl_a == 'A' and tmpl_b == 'B' and (tmpl_pair := tmpl_pair [1 :- 1 ]):
225
- add_sep = False
226
- if special_entry := tmpl_pair [0 ].get ('SpecialToken' , {}).get ('id' ):
227
- if special_entry in (special_sep , special_eos ) and not special_last :
228
- add_sep = True
229
- if special_entry not in (special_sep , special_eos ):
230
- logger .warning (f'Unknown separator token { special_entry !r} in TemplateProcessing<pair>' )
231
- else :
232
- logger .warning (f'Unknown middle sequence { tmpl_pair [0 ]!r} in TemplateProcessing<pair>' )
233
- if len (tmpl_pair ) == 2 :
234
- if special_entry := tmpl_pair [1 ].get ('SpecialToken' , {}).get ('id' ):
235
- if special_entry in (special_sep , special_eos ):
170
+ if post_processor := tokenizer .get ('post_processor' ):
171
+ for processor in post_processor .get ('processors' , [post_processor ]):
172
+ if processor .get ('type' ) == 'RobertaProcessing' :
173
+ self .add_special_token ['bos' ] = True
174
+ self .add_special_token ['eos' ] = True
175
+ self .add_special_token ['sep' ] = True
176
+ if not special_cls and tokenizer_config :
177
+ special_cls = processor .get ('cls' , [special_bos ])[0 ]
178
+ tokenizer_config ['cls_token' ] = special_cls
179
+ if not special_sep and tokenizer_config :
180
+ special_sep = processor .get ('sep' , [special_eos ])[0 ]
181
+ tokenizer_config ['sep_token' ] = special_sep
182
+ continue
183
+ # Crude parsing of TemplateProcessing to determine if BOS/SEP/EOS should be added
184
+ # Only works with simple templates, **will** get it wrong on unusual sequences
185
+ if processor .get ('type' ) == 'TemplateProcessing' :
186
+ tmpl_single = processor .get ('single' , [])
187
+ tmpl_pair = processor .get ('pair' , [])
188
+ special_first = None
189
+ special_last = None
190
+ if len (tmpl_single ) > 1 :
191
+ if special_first := tmpl_single [0 ].get ('SpecialToken' , {}).get ('id' ):
192
+ if not tokenizer_config :
193
+ special_bos = special_first
194
+ self .add_special_token ['bos' ] = True if special_first in (special_bos , special_cls ) else False
195
+ if special_first not in (special_bos , special_cls ):
196
+ logger .warning (f'Unknown leading special token { special_first !r} in TemplateProcessing<single>' )
197
+ if special_last := tmpl_single [- 1 ].get ('SpecialToken' , {}).get ('id' ):
198
+ if not tokenizer_config :
199
+ special_eos = special_last
200
+ elif special_last != special_eos :
201
+ if 'eot' not in self .special_token_types :
202
+ self .special_token_types = tuple (self .special_token_types ) + ('eot' , )
203
+ tokenizer_config ['eot_token' ] = special_eos
204
+ elif 'eom' not in self .special_token_types :
205
+ self .special_token_types = tuple (self .special_token_types ) + ('eom' , )
206
+ tokenizer_config ['eom_token' ] = special_eos
207
+ else :
208
+ logger .warning (f'Overriding EOS token { special_eos !r} with { special_last !r} without EOT/EOM fallback!' )
209
+ tokenizer_config ['eos_token' ] = special_eos = special_last
210
+ self .add_special_token ['eos' ] = True if special_last == special_eos else False
211
+ if special_last != special_eos :
212
+ logger .warning (f'Unknown trailing special token { special_last !r} in TemplateProcessing<single>' )
213
+ if tmpl_pair :
214
+ seq_start = 1 if special_first and tmpl_pair [0 ].get ('SpecialToken' , {}).get ('id' ) == special_first else 0
215
+ seq_stop = - 1 if special_last and tmpl_pair [- 1 ].get ('SpecialToken' , {}).get ('id' ) == special_last else None
216
+ if (special_first and seq_start == 0 ) or (special_last and seq_stop is None ):
217
+ logger .warning ('TemplateProcessing<single> leading/trailing special tokens do not match TemplateProcessing<pair>' )
218
+ if tmpl_pair := tmpl_pair [slice (seq_start , seq_stop )]:
219
+ tmpl_a = tmpl_pair [0 ].get ('Sequence' , {}).get ('id' )
220
+ tmpl_b = tmpl_pair [- 1 ].get ('Sequence' , {}).get ('id' )
221
+ if tmpl_a != 'A' or tmpl_b != 'B' :
222
+ logger .warning (f'Unknown sequence { tmpl_a } ...{ tmpl_b } in TemplateProcessing<pair>' )
223
+ # A [sep] [eos] B
224
+ if tmpl_a == 'A' and tmpl_b == 'B' and (tmpl_pair := tmpl_pair [1 :- 1 ]):
225
+ add_sep = False
226
+ if special_entry := tmpl_pair [0 ].get ('SpecialToken' , {}).get ('id' ):
227
+ if special_entry in (special_sep , special_eos ) and not special_last :
236
228
add_sep = True
237
229
if special_entry not in (special_sep , special_eos ):
238
- logger .warning (f'Unknown second separator token { special_entry !r} in TemplateProcessing<pair>' )
230
+ logger .warning (f'Unknown separator token { special_entry !r} in TemplateProcessing<pair>' )
239
231
else :
240
- logger .warning (f'Unknown second middle sequence { tmpl_pair [1 ]!r} in TemplateProcessing<pair>' )
241
- self .add_special_token ['sep' ] = add_sep
242
- if add_sep and not special_sep and tokenizer_config :
243
- tokenizer_config ['sep_token' ] = special_eos
244
- continue
232
+ logger .warning (f'Unknown middle sequence { tmpl_pair [0 ]!r} in TemplateProcessing<pair>' )
233
+ if len (tmpl_pair ) == 2 :
234
+ if special_entry := tmpl_pair [1 ].get ('SpecialToken' , {}).get ('id' ):
235
+ if special_entry in (special_sep , special_eos ):
236
+ add_sep = True
237
+ if special_entry not in (special_sep , special_eos ):
238
+ logger .warning (f'Unknown second separator token { special_entry !r} in TemplateProcessing<pair>' )
239
+ else :
240
+ logger .warning (f'Unknown second middle sequence { tmpl_pair [1 ]!r} in TemplateProcessing<pair>' )
241
+ self .add_special_token ['sep' ] = add_sep
242
+ if add_sep and not special_sep and tokenizer_config :
243
+ tokenizer_config ['sep_token' ] = special_eos
244
+ continue
245
245
if not tokenizer_config :
246
246
return True
247
247
chat_template_alt = None
0 commit comments