Skip to content

Commit 6a6e24f

Browse files
authored
TikToken uses common bpe base functions (#45)
1 parent 295ee78 commit 6a6e24f

File tree

3 files changed

+13
-87
lines changed

3 files changed

+13
-87
lines changed

include/pytorch/tokenizers/tiktoken.h

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -76,24 +76,13 @@ class Tiktoken : public detail::BPETokenizerBase {
7676
return special_tokens;
7777
}
7878

79-
template <typename T>
80-
std::pair<std::optional<std::string>, re2::StringPiece>
81-
_split_with_allowed_special_token(
82-
re2::StringPiece& input,
83-
const T& allowed_special) const;
84-
8579
Error _encode(
8680
re2::StringPiece& input,
8781
std::vector<uint64_t>& ret,
8882
uint64_t& last_piece_token_len) const override;
8983

9084
void _decode(re2::StringPiece input, std::string& ret) const override;
9185

92-
template <typename T>
93-
Result<std::pair<std::vector<uint64_t>, uint64_t>> _encode_with_special_token(
94-
const std::string& text,
95-
const T& allowed_special) const;
96-
9786
detail::TokenMap _build_special_token_map(ssize_t num_base_tokens) const;
9887

9988
std::unique_ptr<std::vector<std::string>> _special_tokens;

src/bpe_tokenizer_base.cpp

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,12 @@ BPETokenizerBase::split_with_allowed_special_token_(
138138
return std::make_pair(std::nullopt, input);
139139
}
140140

141+
#if __cplusplus >= 202002L
141142
auto start = input.begin();
143+
#else
144+
const char* start = input.data();
145+
#endif
146+
142147
std::string special;
143148
while (true) {
144149
if (!re2::RE2::FindAndConsume(&input, *special_token_regex_, &special)) {
@@ -148,9 +153,15 @@ BPETokenizerBase::split_with_allowed_special_token_(
148153

149154
if (allowed_special.tryGetInteger(special).has_value()) {
150155
// Found an allowed special token, split the text with it.
156+
#if __cplusplus >= 202002L
151157
return std::make_pair(
152158
special,
153159
re2::StringPiece(start, input.begin() - start - special.size()));
160+
#else
161+
return std::make_pair(
162+
special,
163+
re2::StringPiece(start, (input.data() - start) - special.size()));
164+
#endif
154165
} // else try to find the next special token
155166
}
156167

@@ -168,7 +179,8 @@ BPETokenizerBase::encode_with_special_token_(
168179
auto [special, sub_input] =
169180
split_with_allowed_special_token_(input, allowed_special);
170181

171-
_encode(sub_input, tokens, last_piece_token_len);
182+
TK_CHECK_OK_OR_RETURN_ERROR(
183+
_encode(sub_input, tokens, last_piece_token_len));
172184

173185
if (special) {
174186
const auto result = special_token_map_->tryGetInteger(*special);

src/tiktoken.cpp

Lines changed: 0 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -113,44 +113,6 @@ static Result<TokenMap> _load_token_map(const std::string& path) {
113113
// ------------------------------Util end------------------------------------
114114
// -------------------------private method start-------------------------------
115115

116-
template <typename T>
117-
std::pair<std::optional<std::string>, re2::StringPiece>
118-
Tiktoken::_split_with_allowed_special_token(
119-
re2::StringPiece& input,
120-
const T& allowed_special) const {
121-
if (!special_token_regex_) {
122-
return std::make_pair(std::nullopt, input);
123-
}
124-
125-
#if __cplusplus >= 202002L
126-
auto start = input.begin();
127-
#else
128-
const char* start = input.data();
129-
#endif
130-
std::string special;
131-
while (true) {
132-
if (!re2::RE2::FindAndConsume(&input, *special_token_regex_, &special)) {
133-
// No special token.
134-
break;
135-
}
136-
137-
if (allowed_special.tryGetInteger(special)) {
138-
// Found an allowed special token, split the text with it.
139-
#if __cplusplus >= 202002L
140-
return std::make_pair(
141-
special,
142-
re2::StringPiece(start, input.begin() - start - special.size()));
143-
#else
144-
return std::make_pair(
145-
special,
146-
re2::StringPiece(start, (input.data() - start) - special.size()));
147-
#endif
148-
} // else try to find the next special token
149-
}
150-
151-
return std::make_pair(std::nullopt, input);
152-
}
153-
154116
Error Tiktoken::_encode(
155117
re2::StringPiece& input,
156118
std::vector<uint64_t>& ret,
@@ -179,43 +141,6 @@ void Tiktoken::_decode(re2::StringPiece input, std::string& ret) const {
179141
#endif
180142
}
181143

182-
template <typename T>
183-
Result<std::pair<std::vector<uint64_t>, uint64_t>>
184-
Tiktoken::_encode_with_special_token(
185-
const std::string& text,
186-
const T& allowed_special) const {
187-
std::vector<uint64_t> tokens;
188-
uint64_t last_piece_token_len = 0;
189-
re2::StringPiece input(text);
190-
while (true) {
191-
auto [special, sub_input] =
192-
_split_with_allowed_special_token(input, allowed_special);
193-
194-
TK_CHECK_OK_OR_RETURN_ERROR(
195-
_encode(sub_input, tokens, last_piece_token_len));
196-
197-
if (special) {
198-
const auto result = special_token_map_->tryGetInteger(*special);
199-
if (!result) {
200-
// Should never go here, since special pattern includes all special
201-
// chars.
202-
TK_LOG(Error, "unknown special token: %s", special->c_str());
203-
return Error::EncodeFailure;
204-
}
205-
206-
tokens.push_back(*result);
207-
last_piece_token_len = 0;
208-
} else {
209-
break;
210-
}
211-
}
212-
213-
// last_piece_token_len is how many tokens came from the last regex split.
214-
// This is used for determining unstable tokens, since you can't merge
215-
// across (stable) regex splits
216-
return std::make_pair(tokens, last_piece_token_len);
217-
}
218-
219144
// -------------------------private method end-------------------------------
220145
// -------------------------public method start-------------------------------
221146

0 commit comments

Comments
 (0)