Skip to content

Commit dd1e26e

Browse files
CISCqnixsynapse
authored andcommitted
llama : improve sep token handling (ggml-org#14272)
1 parent 7982148 commit dd1e26e

File tree

8 files changed

+88
-89
lines changed

8 files changed

+88
-89
lines changed

common/common.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,7 @@ struct common_params {
358358
int32_t embd_normalize = 2; // normalisation for embeddings (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm)
359359
std::string embd_out = ""; // empty = default, "array" = [[],[]...], "json" = openai style, "json+" = same "json" + cosine similarity matrix
360360
std::string embd_sep = "\n"; // separator of embeddings
361+
std::string cls_sep = "\t"; // separator of classification sequences
361362

362363
// server params
363364
int32_t port = 8080; // server listens on this network port

convert_hf_to_gguf.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2145,7 +2145,6 @@ def __init__(self, *args, **kwargs):
21452145

21462146
def set_vocab(self):
21472147
self._set_vocab_gpt2()
2148-
self.gguf_writer.add_add_bos_token(True)
21492148

21502149
def set_gguf_parameters(self):
21512150
super().set_gguf_parameters()
@@ -3918,9 +3917,6 @@ def _xlmroberta_set_vocab(self) -> None:
39183917
special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
39193918
special_vocab.add_to_gguf(self.gguf_writer)
39203919

3921-
self.gguf_writer.add_add_bos_token(True)
3922-
self.gguf_writer.add_add_eos_token(True)
3923-
39243920

39253921
@ModelBase.register("DistilBertModel", "DistilBertForMaskedLM", "DistilBertForSequenceClassification")
39263922
class DistilBertModel(BertModel):
@@ -3962,8 +3958,6 @@ def set_vocab(self):
39623958
bpe_tok_path = self.dir_model / "tokenizer.json"
39633959
if bpe_tok_path.exists():
39643960
self._set_vocab_gpt2()
3965-
self.gguf_writer.add_add_bos_token(True)
3966-
self.gguf_writer.add_add_eos_token(True)
39673961

39683962
# we need this to validate the size of the token_type embeddings
39693963
# though currently we are passing all zeros to the token_type embeddings
@@ -4848,8 +4842,6 @@ def set_vocab(self):
48484842
self.gguf_writer.add_token_type_count(2)
48494843
else:
48504844
raise NotImplementedError(f'Tokenizer {tokenizer_class} is not supported for JinaBertModel')
4851-
self.gguf_writer.add_add_bos_token(True)
4852-
self.gguf_writer.add_add_eos_token(True)
48534845

48544846

48554847
@ModelBase.register("OpenELMForCausalLM")
@@ -5451,9 +5443,6 @@ def set_vocab(self):
54515443
special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
54525444
special_vocab.add_to_gguf(self.gguf_writer)
54535445

5454-
self.gguf_writer.add_add_bos_token(False)
5455-
self.gguf_writer.add_add_eos_token(True)
5456-
54575446
def set_gguf_parameters(self):
54585447
if (n_ctx := self.find_hparam(["n_positions"], optional=True)) is None:
54595448
logger.warning("Couldn't find context length in config.json, assuming default value of 512")
@@ -5591,9 +5580,6 @@ def set_vocab(self):
55915580
special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
55925581
special_vocab.add_to_gguf(self.gguf_writer)
55935582

5594-
self.gguf_writer.add_add_bos_token(False)
5595-
self.gguf_writer.add_add_eos_token(True)
5596-
55975583
def set_gguf_parameters(self):
55985584
if (n_ctx := self.find_hparam(["n_positions"], optional=True)) is None:
55995585
logger.warning("Couldn't find context length in config.json, assuming default value of 512")

gguf-py/gguf/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,7 @@ class Tokenizer:
198198
MASK_ID = "tokenizer.ggml.mask_token_id"
199199
ADD_BOS = "tokenizer.ggml.add_bos_token"
200200
ADD_EOS = "tokenizer.ggml.add_eos_token"
201+
ADD_SEP = "tokenizer.ggml.add_sep_token"
201202
ADD_PREFIX = "tokenizer.ggml.add_space_prefix"
202203
REMOVE_EXTRA_WS = "tokenizer.ggml.remove_extra_whitespaces"
203204
PRECOMPILED_CHARSMAP = "tokenizer.ggml.precompiled_charsmap"

gguf-py/gguf/vocab.py

Lines changed: 62 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -167,81 +167,71 @@ def _try_load_from_tokenizer_json(self, path: Path) -> bool:
167167
tokenizer_config['bos_token'] = special_bos = special_cls
168168
if not special_eos and special_sep and tokenizer_config:
169169
tokenizer_config['eos_token'] = special_eos = special_sep
170-
if post_processor := tokenizer.get('post_processor'):
171-
for processor in post_processor.get('processors', [post_processor]):
172-
if processor.get('type') == 'RobertaProcessing':
173-
self.add_special_token['bos'] = True
174-
self.add_special_token['eos'] = True
175-
self.add_special_token['sep'] = True
176-
if not special_cls and tokenizer_config:
177-
special_cls = processor.get('cls', [special_bos])[0]
178-
tokenizer_config['cls_token'] = special_cls
179-
if not special_sep and tokenizer_config:
180-
special_sep = processor.get('sep', [special_eos])[0]
181-
tokenizer_config['sep_token'] = special_sep
182-
continue
183-
# Crude parsing of TemplateProcessing to determine if BOS/SEP/EOS should be added
184-
# Only works with simple templates, **will** get it wrong on unusual sequences
185-
if processor.get('type') == 'TemplateProcessing':
186-
tmpl_single = processor.get('single', [])
187-
tmpl_pair = processor.get('pair', [])
188-
special_first = None
189-
special_last = None
190-
if len(tmpl_single) > 1:
191-
if special_first := tmpl_single[0].get('SpecialToken', {}).get('id'):
192-
if not tokenizer_config:
193-
special_bos = special_first
194-
self.add_special_token['bos'] = True if special_first in (special_bos, special_cls) else False
195-
if special_first not in (special_bos, special_cls):
196-
logger.warning(f'Unknown leading special token {special_first!r} in TemplateProcessing<single>')
197-
if special_last := tmpl_single[-1].get('SpecialToken', {}).get('id'):
198-
if not tokenizer_config:
199-
special_eos = special_last
200-
elif special_last != special_eos:
201-
if 'eot' not in self.special_token_types:
202-
self.special_token_types = tuple(self.special_token_types) + ('eot', )
203-
tokenizer_config['eot_token'] = special_eos
204-
elif 'eom' not in self.special_token_types:
205-
self.special_token_types = tuple(self.special_token_types) + ('eom', )
206-
tokenizer_config['eom_token'] = special_eos
207-
else:
208-
logger.warning(f'Overriding EOS token {special_eos!r} with {special_last!r} without EOT/EOM fallback!')
209-
tokenizer_config['eos_token'] = special_eos = special_last
210-
self.add_special_token['eos'] = True if special_last == special_eos else False
211-
if special_last != special_eos:
212-
logger.warning(f'Unknown trailing special token {special_last!r} in TemplateProcessing<single>')
213-
if tmpl_pair:
214-
seq_start = 1 if special_first and tmpl_pair[0].get('SpecialToken', {}).get('id') == special_first else 0
215-
seq_stop = -1 if special_last and tmpl_pair[-1].get('SpecialToken', {}).get('id') == special_last else None
216-
if (special_first and seq_start == 0) or (special_last and seq_stop is None):
217-
logger.warning('TemplateProcessing<single> leading/trailing special tokens do not match TemplateProcessing<pair>')
218-
if tmpl_pair := tmpl_pair[slice(seq_start, seq_stop)]:
219-
tmpl_a = tmpl_pair[0].get('Sequence', {}).get('id')
220-
tmpl_b = tmpl_pair[-1].get('Sequence', {}).get('id')
221-
if tmpl_a != 'A' or tmpl_b != 'B':
222-
logger.warning(f'Unknown sequence {tmpl_a}...{tmpl_b} in TemplateProcessing<pair>')
223-
# A [sep] [eos] B
224-
if tmpl_a == 'A' and tmpl_b == 'B' and (tmpl_pair := tmpl_pair[1:-1]):
225-
add_sep = False
226-
if special_entry := tmpl_pair[0].get('SpecialToken', {}).get('id'):
227-
if special_entry in (special_sep, special_eos) and not special_last:
170+
post_processor = tokenizer.get('post_processor', {})
171+
for processor in post_processor.get('processors', [post_processor]):
172+
if processor.get('type') == 'RobertaProcessing':
173+
self.add_special_token['bos'] = True
174+
self.add_special_token['eos'] = True
175+
self.add_special_token['sep'] = True
176+
if not special_cls and tokenizer_config:
177+
special_cls = processor.get('cls', [special_bos])[0]
178+
tokenizer_config['cls_token'] = special_cls
179+
if not special_sep and tokenizer_config:
180+
special_sep = processor.get('sep', [special_eos])[0]
181+
tokenizer_config['sep_token'] = special_sep
182+
continue
183+
# Crude parsing of TemplateProcessing to determine if BOS/SEP/EOS should be added
184+
# Only works with simple templates, **will** get it wrong on unusual sequences
185+
if processor.get('type') == 'TemplateProcessing':
186+
tmpl_single = processor.get('single', [])
187+
tmpl_pair = processor.get('pair', [])
188+
special_first = None
189+
special_last = None
190+
if len(tmpl_single) > 1:
191+
if special_first := tmpl_single[0].get('SpecialToken', {}).get('id'):
192+
if not tokenizer_config:
193+
special_bos = special_first
194+
self.add_special_token['bos'] = True if special_first in (special_bos, special_cls) else False
195+
if special_first not in (special_bos, special_cls):
196+
logger.warning(f'Unknown leading special token {special_first!r} in TemplateProcessing<single>')
197+
if special_last := tmpl_single[-1].get('SpecialToken', {}).get('id'):
198+
if not tokenizer_config:
199+
special_eos = special_last
200+
self.add_special_token['eos'] = True if special_last == special_eos else False
201+
if special_last != special_eos:
202+
logger.warning(f'Unknown trailing special token {special_last!r} in TemplateProcessing<single>')
203+
if tmpl_pair:
204+
seq_start = 1 if tmpl_pair[0].get('SpecialToken', {}).get('id') == special_first else 0
205+
seq_stop = -1 if tmpl_pair[-1].get('SpecialToken', {}).get('id') == special_last else None
206+
if seq_start == 0 or seq_stop is None:
207+
logger.warning('TemplateProcessing<single> leading/trailing special tokens do not match TemplateProcessing<pair>')
208+
if tmpl_pair := tmpl_pair[slice(seq_start, seq_stop)]:
209+
tmpl_a = tmpl_pair[0].get('Sequence', {}).get('id')
210+
tmpl_b = tmpl_pair[-1].get('Sequence', {}).get('id')
211+
if tmpl_a != 'A' or tmpl_b != 'B':
212+
logger.warning(f'Unknown sequence {tmpl_a}...{tmpl_b} in TemplateProcessing<pair>')
213+
# A [sep] [eos] B
214+
if tmpl_a == 'A' and tmpl_b == 'B' and (tmpl_pair := tmpl_pair[1:-1]):
215+
add_sep = False
216+
if special_entry := tmpl_pair[0].get('SpecialToken', {}).get('id'):
217+
if special_entry in (special_sep, special_eos) and not special_last:
218+
add_sep = True
219+
if special_entry not in (special_sep, special_eos):
220+
logger.warning(f'Unknown separator token {special_entry!r} in TemplateProcessing<pair>')
221+
else:
222+
logger.warning(f'Unknown middle sequence {tmpl_pair[0]!r} in TemplateProcessing<pair>')
223+
if len(tmpl_pair) == 2:
224+
if special_entry := tmpl_pair[1].get('SpecialToken', {}).get('id'):
225+
if special_entry in (special_sep, special_eos):
228226
add_sep = True
229227
if special_entry not in (special_sep, special_eos):
230-
logger.warning(f'Unknown separator token {special_entry!r} in TemplateProcessing<pair>')
228+
logger.warning(f'Unknown second separator token {special_entry!r} in TemplateProcessing<pair>')
231229
else:
232-
logger.warning(f'Unknown middle sequence {tmpl_pair[0]!r} in TemplateProcessing<pair>')
233-
if len(tmpl_pair) == 2:
234-
if special_entry := tmpl_pair[1].get('SpecialToken', {}).get('id'):
235-
if special_entry in (special_sep, special_eos):
236-
add_sep = True
237-
if special_entry not in (special_sep, special_eos):
238-
logger.warning(f'Unknown second separator token {special_entry!r} in TemplateProcessing<pair>')
239-
else:
240-
logger.warning(f'Unknown second middle sequence {tmpl_pair[1]!r} in TemplateProcessing<pair>')
241-
self.add_special_token['sep'] = add_sep
242-
if add_sep and not special_sep and tokenizer_config:
243-
tokenizer_config['sep_token'] = special_eos
244-
continue
230+
logger.warning(f'Unknown second middle sequence {tmpl_pair[1]!r} in TemplateProcessing<pair>')
231+
self.add_special_token['sep'] = add_sep
232+
if add_sep and not special_sep and tokenizer_config:
233+
tokenizer_config['sep_token'] = special_eos
234+
continue
245235
if not tokenizer_config:
246236
return True
247237
chat_template_alt = None

include/llama.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1044,6 +1044,7 @@ extern "C" {
10441044

10451045
LLAMA_API bool llama_vocab_get_add_bos(const struct llama_vocab * vocab);
10461046
LLAMA_API bool llama_vocab_get_add_eos(const struct llama_vocab * vocab);
1047+
LLAMA_API bool llama_vocab_get_add_sep(const struct llama_vocab * vocab);
10471048

10481049
LLAMA_API llama_token llama_vocab_fim_pre(const struct llama_vocab * vocab);
10491050
LLAMA_API llama_token llama_vocab_fim_suf(const struct llama_vocab * vocab);

src/llama-arch.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
198198
{ LLM_KV_TOKENIZER_MASK_ID, "tokenizer.ggml.mask_token_id" },
199199
{ LLM_KV_TOKENIZER_ADD_BOS, "tokenizer.ggml.add_bos_token" },
200200
{ LLM_KV_TOKENIZER_ADD_EOS, "tokenizer.ggml.add_eos_token" },
201+
{ LLM_KV_TOKENIZER_ADD_SEP, "tokenizer.ggml.add_sep_token" },
201202
{ LLM_KV_TOKENIZER_ADD_PREFIX, "tokenizer.ggml.add_space_prefix" },
202203
{ LLM_KV_TOKENIZER_REMOVE_EXTRA_WS, "tokenizer.ggml.remove_extra_whitespaces" },
203204
{ LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP, "tokenizer.ggml.precompiled_charsmap" },

src/llama-arch.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,7 @@ enum llm_kv {
194194
LLM_KV_TOKENIZER_MASK_ID,
195195
LLM_KV_TOKENIZER_ADD_BOS,
196196
LLM_KV_TOKENIZER_ADD_EOS,
197+
LLM_KV_TOKENIZER_ADD_SEP,
197198
LLM_KV_TOKENIZER_ADD_PREFIX,
198199
LLM_KV_TOKENIZER_REMOVE_EXTRA_WS,
199200
LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP,

src/llama-vocab.cpp

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1269,6 +1269,7 @@ struct llama_vocab::impl {
12691269
bool add_space_prefix = false;
12701270
bool add_bos = false;
12711271
bool add_eos = false;
1272+
bool add_sep = false;
12721273
bool ignore_merges = false;
12731274
bool clean_spaces = false; // clean_up_tokenization_spaces
12741275
bool remove_extra_whitespaces = false;
@@ -1421,6 +1422,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
14211422
special_sep_id = 102;
14221423
special_pad_id = 0;
14231424
special_mask_id = 103;
1425+
1426+
add_sep = true;
14241427
} else if (tokenizer_model == "gpt2") {
14251428
type = LLAMA_VOCAB_TYPE_BPE;
14261429

@@ -1550,12 +1553,15 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
15501553
tokenizer_pre == "jina-es" ||
15511554
tokenizer_pre == "jina-de" ||
15521555
tokenizer_pre == "gigachat" ||
1553-
tokenizer_pre == "jina-v1-en" ||
15541556
tokenizer_pre == "jina-v2-es" ||
1555-
tokenizer_pre == "jina-v2-de" ||
1557+
tokenizer_pre == "jina-v2-de") {
1558+
pre_type = LLAMA_VOCAB_PRE_TYPE_GPT2;
1559+
} else if (
1560+
tokenizer_pre == "jina-v1-en" ||
15561561
tokenizer_pre == "jina-v2-code" ||
15571562
tokenizer_pre == "roberta-bpe") {
15581563
pre_type = LLAMA_VOCAB_PRE_TYPE_GPT2;
1564+
add_sep = true;
15591565
} else if (
15601566
tokenizer_pre == "refact") {
15611567
pre_type = LLAMA_VOCAB_PRE_TYPE_REFACT;
@@ -1665,6 +1671,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
16651671
clean_spaces = true;
16661672
add_bos = true;
16671673
add_eos = false;
1674+
add_sep = true;
16681675
} else if (type == LLAMA_VOCAB_TYPE_UGM) {
16691676
pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
16701677
add_bos = false;
@@ -1801,7 +1808,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
18011808
}
18021809
}
18031810

1804-
// Handle add_bos and add_eos
1811+
// Handle add_bos, add_eos and add_sep
18051812
{
18061813
bool temp = true;
18071814

@@ -1811,6 +1818,9 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
18111818
if (ml.get_key(LLM_KV_TOKENIZER_ADD_EOS, temp, false)) {
18121819
add_eos = temp;
18131820
}
1821+
if (ml.get_key(LLM_KV_TOKENIZER_ADD_SEP, temp, false)) {
1822+
add_sep = temp;
1823+
}
18141824
}
18151825

18161826
// auto-detect special tokens by text
@@ -3000,6 +3010,10 @@ bool llama_vocab::get_add_eos() const {
30003010
return pimpl->add_eos;
30013011
}
30023012

3013+
bool llama_vocab::get_add_sep() const {
3014+
return pimpl->add_sep;
3015+
}
3016+
30033017
bool llama_vocab::get_ignore_merges() const {
30043018
return pimpl->ignore_merges;
30053019
}
@@ -3191,6 +3205,10 @@ bool llama_vocab_get_add_eos(const struct llama_vocab * vocab) {
31913205
return vocab->get_add_eos();
31923206
}
31933207

3208+
bool llama_vocab_get_add_sep(const struct llama_vocab * vocab) {
3209+
return vocab->get_add_sep();
3210+
}
3211+
31943212
llama_token llama_vocab_fim_pre(const struct llama_vocab * vocab) {
31953213
return vocab->token_fim_pre();
31963214
}

0 commit comments

Comments
 (0)