Skip to content

Commit e894ef4

Browse files
authored
[hf] Add features to HF tokenizer (#87)
1 parent 6eb0839 commit e894ef4

File tree

14 files changed

+2380351
-11
lines changed

14 files changed

+2380351
-11
lines changed

.github/workflows/pull.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,6 @@ jobs:
2424
timeout: 90
2525
script: |
2626
set -ex
27-
cmake -DTOKENIZERS_BUILD_TEST=ON -DCMAKE_BUILD_TYPE=Debug . -Bbuild
28-
cmake --build build -j9 --config Debug
29-
cd build && ctest
27+
cmake -DCMAKE_BUILD_TYPE=Debug test -Bbuild/test
28+
cmake --build build/test -j9 --config Debug
29+
cd build/test && ctest

.github/workflows/trunk.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,6 @@ jobs:
3131
timeout: 90
3232
script: |
3333
set -ex
34-
cmake -DTOKENIZERS_BUILD_TEST=ON -DCMAKE_BUILD_TYPE=Debug . -Bbuild
35-
cmake --build build -j9 --config Debug
36-
cd build && ctest
34+
cmake -DCMAKE_BUILD_TYPE=Debug test -Bbuild/test
35+
cmake --build build/test -j9 --config Debug
36+
cd build/test && ctest

include/pytorch/tokenizers/pre_tokenizer.h

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,21 @@ class PreTokenizerConfig {
104104
*/
105105
CONFIG_MEMBER(bool, add_prefix_space)
106106

107+
/**
108+
* Used by RegexPreTokenizer
109+
*/
110+
CONFIG_MEMBER(bool, is_delimiter)
111+
112+
/**
113+
* Used by RegexPreTokenizer - Split behavior
114+
*/
115+
CONFIG_MEMBER(std::string, behavior)
116+
117+
/**
118+
* Used by RegexPreTokenizer - Split invert flag
119+
*/
120+
CONFIG_MEMBER(bool, invert)
121+
107122
/**
108123
* Used by: SequencePreTokenizer
109124
*/
@@ -141,8 +156,29 @@ class PreTokenizerConfig {
141156

142157
class RegexPreTokenizer : public PreTokenizer {
143158
public:
144-
explicit RegexPreTokenizer(const std::string& pattern)
145-
: regex_(RegexPreTokenizer::create_regex_(pattern)) {}
159+
/**
160+
* @param pattern: The regex pattern to use for token splitting
161+
* @param is_delimiter: Whether treat `pattern` as delimiter characters, or
162+
* use `pattern` as a regex pattern.
163+
* @param behavior: Split behavior (only "MergedWithPrevious" supported)
164+
* For example:
165+
* "pre_tokenizer": {
166+
* "type": "Split",
167+
* "pattern": {
168+
* "String": " "
169+
* },
170+
* "behavior": "MergedWithPrevious",
171+
* "invert": false
172+
* },
173+
* Notice that the `invert` option is not supported.
174+
*/
175+
explicit RegexPreTokenizer(
176+
const std::string& pattern,
177+
bool is_delimiter = false,
178+
const std::string& behavior = "")
179+
: regex_(RegexPreTokenizer::create_regex_(pattern)),
180+
is_delimiter_(is_delimiter),
181+
behavior_(behavior) {}
146182

147183
/** Pre-tokenize with the stored regex */
148184
std::vector<std::string> pre_tokenize(const std::string& input) const;
@@ -151,6 +187,8 @@ class RegexPreTokenizer : public PreTokenizer {
151187
static std::unique_ptr<IRegex> create_regex_(const std::string& pattern);
152188

153189
std::unique_ptr<IRegex> regex_;
190+
const bool is_delimiter_;
191+
const std::string behavior_;
154192

155193
}; // end class RegexPreTokenizer
156194

include/pytorch/tokenizers/regex.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,14 @@ class IRegex {
4242
* @return A vector of strings containing all matched substrings.
4343
*/
4444
virtual std::vector<Match> find_all(const std::string& text) const = 0;
45+
46+
/**
47+
* @brief Escape special regex characters in a string to treat it as literal.
48+
*
49+
* @param input The input string to escape.
50+
* @return The escaped string that can be used as a literal pattern in regex.
51+
*/
52+
static std::string escape(const std::string& input);
4553
};
4654

4755
// Function pointer type for create_fallback_regex implementations

include/pytorch/tokenizers/token_decoder.h

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,13 @@ class TokenDecoderConfig {
6565
*/
6666
std::string type;
6767

68+
// Parameters for Replace decoder
69+
std::string replace_pattern;
70+
std::string replace_content;
71+
72+
// Parameters for Sequence decoder
73+
std::vector<nlohmann::json> sequence_decoders;
74+
6875
/*----------------*/
6976
/* Public methods */
7077
/*----------------*/
@@ -96,4 +103,49 @@ class ByteLevelTokenDecoder : public TokenDecoder {
96103

97104
}; // end class ByteLevelTokenDecoder
98105

106+
// -- Replace ------------------------------------------------------------------
107+
// Replaces a pattern with a replacement string
108+
109+
class ReplaceTokenDecoder : public TokenDecoder {
110+
public:
111+
explicit ReplaceTokenDecoder(
112+
const std::string& pattern,
113+
const std::string& content);
114+
std::string decode(const std::string& token) const override;
115+
116+
private:
117+
std::string pattern_;
118+
std::string content_;
119+
}; // end class ReplaceTokenDecoder
120+
121+
// -- ByteFallback -------------------------------------------------------------
122+
// Handles byte fallback decoding
123+
124+
class ByteFallbackTokenDecoder : public TokenDecoder {
125+
public:
126+
std::string decode(const std::string& token) const override;
127+
128+
}; // end class ByteFallbackTokenDecoder
129+
130+
// -- Fuse --------------------------------------------------------------------
131+
// Fuses tokens together
132+
133+
class FuseTokenDecoder : public TokenDecoder {
134+
public:
135+
std::string decode(const std::string& token) const override;
136+
137+
}; // end class FuseTokenDecoder
138+
139+
// -- Sequence -----------------------------------------------------------------
140+
// Applies a sequence of decoders in order
141+
142+
class SequenceTokenDecoder : public TokenDecoder {
143+
public:
144+
explicit SequenceTokenDecoder(std::vector<TokenDecoder::Ptr> decoders);
145+
std::string decode(const std::string& token) const override;
146+
147+
private:
148+
std::vector<TokenDecoder::Ptr> decoders_;
149+
}; // end class SequenceTokenDecoder
150+
99151
} // namespace tokenizers

src/bpe_tokenizer_base.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,7 @@ Result<std::vector<uint64_t>> BPETokenizerBase::byte_pair_encode_(
194194
return std::vector<uint64_t>(*result);
195195
} else {
196196
// TODO: is it possible?
197+
TK_LOG(Error, "unknown token: '%s'", piece.c_str());
197198
return Error::EncodeFailure;
198199
}
199200
}

src/pre_tokenizer.cpp

Lines changed: 97 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,24 @@ PreTokenizer::Ptr PreTokenizerConfig::create() const {
3737
throw std::runtime_error(
3838
"Missing pattern for PreTokenizer of type Split");
3939
}
40-
return PreTokenizer::Ptr(new RegexPreTokenizer(*pattern));
40+
41+
// Validate behavior parameter
42+
std::string behavior_str = behavior ? *behavior : "";
43+
if (!behavior_str.empty() && behavior_str != "MergedWithPrevious") {
44+
throw std::runtime_error(
45+
"Unsupported behavior '" + behavior_str +
46+
"' for Split PreTokenizer. Only 'MergedWithPrevious' is supported.");
47+
}
48+
49+
// Validate invert parameter
50+
bool invert_flag = invert ? *invert : false;
51+
if (invert_flag) {
52+
throw std::runtime_error(
53+
"invert=true is not supported for Split PreTokenizer. Only invert=false is supported.");
54+
}
55+
56+
return PreTokenizer::Ptr(new RegexPreTokenizer(
57+
*pattern, is_delimiter ? *is_delimiter : false, behavior_str));
4158
}
4259
if (type == "Digits") {
4360
if (individual_digits) {
@@ -79,7 +96,27 @@ PreTokenizerConfig& PreTokenizerConfig::parse_json(const json& json_config) {
7996
if (type == "Split") {
8097
try {
8198
pattern = json_config.at("pattern").at("Regex");
99+
is_delimiter = false;
100+
} catch (json::out_of_range&) {
101+
// "Regex" is not there, check "String", which is a delimiter
102+
std::string delimiter = json_config.at("pattern").at("String");
103+
// For string patterns, escape regex special characters to treat them as
104+
// literal strings (same as Rust's regex::escape)
105+
pattern = IRegex::escape(delimiter);
106+
is_delimiter = true;
107+
}
108+
109+
// Parse behavior and invert fields
110+
try {
111+
behavior = json_config.at("behavior");
112+
} catch (json::out_of_range&) {
113+
// behavior is optional, default to empty string
114+
}
115+
116+
try {
117+
invert = json_config.at("invert");
82118
} catch (json::out_of_range&) {
119+
// invert is optional, default to false
83120
}
84121
} else if (type == "Digits") {
85122
try {
@@ -115,9 +152,66 @@ std::vector<std::string> RegexPreTokenizer::pre_tokenize(
115152
const std::string& input) const {
116153
if (!regex_)
117154
return {};
155+
118156
std::vector<std::string> results;
119-
for (const auto& match : regex_->find_all(input)) {
120-
results.push_back(input.substr(match.start, match.end - match.start));
157+
auto matches = regex_->find_all(input);
158+
159+
if (!is_delimiter_) {
160+
// Original behavior: return the matches themselves
161+
for (const auto& match : matches) {
162+
results.push_back(input.substr(match.start, match.end - match.start));
163+
}
164+
} else {
165+
// Delimiter behavior
166+
if (matches.empty()) {
167+
// No matches found, return the entire input
168+
results.push_back(input);
169+
return results;
170+
}
171+
172+
if (behavior_ == "MergedWithPrevious") {
173+
// MergedWithPrevious: Include delimiter with previous token
174+
// Example: "the-final--countdown" with delimiter "-"
175+
// -> ["the-", "final-", "-", "countdown"]
176+
size_t last_end = 0;
177+
178+
for (size_t i = 0; i < matches.size(); ++i) {
179+
const auto& match = matches[i];
180+
181+
// Add text before the match plus the delimiter
182+
if (match.start > last_end) {
183+
std::string token = input.substr(last_end, match.end - last_end);
184+
results.push_back(token);
185+
} else {
186+
// Only delimiter, no preceding text
187+
std::string delimiter =
188+
input.substr(match.start, match.end - match.start);
189+
results.push_back(delimiter);
190+
}
191+
192+
last_end = match.end;
193+
}
194+
195+
// Add remaining text after the last match (if any)
196+
if (last_end < input.length()) {
197+
results.push_back(input.substr(last_end));
198+
}
199+
} else {
200+
// Default delimiter behavior (split on delimiters)
201+
size_t last_end = 0;
202+
for (const auto& match : matches) {
203+
// Add text before the match (if any)
204+
if (match.start > last_end) {
205+
results.push_back(input.substr(last_end, match.start - last_end));
206+
}
207+
last_end = match.end;
208+
}
209+
210+
// Add remaining text after the last match (if any)
211+
if (last_end < input.length()) {
212+
results.push_back(input.substr(last_end));
213+
}
214+
}
121215
}
122216
return results;
123217
}

src/regex.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,23 @@ FallbackRegexFn get_fallback_regex() {
3333
return fallback_regex;
3434
}
3535

36+
std::string IRegex::escape(const std::string& input) {
37+
std::string result;
38+
result.reserve(input.size() * 2); // Reserve space for potential escaping
39+
40+
for (char c : input) {
41+
// Escape regex special characters to treat them as literal strings
42+
if (c == '\\' || c == '^' || c == '$' || c == '.' || c == '|' || c == '?' ||
43+
c == '*' || c == '+' || c == '(' || c == ')' || c == '[' || c == ']' ||
44+
c == '{' || c == '}') {
45+
result += '\\';
46+
}
47+
result += c;
48+
}
49+
50+
return result;
51+
}
52+
3653
Result<std::unique_ptr<IRegex>> create_regex(const std::string& pattern) {
3754
// Try RE2 first
3855
auto re2 = std::make_unique<Re2Regex>();

0 commit comments

Comments
 (0)