Skip to content

Commit cbfb2ef

Browse files
committed
seqrep: Fix token text content handling.
1 parent e52cf05 commit cbfb2ef

File tree

1 file changed

+16
-5
lines changed

1 file changed

+16
-5
lines changed

llama.cpp

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4165,14 +4165,24 @@ static size_t llama_seqrep_find_match(const llama_token * last_tokens_p, const s
41654165
// Bit 1 set indicates token is a word boundary. NL, " blah", "," - word boundary. "blah", "blah:" - not a word boundary.
41664166
// Bit 2 set indicates token ends on a word boundary. NL, "blah:", "blah " - ends on word boundary. " blah", "blah" - doesn't end on word boundary.
41674167
// Errata: Special cases apostrophe and only has partial support for unicode punctuation as word boundaries.
4168-
static uint8_t llama_seqrep_check_word(struct llama_context * ctx, const llama_token token) {
4168+
static uint8_t llama_seqrep_check_word(struct llama_context * ctx, const llama_token token, std::vector<char> & buf) {
41694169
if (token == llama_token_bos(ctx) || token == llama_token_eos(ctx) || token == llama_token_nl(ctx)) {
41704170
// BOS, EOS, NL are always a boundary.
41714171
return 3;
41724172
}
41734173

4174-
const char * token_str = llama_token_get_text(ctx, token);
4175-
auto decoded = decode_utf8(token_str, llama_partial_utf8{ 0, 0 });
4174+
if (buf.size() < 128) {
4175+
buf.resize(128);
4176+
}
4177+
int n_tokens = llama_token_to_piece(ctx, token, buf.data(), buf.size() - 1);
4178+
if (n_tokens < 0) {
4179+
buf.resize(size_t(-n_tokens) + 128);
4180+
const int check = llama_token_to_piece(ctx, token, buf.data(), buf.size() - 1);
4181+
GGML_ASSERT(check == -n_tokens);
4182+
n_tokens = check;
4183+
}
4184+
buf[n_tokens] = 0;
4185+
auto decoded = decode_utf8(buf.data(), llama_partial_utf8{ 0, 0 });
41764186
const std::vector<uint32_t> & token_cps = decoded.first;
41774187
const size_t token_cps_len = token_cps.size();
41784188

@@ -4259,8 +4269,9 @@ void llama_sample_seqrep_penalty(struct llama_context * ctx, llama_token_data_ar
42594269
}
42604270
}
42614271

4272+
std::vector<char> buf(128, 0);
42624273
const bool ends_on_word = params->mid_word_scale == 1.0f
4263-
|| (llama_seqrep_check_word(ctx, last_tokens_p[last_tokens_size - 1]) & 2) != 0;
4274+
|| (llama_seqrep_check_word(ctx, last_tokens_p[last_tokens_size - 1], buf) & 2) != 0;
42644275

42654276
for (size_t i = 0; i < candidates->size; ++i) {
42664277
auto pt_iter = penalize_tokens.find(candidates->data[i].id);
@@ -4270,7 +4281,7 @@ void llama_sample_seqrep_penalty(struct llama_context * ctx, llama_token_data_ar
42704281

42714282
const size_t count = pt_iter->second;
42724283
const bool pt_starts_word = params->mid_word_scale == 1.0f ||
4273-
(llama_seqrep_check_word(ctx, candidates->data[i].id) & 1) != 0;
4284+
(llama_seqrep_check_word(ctx, candidates->data[i].id, buf) & 1) != 0;
42744285
float penalty_scale = ends_on_word || pt_starts_word ? 1.0f : params->mid_word_scale;
42754286
float logit = candidates->data[i].logit;
42764287

0 commit comments

Comments
 (0)