Skip to content

Commit 14a80d3

Browse files
authored
Support NFC Normalizer (#104)
A simplified version of NFC. Based on what we have for NFD, this PR implements the composition logic. Also refer back to https://unicode.org/reports/tr15/ on how this works.
1 parent 20da1ca commit 14a80d3

File tree

5 files changed

+277
-14
lines changed

5 files changed

+277
-14
lines changed

include/pytorch/tokenizers/normalizer.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,4 +171,19 @@ class SequenceNormalizer : public Normalizer {
171171

172172
}; // end class SequenceNormalizer
173173

174+
// -- NFC ----------------------------------------------------------------------
175+
// Used for Unicode NFC (Normalization Form Canonical Composition) normalization
176+
// CITE:
177+
// https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/normalizers/unicode.rs
178+
179+
class NFCNormalizer : public Normalizer {
180+
public:
181+
/** Default constructor */
182+
explicit NFCNormalizer() = default;
183+
184+
/** Normalize with NFC Unicode normalization */
185+
std::string normalize(const std::string& input) const override;
186+
187+
}; // end class NFCNormalizer
188+
174189
} // namespace tokenizers

src/normalizer.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010
// Local
1111
#include <pytorch/tokenizers/normalizer.h>
1212

13+
// Third Party
14+
#include <unicode.h>
15+
1316
// Standard
1417
#include <algorithm>
1518
#include <iterator>
@@ -54,6 +57,9 @@ Normalizer::Ptr NormalizerConfig::create() const {
5457
[](const NormalizerConfig& cfg) { return cfg.create(); });
5558
return Normalizer::Ptr(new SequenceNormalizer(norms));
5659
}
60+
if (type == "NFC") {
61+
return Normalizer::Ptr(new NFCNormalizer());
62+
}
5763
throw std::runtime_error("Unsupported Normalizer type: " + type);
5864
}
5965

@@ -76,6 +82,11 @@ NormalizerConfig& NormalizerConfig::parse_json(const json& json_config) {
7682
for (const auto& entry : json_config.at("normalizers")) {
7783
normalizers->push_back(NormalizerConfig().parse_json(entry));
7884
}
85+
} else if (type == "NFC") {
86+
// NFC normalizer has no additional configuration parameters
87+
TK_LOG(
88+
Info,
89+
"Using NFC normalizer. Please notice that our implementation may not handle all edge cases.");
7990
} else {
8091
throw std::runtime_error("Unsupported Normalizer type: " + type);
8192
}
@@ -119,4 +130,22 @@ std::string SequenceNormalizer::normalize(const std::string& input) const {
119130
return result;
120131
}
121132

133+
// NFCNormalizer ///////////////////////////////////////////////////////////////
134+
135+
std::string NFCNormalizer::normalize(const std::string& input) const {
136+
// Convert UTF-8 string to codepoints
137+
auto codepoints = unicode_cpts_from_utf8(input);
138+
139+
// Apply NFC normalization
140+
auto normalized_cpts = unicode_cpts_normalize_nfc(codepoints);
141+
142+
// Convert back to UTF-8 string
143+
std::string result;
144+
for (uint32_t cpt : normalized_cpts) {
145+
result += unicode_cpt_to_utf8(cpt);
146+
}
147+
148+
return result;
149+
}
150+
122151
} // namespace tokenizers

test/test_hf_tokenizer.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,27 +10,35 @@
1010
"""
1111

1212
import unittest
13+
import pytest
1314
from pytorch_tokenizers import CppHFTokenizer
1415
from transformers import AutoTokenizer
1516
from tempfile import TemporaryDirectory
1617

1718
PROMPT = "What is the capital of France?"
1819

19-
class TestHfTokenizer(unittest.TestCase):
20-
def setUp(self) -> None:
21-
self.temp_dir = TemporaryDirectory()
22-
super().setUp()
2320

24-
def test_smolLM3(self) -> None:
25-
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM3-3B")
26-
tokenizer_path = tokenizer.save_pretrained(self.temp_dir.name)[-1]
21+
@pytest.mark.parametrize("model_id", [
22+
"HuggingFaceTB/SmolLM3-3B",
23+
"Qwen/Qwen2.5-0.5B"
24+
])
25+
def test_models(model_id: str) -> None:
26+
with TemporaryDirectory() as temp_dir:
27+
tokenizer = AutoTokenizer.from_pretrained(model_id)
28+
tokenizer_path = tokenizer.save_pretrained(temp_dir)[-1]
2729

2830
cpp_tokenizer = CppHFTokenizer()
2931
cpp_tokenizer.load(tokenizer_path)
3032

3133
tokens = tokenizer.encode(PROMPT)
3234
cpp_tokens = cpp_tokenizer.encode(PROMPT)
33-
self.assertEqual(tokens, cpp_tokens)
35+
assert tokens == cpp_tokens
36+
37+
38+
class TestHfTokenizer(unittest.TestCase):
39+
def setUp(self) -> None:
40+
self.temp_dir = TemporaryDirectory()
41+
super().setUp()
3442

3543
def test_llama3_2_1b(self) -> None:
3644
tokenizer = AutoTokenizer.from_pretrained("unsloth/Llama-3.2-1B-Instruct")
@@ -42,7 +50,3 @@ def test_llama3_2_1b(self) -> None:
4250
tokens = tokenizer.encode(PROMPT)
4351
cpp_tokens = cpp_tokenizer.encode(PROMPT, bos=1)
4452
self.assertEqual(tokens, cpp_tokens)
45-
46-
47-
async def test_async_DO_NOT_COMMIT(self) -> None:
48-
pass

third-party/llama.cpp-unicode/include/unicode.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,9 @@ std::vector<uint32_t> unicode_cpts_from_utf8(const std::string &utf8);
8484
std::vector<uint32_t>
8585
unicode_cpts_normalize_nfd(const std::vector<uint32_t> &cpts);
8686

87+
std::vector<uint32_t>
88+
unicode_cpts_normalize_nfc(const std::vector<uint32_t> &cpts);
89+
8790
codepoint_flags unicode_cpt_flags(const uint32_t cp);
8891
codepoint_flags unicode_cpt_flags(const std::string &utf8);
8992

third-party/llama.cpp-unicode/src/unicode.cpp

Lines changed: 214 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,16 +37,27 @@ SOFTWARE.
3737
#include <codecvt>
3838
#include <cstddef>
3939
#include <cstdint>
40+
#include <functional>
41+
#include <iterator>
42+
#include <limits>
4043
#include <locale>
4144
#include <map>
4245
#include <regex>
4346
#include <stdexcept>
4447
#include <string>
4548
#include <unordered_map>
46-
#include <unordered_set>
47-
#include <utility>
4849
#include <vector>
4950

51+
// Hash function for std::pair<uint32_t, uint32_t> used in composition table
52+
namespace std {
53+
template<>
54+
struct hash<std::pair<uint32_t, uint32_t>> {
55+
std::size_t operator()(const std::pair<uint32_t, uint32_t>& p) const {
56+
return std::hash<uint64_t>{}(((uint64_t)p.first << 32) | p.second);
57+
}
58+
};
59+
}
60+
5061
size_t unicode_len_utf8(char src) {
5162
const size_t lookup[] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4};
5263
uint8_t highbits = static_cast<uint8_t>(src) >> 4;
@@ -928,3 +939,204 @@ std::vector<std::string> unicode_regex_split(
928939

929940
return unicode_byte_encoding_process(bpe_words);
930941
}
942+
943+
// Get canonical combining class for a codepoint using existing flags data
944+
static uint8_t get_combining_class(uint32_t cpt) {
945+
codepoint_flags flags = unicode_cpt_flags(cpt);
946+
947+
// Use the existing flag system to determine combining class
948+
if (flags.is_accent_mark) {
949+
// Most combining marks have class 230, but some have different classes
950+
// This is a simplified mapping based on common Unicode patterns
951+
if (cpt >= 0x0591 && cpt <= 0x05BD) return 220; // Hebrew accents
952+
if (cpt >= 0x05BF && cpt <= 0x05C7) return 230; // Hebrew points
953+
if (cpt >= 0x0610 && cpt <= 0x061A) return 230; // Arabic marks
954+
if (cpt >= 0x064B && cpt <= 0x065F) return 30; // Arabic vowels
955+
if (cpt >= 0x0670 && cpt <= 0x0670) return 35; // Arabic superscript alef
956+
if (cpt >= 0x06D6 && cpt <= 0x06E4) return 230; // Arabic small high marks
957+
if (cpt >= 0x06E7 && cpt <= 0x06E8) return 230; // Arabic small high marks
958+
if (cpt >= 0x06EA && cpt <= 0x06ED) return 220; // Arabic small low marks
959+
960+
// Default combining class for most combining marks
961+
return 230;
962+
}
963+
964+
return 0; // Non-combining character (starter)
965+
}
966+
967+
// Apply canonical ordering using bubble sort (simple but correct)
968+
static void canonical_order(std::vector<uint32_t>& cpts) {
969+
for (size_t i = 1; i < cpts.size(); ++i) {
970+
for (size_t j = i; j > 0; --j) {
971+
uint8_t cc1 = get_combining_class(cpts[j-1]);
972+
uint8_t cc2 = get_combining_class(cpts[j]);
973+
974+
// Only reorder if both have non-zero combining class and are out of order
975+
if (cc1 > cc2 && cc2 != 0) {
976+
std::swap(cpts[j-1], cpts[j]);
977+
} else {
978+
break;
979+
}
980+
}
981+
}
982+
}
983+
984+
// Build composition table by reverse-engineering the NFD data
985+
static std::unordered_map<std::pair<uint32_t, uint32_t>, uint32_t> build_composition_table() {
986+
std::unordered_map<std::pair<uint32_t, uint32_t>, uint32_t> composition_map;
987+
988+
// Iterate through all NFD mappings to build reverse composition table
989+
for (const auto& range : unicode_ranges_nfd) {
990+
for (uint32_t cpt = range.first; cpt <= range.last; ++cpt) {
991+
uint32_t base = range.nfd;
992+
993+
// For NFC, we need to figure out what combining character was removed
994+
// This is a simplified approach that works for the most common cases
995+
996+
// Common diacritic mappings based on the composed character
997+
uint32_t combining = 0;
998+
999+
// Determine combining character based on the composed character
1000+
// This is derived from common Unicode patterns
1001+
switch (cpt) {
1002+
// Grave accent (0x0300)
1003+
case 0x00C0: case 0x00E0: // À à
1004+
case 0x00C8: case 0x00E8: // È è
1005+
case 0x00CC: case 0x00EC: // Ì ì
1006+
case 0x00D2: case 0x00F2: // Ò ò
1007+
case 0x00D9: case 0x00F9: // Ù ù
1008+
case 0x01CD: case 0x01CE: // Ǎ ǎ
1009+
case 0x01CF: case 0x01D0: // Ǐ ǐ
1010+
case 0x01D1: case 0x01D2: // Ǒ ǒ
1011+
case 0x01D3: case 0x01D4: // Ǔ ǔ
1012+
combining = 0x0300; break;
1013+
1014+
// Acute accent (0x0301)
1015+
case 0x00C1: case 0x00E1: // Á á
1016+
case 0x00C9: case 0x00E9: // É é
1017+
case 0x00CD: case 0x00ED: // Í í
1018+
case 0x00D3: case 0x00F3: // Ó ó
1019+
case 0x00DA: case 0x00FA: // Ú ú
1020+
case 0x00DD: case 0x00FD: // Ý ý
1021+
combining = 0x0301; break;
1022+
1023+
// Circumflex (0x0302)
1024+
case 0x00C2: case 0x00E2: // Â â
1025+
case 0x00CA: case 0x00EA: // Ê ê
1026+
case 0x00CE: case 0x00EE: // Î î
1027+
case 0x00D4: case 0x00F4: // Ô ô
1028+
case 0x00DB: case 0x00FB: // Û û
1029+
combining = 0x0302; break;
1030+
1031+
// Tilde (0x0303)
1032+
case 0x00C3: case 0x00E3: // Ã ã
1033+
case 0x00D1: case 0x00F1: // Ñ ñ
1034+
case 0x00D5: case 0x00F5: // Õ õ
1035+
combining = 0x0303; break;
1036+
1037+
// Diaeresis (0x0308)
1038+
case 0x00C4: case 0x00E4: // Ä ä
1039+
case 0x00CB: case 0x00EB: // Ë ë
1040+
case 0x00CF: case 0x00EF: // Ï ï
1041+
case 0x00D6: case 0x00F6: // Ö ö
1042+
case 0x00DC: case 0x00FC: // Ü ü
1043+
case 0x00FF: // ÿ
1044+
combining = 0x0308; break;
1045+
1046+
// Ring above (0x030A)
1047+
case 0x00C5: case 0x00E5: // Å å
1048+
combining = 0x030A; break;
1049+
1050+
// Cedilla (0x0327)
1051+
case 0x00C7: case 0x00E7: // Ç ç
1052+
combining = 0x0327; break;
1053+
1054+
default:
1055+
// For other characters, try to infer from Unicode blocks
1056+
if (cpt >= 0x0100 && cpt <= 0x017F) {
1057+
// Extended Latin A - try common patterns
1058+
if ((cpt & 1) == 0) { // Even codepoints (uppercase)
1059+
if (cpt >= 0x0100 && cpt <= 0x0105) combining = 0x0304; // macron
1060+
else if (cpt >= 0x0102 && cpt <= 0x0107) combining = 0x0306; // breve
1061+
else if (cpt >= 0x0104 && cpt <= 0x0119) combining = 0x0328; // ogonek
1062+
else if (cpt >= 0x0106 && cpt <= 0x010D) combining = 0x0301; // acute
1063+
else if (cpt >= 0x0108 && cpt <= 0x010F) combining = 0x0302; // circumflex
1064+
else if (cpt >= 0x010A && cpt <= 0x0111) combining = 0x0307; // dot above
1065+
else if (cpt >= 0x010C && cpt <= 0x0165) combining = 0x030C; // caron
1066+
}
1067+
}
1068+
break;
1069+
}
1070+
1071+
// Only add to composition table if we identified a combining character
1072+
if (combining != 0) {
1073+
composition_map[{base, combining}] = cpt;
1074+
}
1075+
}
1076+
}
1077+
1078+
return composition_map;
1079+
}
1080+
1081+
// Get the composition table (built once, cached)
1082+
static const std::unordered_map<std::pair<uint32_t, uint32_t>, uint32_t>& get_composition_table() {
1083+
static const auto composition_table = build_composition_table();
1084+
return composition_table;
1085+
}
1086+
1087+
std::vector<uint32_t> unicode_cpts_normalize_nfc(
1088+
const std::vector<uint32_t>& cpts) {
1089+
1090+
// Step 1: Apply NFD (canonical decomposition) using existing implementation
1091+
std::vector<uint32_t> nfd_result = unicode_cpts_normalize_nfd(cpts);
1092+
1093+
// Step 2: Apply canonical ordering
1094+
canonical_order(nfd_result);
1095+
1096+
// Step 3: Apply canonical composition
1097+
const auto& composition_table = get_composition_table();
1098+
std::vector<uint32_t> result;
1099+
result.reserve(nfd_result.size());
1100+
1101+
size_t i = 0;
1102+
while (i < nfd_result.size()) {
1103+
uint32_t starter = nfd_result[i];
1104+
result.push_back(starter);
1105+
1106+
// Only try to compose if this is a starter (combining class 0)
1107+
if (get_combining_class(starter) == 0) {
1108+
size_t last_starter_pos = result.size() - 1;
1109+
1110+
// Look for composable combining marks after this starter
1111+
size_t j = i + 1;
1112+
while (j < nfd_result.size()) {
1113+
uint32_t combining = nfd_result[j];
1114+
uint8_t cc = get_combining_class(combining);
1115+
1116+
// If we hit another starter, stop
1117+
if (cc == 0) break;
1118+
1119+
// Try to compose with the last starter
1120+
auto key = std::make_pair(result[last_starter_pos], combining);
1121+
auto it = composition_table.find(key);
1122+
1123+
if (it != composition_table.end()) {
1124+
// Compose: replace starter with composed character
1125+
result[last_starter_pos] = it->second;
1126+
// Skip this combining character
1127+
++j;
1128+
continue;
1129+
}
1130+
1131+
// No composition possible, add the combining character
1132+
result.push_back(combining);
1133+
++j;
1134+
}
1135+
i = j;
1136+
} else {
1137+
++i;
1138+
}
1139+
}
1140+
1141+
return result;
1142+
}

0 commit comments

Comments
 (0)