Skip to content

Commit 86646c4

Browse files
committed
Better seqrep word boundary handling.
1 parent cb02274 commit 86646c4

File tree

1 file changed

+14
-17
lines changed

1 file changed

+14
-17
lines changed

llama.cpp

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2691,36 +2691,33 @@ static size_t llama_seqrep_find_match(const llama_token * last_tokens_p, const s
26912691
return matches;
26922692
}
26932693

2694+
// Internal helper macro for sequence matching, used to determine if a CP is a word boundary.
2695+
// 0x2000 through is 0x206f is standard unicode punctuation. The CP is considered a word bound
2696+
// if it falls in that range but is _not_ RIGHT SINGLE QUOTATION MARK (0x2019) or if it's in
2697+
// low ASCII range, not alphanumeric and also not a single quote.
2698+
#define LLAMA_SEQREP_IS_WBOUND(cp) ( (cp < 127 && cp != 39 && !std::isalnum((int(cp)))) || (cp != 0x2019 && cp >= 0x2000 && cp <= 0x206f) )
2699+
26942700
// Internal helper function for sequence matching.
26952701
// Bit 1 set indicates token is a word boundary. NL, " blah", "," - word boundary. "blah", "blah:" - not a word boundary.
26962702
// Bit 2 set indicates token ends on a word boundary. NL, "blah:", "blah " - ends on word boundary. " blah", "blah" - doesn't end on word boundary.
2697-
// Errata: UTF8 safe but only considers ASCII characters. ASCII single quote is treated as a non-boundary which isn't always correct.
2703+
// Errata: Special cases apostrophe and only has partial support for unicode punctuation as word boundaries.
26982704
static uint8_t llama_seqrep_check_word(struct llama_context * ctx, const llama_token token) {
26992705
if (token == llama_token_bos() || token == llama_token_eos() || token == llama_token_nl()) {
27002706
// BOS, EOS, NL are always a boundary.
27012707
return 3;
27022708
}
27032709
const char * token_str = llama_token_to_str(ctx, token);
27042710
assert(token_str != NULL);
2705-
if (token_str[0] == '\0') {
2706-
// 0-length token string, can't be a boundary.
2707-
return 0;
2708-
}
27092711

2710-
const char start_char = token_str[0];
2711-
char end_char;
2712-
for (const char *curr_char = token_str; ; curr_char++) {
2713-
// Guaranteed to iterate at least once since we already checked if the string was 0-length.
2714-
if (*(curr_char + 1) == '\0') {
2715-
end_char = *curr_char;
2716-
break;
2717-
}
2712+
const std::vector<uint32_t> token_cps = decode_utf8(token_str);
2713+
const size_t token_cps_len = token_cps.size();
2714+
if (token_cps_len < 2) {
2715+
// token has no codepoints, can't be a boundary. < 2 here because decode_utf8 terminates with a 0 entry.
2716+
return 0;
27182717
}
2719-
return uint8_t(
2720-
(start_char != '\'' && !isalnum((int)start_char) ? 1 : 0) +
2721-
(end_char != '\'' && !isalnum((int)end_char) ? 2 : 0)
2722-
);
27232718

2719+
const uint32_t start_cp = token_cps[0], end_cp = token_cps[token_cps_len - 2];
2720+
return uint8_t(LLAMA_SEQREP_IS_WBOUND(start_cp)) + uint8_t(LLAMA_SEQREP_IS_WBOUND(end_cp)) * 2;
27242721
}
27252722

27262723
void llama_sample_seqrep_penalty(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens_p, size_t last_tokens_size, size_t min_length, size_t tolerance, float flat_penalty, float length_penalty, float mid_word_scale) {

0 commit comments

Comments
 (0)