@@ -4165,14 +4165,24 @@ static size_t llama_seqrep_find_match(const llama_token * last_tokens_p, const s
4165
4165
// Bit 1 set indicates token is a word boundary. NL, " blah", "," - word boundary. "blah", "blah:" - not a word boundary.
4166
4166
// Bit 2 set indicates token ends on a word boundary. NL, "blah:", "blah " - ends on word boundary. " blah", "blah" - doesn't end on word boundary.
4167
4167
// Errata: Special cases apostrophe and only has partial support for unicode punctuation as word boundaries.
4168
- static uint8_t llama_seqrep_check_word (struct llama_context * ctx, const llama_token token) {
4168
+ static uint8_t llama_seqrep_check_word (struct llama_context * ctx, const llama_token token, std::vector< char > & buf ) {
4169
4169
if (token == llama_token_bos (ctx) || token == llama_token_eos (ctx) || token == llama_token_nl (ctx)) {
4170
4170
// BOS, EOS, NL are always a boundary.
4171
4171
return 3 ;
4172
4172
}
4173
4173
4174
- const char * token_str = llama_token_get_text (ctx, token);
4175
- auto decoded = decode_utf8 (token_str, llama_partial_utf8{ 0 , 0 });
4174
+ if (buf.size () < 128 ) {
4175
+ buf.resize (128 );
4176
+ }
4177
+ int n_tokens = llama_token_to_piece (ctx, token, buf.data (), buf.size () - 1 );
4178
+ if (n_tokens < 0 ) {
4179
+ buf.resize (size_t (-n_tokens) + 128 );
4180
+ const int check = llama_token_to_piece (ctx, token, buf.data (), buf.size () - 1 );
4181
+ GGML_ASSERT (check == -n_tokens);
4182
+ n_tokens = check;
4183
+ }
4184
+ buf[n_tokens] = 0 ;
4185
+ auto decoded = decode_utf8 (buf.data (), llama_partial_utf8{ 0 , 0 });
4176
4186
const std::vector<uint32_t > & token_cps = decoded.first ;
4177
4187
const size_t token_cps_len = token_cps.size ();
4178
4188
@@ -4259,8 +4269,9 @@ void llama_sample_seqrep_penalty(struct llama_context * ctx, llama_token_data_ar
4259
4269
}
4260
4270
}
4261
4271
4272
+ std::vector<char > buf (128 , 0 );
4262
4273
const bool ends_on_word = params->mid_word_scale == 1 .0f
4263
- || (llama_seqrep_check_word (ctx, last_tokens_p[last_tokens_size - 1 ]) & 2 ) != 0 ;
4274
+ || (llama_seqrep_check_word (ctx, last_tokens_p[last_tokens_size - 1 ], buf ) & 2 ) != 0 ;
4264
4275
4265
4276
for (size_t i = 0 ; i < candidates->size ; ++i) {
4266
4277
auto pt_iter = penalize_tokens.find (candidates->data [i].id );
@@ -4270,7 +4281,7 @@ void llama_sample_seqrep_penalty(struct llama_context * ctx, llama_token_data_ar
4270
4281
4271
4282
const size_t count = pt_iter->second ;
4272
4283
const bool pt_starts_word = params->mid_word_scale == 1 .0f ||
4273
- (llama_seqrep_check_word (ctx, candidates->data [i].id ) & 1 ) != 0 ;
4284
+ (llama_seqrep_check_word (ctx, candidates->data [i].id , buf ) & 1 ) != 0 ;
4274
4285
float penalty_scale = ends_on_word || pt_starts_word ? 1 .0f : params->mid_word_scale ;
4275
4286
float logit = candidates->data [i].logit ;
4276
4287
0 commit comments