Skip to content

Commit 83796e6

Browse files
authored
llama : refactor unicode stuff (#5992)
* llama : refactor unicode stuff ggml-ci * unicode : names * make : fix c++ compiler * unicode : names * unicode : straighten tables * zig : fix build * unicode : put nfd normalization behind API ggml-ci * swift : fix build * unicode : add BOM * unicode : add <cstdint> ggml-ci * unicode : pass as cpts as const ref
1 parent 828defe commit 83796e6

File tree

9 files changed

+1744
-836
lines changed

9 files changed

+1744
-836
lines changed

CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1141,6 +1141,8 @@ endif()
11411141
add_library(llama
11421142
llama.cpp
11431143
llama.h
1144+
unicode.h
1145+
unicode.cpp
11441146
)
11451147

11461148
target_include_directories(llama PUBLIC .)

Makefile

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -633,9 +633,12 @@ ggml-backend.o: ggml-backend.c ggml.h ggml-backend.h
633633
ggml-quants.o: ggml-quants.c ggml.h ggml-quants.h ggml-common.h
634634
$(CC) $(CFLAGS) -c $< -o $@
635635

636-
OBJS += ggml-alloc.o ggml-backend.o ggml-quants.o
636+
unicode.o: unicode.cpp unicode.h
637+
$(CXX) $(CXXFLAGS) -c $< -o $@
638+
639+
OBJS += ggml-alloc.o ggml-backend.o ggml-quants.o unicode.o
637640

638-
llama.o: llama.cpp ggml.h ggml-alloc.h ggml-backend.h ggml-cuda.h ggml-metal.h llama.h
641+
llama.o: llama.cpp unicode.h ggml.h ggml-alloc.h ggml-backend.h ggml-cuda.h ggml-metal.h llama.h
639642
$(CXX) $(CXXFLAGS) -c $< -o $@
640643

641644
COMMON_H_DEPS = common/common.h common/sampling.h common/log.h

Package.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ let package = Package(
3131
sources: [
3232
"ggml.c",
3333
"llama.cpp",
34+
"unicode.cpp",
3435
"ggml-alloc.c",
3536
"ggml-backend.c",
3637
"ggml-quants.c",

build.zig

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ pub fn build(b: *std.build.Builder) !void {
115115
const ggml_alloc = make.obj("ggml-alloc", "ggml-alloc.c");
116116
const ggml_backend = make.obj("ggml-backend", "ggml-backend.c");
117117
const ggml_quants = make.obj("ggml-quants", "ggml-quants.c");
118+
const unicode = make.obj("unicode", "unicode.cpp");
118119
const llama = make.obj("llama", "llama.cpp");
119120
const buildinfo = make.obj("common", "common/build-info.cpp");
120121
const common = make.obj("common", "common/common.cpp");
@@ -125,14 +126,14 @@ pub fn build(b: *std.build.Builder) !void {
125126
const clip = make.obj("clip", "examples/llava/clip.cpp");
126127
const llava = make.obj("llava", "examples/llava/llava.cpp");
127128

128-
_ = make.exe("main", "examples/main/main.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, common, buildinfo, sampling, console, grammar_parser });
129-
_ = make.exe("quantize", "examples/quantize/quantize.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, common, buildinfo });
130-
_ = make.exe("perplexity", "examples/perplexity/perplexity.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, common, buildinfo });
131-
_ = make.exe("embedding", "examples/embedding/embedding.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, common, buildinfo });
132-
_ = make.exe("finetune", "examples/finetune/finetune.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, common, buildinfo, train });
133-
_ = make.exe("train-text-from-scratch", "examples/train-text-from-scratch/train-text-from-scratch.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, common, buildinfo, train });
129+
_ = make.exe("main", "examples/main/main.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, common, buildinfo, sampling, console, grammar_parser });
130+
_ = make.exe("quantize", "examples/quantize/quantize.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, common, buildinfo });
131+
_ = make.exe("perplexity", "examples/perplexity/perplexity.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, common, buildinfo });
132+
_ = make.exe("embedding", "examples/embedding/embedding.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, common, buildinfo });
133+
_ = make.exe("finetune", "examples/finetune/finetune.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, common, buildinfo, train });
134+
_ = make.exe("train-text-from-scratch", "examples/train-text-from-scratch/train-text-from-scratch.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, common, buildinfo, train });
134135

135-
const server = make.exe("server", "examples/server/server.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, common, buildinfo, sampling, grammar_parser, clip, llava });
136+
const server = make.exe("server", "examples/server/server.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, common, buildinfo, sampling, grammar_parser, clip, llava });
136137
if (server.target.isWindows()) {
137138
server.linkSystemLibrary("ws2_32");
138139
}

llama.cpp

Lines changed: 37 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -3703,7 +3703,7 @@ static void llm_load_vocab(
37033703

37043704
for (int i = 0; i < n_merges; i++) {
37053705
const std::string word = gguf_get_arr_str(ctx, merges_keyidx, i);
3706-
GGML_ASSERT(codepoints_from_utf8(word).size() > 0);
3706+
GGML_ASSERT(unicode_cpts_from_utf8(word).size() > 0);
37073707

37083708
std::string first;
37093709
std::string second;
@@ -3748,7 +3748,7 @@ static void llm_load_vocab(
37483748

37493749
for (uint32_t i = 0; i < n_vocab; i++) {
37503750
std::string word = gguf_get_arr_str(ctx, token_idx, i);
3751-
GGML_ASSERT(codepoints_from_utf8(word).size() > 0);
3751+
GGML_ASSERT(unicode_cpts_from_utf8(word).size() > 0);
37523752

37533753
vocab.token_to_id[word] = i;
37543754

@@ -9340,7 +9340,7 @@ static uint8_t llama_token_to_byte(const llama_vocab& vocab, llama_token id) {
93409340
}
93419341
case LLAMA_VOCAB_TYPE_BPE: {
93429342
GGML_ASSERT(false);
9343-
return unicode_to_bytes_bpe(token_data.text);
9343+
return unicode_utf8_to_byte(token_data.text);
93449344
}
93459345
case LLAMA_VOCAB_TYPE_WPM: {
93469346
GGML_ASSERT(false);
@@ -9365,7 +9365,7 @@ static llama_token llama_byte_to_token(const llama_vocab & vocab, uint8_t ch) {
93659365
}
93669366
case LLAMA_VOCAB_TYPE_WPM:
93679367
case LLAMA_VOCAB_TYPE_BPE: {
9368-
return vocab.token_to_id.at(bytes_to_unicode_bpe(ch));
9368+
return vocab.token_to_id.at(unicode_byte_to_utf8(ch));
93699369
}
93709370
default:
93719371
GGML_ASSERT(false);
@@ -9705,9 +9705,9 @@ struct llm_tokenizer_bpe {
97059705
bpe_words.reserve(text.size());
97069706
bpe_encoded_words.reserve(text.size());
97079707

9708-
auto cps = codepoints_from_utf8(text);
9709-
for (size_t i = 0; i < cps.size(); ++i)
9710-
text_utf.emplace_back(codepoint_to_utf8(cps[i]));
9708+
const auto cpts = unicode_cpts_from_utf8(text);
9709+
for (size_t i = 0; i < cpts.size(); ++i)
9710+
text_utf.emplace_back(unicode_cpt_to_utf8(cpts[i]));
97119711

97129712
for (int i = 0; i < (int)text_utf.size(); i++) {
97139713
const std::string & utf_char = text_utf[i];
@@ -9757,40 +9757,40 @@ struct llm_tokenizer_bpe {
97579757
}
97589758

97599759
if (!split_condition && !collecting) {
9760-
if (codepoint_type(utf_char) == CODEPOINT_TYPE_LETTER || (!token.size() && utf_char == " " && codepoint_type(utf_char_next) == CODEPOINT_TYPE_LETTER)) {
9760+
if (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_LETTER || (!token.size() && utf_char == " " && unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_LETTER)) {
97619761
collecting_letter = true;
97629762
collecting = true;
97639763
}
9764-
else if (codepoint_type(utf_char) == CODEPOINT_TYPE_DIGIT || (!token.size() && utf_char == " " && codepoint_type(utf_char_next) == CODEPOINT_TYPE_DIGIT)) {
9764+
else if (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_DIGIT || (!token.size() && utf_char == " " && unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_DIGIT)) {
97659765
collecting_numeric = true;
97669766
collecting = true;
97679767
}
97689768
else if (
9769-
((codepoint_type(utf_char) != CODEPOINT_TYPE_LETTER && codepoint_type(utf_char) != CODEPOINT_TYPE_DIGIT) && (codepoint_type(utf_char) != CODEPOINT_TYPE_WHITESPACE)) ||
9770-
(!token.size() && utf_char == " " && codepoint_type(utf_char_next) != CODEPOINT_TYPE_LETTER && codepoint_type(utf_char_next) != CODEPOINT_TYPE_DIGIT && codepoint_type(utf_char_next) != CODEPOINT_TYPE_WHITESPACE)
9769+
((unicode_cpt_type(utf_char) != CODEPOINT_TYPE_LETTER && unicode_cpt_type(utf_char) != CODEPOINT_TYPE_DIGIT) && (unicode_cpt_type(utf_char) != CODEPOINT_TYPE_WHITESPACE)) ||
9770+
(!token.size() && utf_char == " " && unicode_cpt_type(utf_char_next) != CODEPOINT_TYPE_LETTER && unicode_cpt_type(utf_char_next) != CODEPOINT_TYPE_DIGIT && unicode_cpt_type(utf_char_next) != CODEPOINT_TYPE_WHITESPACE)
97719771
) {
97729772
collecting_special = true;
97739773
collecting = true;
97749774
}
9775-
else if (codepoint_type(utf_char) == CODEPOINT_TYPE_WHITESPACE && codepoint_type(utf_char_next) == CODEPOINT_TYPE_WHITESPACE) {
9775+
else if (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_WHITESPACE && unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_WHITESPACE) {
97769776
collecting_whitespace_lookahead = true;
97779777
collecting = true;
97789778
}
9779-
else if (codepoint_type(utf_char) == CODEPOINT_TYPE_WHITESPACE) {
9779+
else if (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_WHITESPACE) {
97809780
split_condition = true;
97819781
}
97829782
}
97839783
else if (!split_condition && collecting) {
9784-
if (collecting_letter && codepoint_type(utf_char) != CODEPOINT_TYPE_LETTER) {
9784+
if (collecting_letter && unicode_cpt_type(utf_char) != CODEPOINT_TYPE_LETTER) {
97859785
split_condition = true;
97869786
}
9787-
else if (collecting_numeric && codepoint_type(utf_char) != CODEPOINT_TYPE_DIGIT) {
9787+
else if (collecting_numeric && unicode_cpt_type(utf_char) != CODEPOINT_TYPE_DIGIT) {
97889788
split_condition = true;
97899789
}
9790-
else if (collecting_special && (codepoint_type(utf_char) == CODEPOINT_TYPE_LETTER || codepoint_type(utf_char) == CODEPOINT_TYPE_DIGIT || codepoint_type(utf_char) == CODEPOINT_TYPE_WHITESPACE)) {
9790+
else if (collecting_special && (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_LETTER || unicode_cpt_type(utf_char) == CODEPOINT_TYPE_DIGIT || unicode_cpt_type(utf_char) == CODEPOINT_TYPE_WHITESPACE)) {
97919791
split_condition = true;
97929792
}
9793-
else if (collecting_whitespace_lookahead && (codepoint_type(utf_char_next) == CODEPOINT_TYPE_LETTER || codepoint_type(utf_char_next) == CODEPOINT_TYPE_DIGIT)) {
9793+
else if (collecting_whitespace_lookahead && (unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_LETTER || unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_DIGIT)) {
97949794
split_condition = true;
97959795
}
97969796
}
@@ -9819,7 +9819,7 @@ struct llm_tokenizer_bpe {
98199819
for (std::string & word : bpe_words) {
98209820
std::string encoded_token = "";
98219821
for (char & c : word) {
9822-
encoded_token += bytes_to_unicode_bpe(c);
9822+
encoded_token += unicode_byte_to_utf8(c);
98239823
}
98249824
bpe_encoded_words.emplace_back(encoded_token);
98259825
}
@@ -9893,33 +9893,21 @@ struct llm_tokenizer_wpm {
98939893
}
98949894

98959895
std::vector<std::string> preprocess(const std::string & text) {
9896-
// normalalization form D
9897-
std::vector<uint32_t> codepoints = codepoints_from_utf8(text);
9898-
std::vector<uint32_t> nfd_codepoints;
9899-
for (uint32_t code : codepoints) {
9900-
auto it = nfd_map.equal_range(code);
9901-
if (it.first != it.second) {
9902-
for (auto jt = it.first; jt != it.second; jt++) {
9903-
nfd_codepoints.push_back(jt->second);
9904-
}
9905-
} else {
9906-
nfd_codepoints.push_back(code);
9907-
}
9908-
}
9896+
std::vector<uint32_t> cpts_nfd = unicode_cpts_normalize_nfd(unicode_cpts_from_utf8(text));
99099897

99109898
// strip accents, strip control, uniformize whitespace,
99119899
// to lowercase, pad chinese characters, pad punctuation
99129900
std::string new_str = "";
9913-
for (uint32_t code : nfd_codepoints) {
9914-
int type = codepoint_type(code);
9901+
for (uint32_t code : cpts_nfd) {
9902+
int type = unicode_cpt_type(code);
99159903
if (type == CODEPOINT_TYPE_ACCENT_MARK || type == CODEPOINT_TYPE_CONTROL) {
99169904
continue;
99179905
}
99189906
code = to_lower(code);
99199907
if (type == CODEPOINT_TYPE_WHITESPACE) {
99209908
code = ' ';
99219909
}
9922-
std::string s = codepoint_to_utf8(code);
9910+
std::string s = unicode_cpt_to_utf8(code);
99239911
if (type == CODEPOINT_TYPE_PUNCTUATION || is_ascii_punct(code) || is_chinese_char(code)) {
99249912
new_str += " ";
99259913
new_str += s;
@@ -9939,8 +9927,7 @@ struct llm_tokenizer_wpm {
99399927
if (r > l) words.push_back(new_str.substr(l, (r - l)));
99409928
l = r + 1;
99419929
r = l;
9942-
}
9943-
else {
9930+
} else {
99449931
r += 1;
99459932
}
99469933
}
@@ -9964,17 +9951,17 @@ struct llm_tokenizer_wpm {
99649951
return code < 256 && ispunct(code);
99659952
}
99669953

9967-
bool is_chinese_char(uint32_t codepoint) {
9968-
if ((codepoint >= 0x4E00 && codepoint <= 0x9FFF) ||
9969-
(codepoint >= 0x3400 && codepoint <= 0x4DBF) ||
9970-
(codepoint >= 0x20000 && codepoint <= 0x2A6DF) ||
9971-
(codepoint >= 0x2A700 && codepoint <= 0x2B73F) ||
9972-
(codepoint >= 0x2B740 && codepoint <= 0x2B81F) ||
9973-
(codepoint >= 0x2B920 && codepoint <= 0x2CEAF) || // this should be 0x2B820 but in hf rust code it is 0x2B920
9974-
(codepoint >= 0xF900 && codepoint <= 0xFAFF) ||
9975-
(codepoint >= 0x2F800 && codepoint <= 0x2FA1F) ||
9976-
(codepoint >= 0x3000 && codepoint <= 0x303F) ||
9977-
(codepoint >= 0xFF00 && codepoint <= 0xFFEF)) {
9954+
bool is_chinese_char(uint32_t cpt) {
9955+
if ((cpt >= 0x4E00 && cpt <= 0x9FFF) ||
9956+
(cpt >= 0x3400 && cpt <= 0x4DBF) ||
9957+
(cpt >= 0x20000 && cpt <= 0x2A6DF) ||
9958+
(cpt >= 0x2A700 && cpt <= 0x2B73F) ||
9959+
(cpt >= 0x2B740 && cpt <= 0x2B81F) ||
9960+
(cpt >= 0x2B920 && cpt <= 0x2CEAF) || // this should be 0x2B820 but in hf rust code it is 0x2B920
9961+
(cpt >= 0xF900 && cpt <= 0xFAFF) ||
9962+
(cpt >= 0x2F800 && cpt <= 0x2FA1F) ||
9963+
(cpt >= 0x3000 && cpt <= 0x303F) ||
9964+
(cpt >= 0xFF00 && cpt <= 0xFFEF)) {
99789965
return true; // NOLINT
99799966
}
99809967
return false;
@@ -13953,9 +13940,9 @@ int32_t llama_tokenize(
1395313940

1395413941
static std::string llama_decode_text(const std::string & text) {
1395513942
std::string decoded_text;
13956-
auto unicode_sequences = codepoints_from_utf8(text);
13957-
for (auto& unicode_sequence : unicode_sequences) {
13958-
decoded_text += unicode_to_bytes_bpe(codepoint_to_utf8(unicode_sequence));
13943+
auto unicode_sequences = unicode_cpts_from_utf8(text);
13944+
for (auto & unicode_sequence : unicode_sequences) {
13945+
decoded_text += unicode_utf8_to_byte(unicode_cpt_to_utf8(unicode_sequence));
1395913946
}
1396013947

1396113948
return decoded_text;

tests/test-tokenizer-1-bpe.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ int main(int argc, char **argv) {
6464
for (int i = 0; i < n_vocab; ++i) {
6565
std::string str = llama_detokenize_bpe(ctx, std::vector<int>(1, i));
6666
try {
67-
auto cps = codepoints_from_utf8(str);
67+
auto cps = unicode_cpts_from_utf8(str);
6868
std::vector<llama_token> tokens = llama_tokenize(ctx, str, false);
6969
std::string check = llama_detokenize_bpe(ctx, tokens);
7070
if (check != str) {
@@ -97,7 +97,7 @@ int main(int argc, char **argv) {
9797
continue;
9898
}
9999

100-
std::string str = codepoint_to_utf8(cp);
100+
std::string str = unicode_cpt_to_utf8(cp);
101101
std::vector<llama_token> tokens = llama_tokenize(ctx, str, false);
102102
std::string check = llama_detokenize_bpe(ctx, tokens);
103103
if (cp != 9601 && str != check) {

tests/test-tokenizer-1-llama.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ int main(int argc, char **argv) {
8585
continue;
8686
}
8787

88-
std::string str = codepoint_to_utf8(cp);
88+
std::string str = unicode_cpt_to_utf8(cp);
8989
std::vector<llama_token> tokens = llama_tokenize(ctx, str, false);
9090
std::string check = llama_detokenize_spm(ctx, tokens);
9191
if (cp != 9601 && str != check) {

0 commit comments

Comments
 (0)