Skip to content

Commit 9fa0a64

Browse files
authored
Support gemma3 HF tokenizer.json
Differential Revision: D77761574 Pull Request resolved: #96
1 parent cf543d0 commit 9fa0a64

15 files changed

+821
-53
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ set(tokenizers_source_files
4646
${CMAKE_CURRENT_SOURCE_DIR}/src/bpe_tokenizer_base.cpp
4747
${CMAKE_CURRENT_SOURCE_DIR}/src/hf_tokenizer.cpp
4848
${CMAKE_CURRENT_SOURCE_DIR}/src/llama2c_tokenizer.cpp
49+
${CMAKE_CURRENT_SOURCE_DIR}/src/normalizer.cpp
4950
${CMAKE_CURRENT_SOURCE_DIR}/src/pre_tokenizer.cpp
5051
${CMAKE_CURRENT_SOURCE_DIR}/src/re2_regex.cpp
5152
${CMAKE_CURRENT_SOURCE_DIR}/src/regex.cpp

include/pytorch/tokenizers/bpe_tokenizer_base.h

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ namespace detail {
3333
using TokenMap = StringIntegerMap<>;
3434

3535
template <typename TToken, typename TRank>
36-
static Result<TokenMap> buildTokenMap(
36+
static Result<TokenMap> build_token_map(
3737
std::vector<std::pair<TToken, TRank>> container) {
3838
static_assert(
3939
std::is_same_v<TToken, std::string> ||
@@ -82,7 +82,7 @@ static Result<TokenMap> buildTokenMap(
8282
};
8383

8484
template <typename TContainer, typename TTokenAccessor, typename TRankAccessor>
85-
static Result<TokenMap> buildTokenMap(
85+
static Result<TokenMap> build_token_map(
8686
const TContainer& container,
8787
TTokenAccessor token_accessor,
8888
TRankAccessor rank_accessor) {
@@ -103,7 +103,7 @@ static Result<TokenMap> buildTokenMap(
103103
pairs.emplace_back(token_accessor(value), rank_accessor(value));
104104
}
105105

106-
return buildTokenMap(std::move(pairs));
106+
return build_token_map(std::move(pairs));
107107
}
108108

109109
inline Result<std::unique_ptr<IRegex>> build_special_token_regex(
@@ -152,10 +152,19 @@ class BPETokenizerBase : public Tokenizer {
152152
const std::string& text,
153153
const TokenMap& allowed_special) const;
154154

155-
Result<std::vector<uint64_t>> byte_pair_encode_(
155+
virtual Result<std::vector<uint64_t>> byte_pair_encode_(
156156
const std::string& piece,
157157
const TokenMap& encoder) const;
158158

159+
// Virtual method for BPE merging - can be overridden by derived classes
160+
// The passed in `ranks` param for the base impl is just a regular token map
161+
// and that the actual ranks are derived implicitly from the regular token
162+
// map. This is the same implementation as Tiktoken.
163+
virtual std::vector<uint64_t> _byte_pair_merge(
164+
const std::string& piece,
165+
const TokenMap& ranks,
166+
std::function<uint64_t(uint64_t, uint64_t)> func) const;
167+
159168
// Protected members that can be overloaded by other BPE tokenizers
160169
std::unique_ptr<IRegex> special_token_regex_;
161170
std::optional<TokenMap> token_map_;

include/pytorch/tokenizers/hf_tokenizer.h

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,155 @@
1818
// Local
1919
#include <pytorch/tokenizers/bpe_tokenizer_base.h>
2020
#include <pytorch/tokenizers/error.h>
21+
#include <pytorch/tokenizers/normalizer.h>
2122
#include <pytorch/tokenizers/pre_tokenizer.h>
2223
#include <pytorch/tokenizers/result.h>
2324
#include <pytorch/tokenizers/token_decoder.h>
2425

2526
namespace tokenizers {
27+
namespace detail {
28+
29+
// Hash function for std::pair<uint64_t, uint64_t>
30+
struct PairHash {
31+
std::size_t operator()(const std::pair<uint64_t, uint64_t>& p) const {
32+
return std::hash<uint64_t>{}(p.first) ^
33+
(std::hash<uint64_t>{}(p.second) << 1);
34+
}
35+
};
36+
37+
// Type alias for BPE merge map: (token_id_1, token_id_2) -> (rank,
38+
// merged_token_id)
39+
using MergeMap = std::unordered_map<
40+
std::pair<uint64_t, uint64_t>,
41+
std::pair<uint64_t, uint64_t>,
42+
PairHash>;
43+
44+
// Utility function to build merge ranks map from merge rules
45+
template <typename TMergeMap>
46+
inline Result<TokenMap> build_merge_ranks_map(
47+
const TMergeMap& merge_map,
48+
const TokenMap& token_map) {
49+
// Static assertions to verify TMergeMap has the expected key and value types
50+
using KeyType = typename TMergeMap::key_type;
51+
using ValueType = typename TMergeMap::mapped_type;
52+
53+
static_assert(
54+
std::is_same_v<KeyType, std::pair<uint64_t, uint64_t>>,
55+
"TMergeMap key type must be std::pair<uint64_t, uint64_t>");
56+
57+
static_assert(
58+
std::is_same_v<ValueType, std::pair<uint64_t, uint64_t>>,
59+
"TMergeMap value type must be std::pair<uint64_t, uint64_t>");
60+
61+
// Use a map to handle duplicates - keep the lowest rank (highest priority)
62+
std::unordered_map<std::string, uint64_t> unique_merge_ranks;
63+
64+
for (const auto& [pair, rank_and_id] : merge_map) {
65+
uint64_t first_id = pair.first;
66+
uint64_t second_id = pair.second;
67+
uint64_t rank = rank_and_id.first;
68+
69+
// Get the token strings for the pair
70+
auto first_token = token_map.tryGetString(first_id);
71+
auto second_token = token_map.tryGetString(second_id);
72+
73+
if (first_token && second_token) {
74+
std::string merged_token =
75+
std::string(*first_token) + std::string(*second_token);
76+
77+
// Keep the entry with the lowest rank (highest priority in BPE)
78+
auto it = unique_merge_ranks.find(merged_token);
79+
if (it == unique_merge_ranks.end() || rank < it->second) {
80+
unique_merge_ranks[merged_token] = rank;
81+
}
82+
}
83+
}
84+
85+
// Convert to vector for buildTokenMap
86+
std::vector<std::pair<std::string, uint64_t>> merge_rank_pairs;
87+
merge_rank_pairs.reserve(unique_merge_ranks.size());
88+
89+
for (const auto& [token, rank] : unique_merge_ranks) {
90+
merge_rank_pairs.emplace_back(token, rank);
91+
}
92+
93+
return build_token_map(std::move(merge_rank_pairs));
94+
}
95+
96+
} // namespace detail
97+
98+
// Simple Word structure to mimic Rust's Word behavior
99+
struct HFWord {
100+
std::vector<uint64_t> tokens;
101+
std::vector<size_t> byte_lengths;
102+
103+
void add(uint64_t token_id, size_t byte_len) {
104+
tokens.push_back(token_id);
105+
byte_lengths.push_back(byte_len);
106+
}
107+
108+
size_t size() const {
109+
return tokens.size();
110+
}
111+
112+
// Apply all possible merges using the merge ranks
113+
void merge_all(
114+
const detail::TokenMap& merge_ranks,
115+
const detail::TokenMap& token_map) {
116+
while (tokens.size() > 1) {
117+
std::optional<std::pair<size_t, uint32_t>> best_merge;
118+
119+
// Find the best merge (lowest rank) among adjacent token pairs
120+
for (size_t i = 0; i < tokens.size() - 1; ++i) {
121+
// Create the merged token string to look up its rank
122+
auto first_token = token_map.tryGetString(tokens[i]);
123+
auto second_token = token_map.tryGetString(tokens[i + 1]);
124+
125+
if (first_token && second_token) {
126+
std::string merged_token =
127+
std::string(*first_token) + std::string(*second_token);
128+
auto rank = merge_ranks.tryGetInteger(merged_token);
129+
130+
if (rank && (!best_merge || *rank < best_merge->second)) {
131+
best_merge = std::make_pair(i, static_cast<uint32_t>(*rank));
132+
}
133+
}
134+
}
135+
136+
if (!best_merge) {
137+
break; // No more merges possible
138+
}
139+
140+
// Apply the best merge
141+
size_t merge_idx = best_merge->first;
142+
143+
// Get the merged token ID
144+
auto first_token = token_map.tryGetString(tokens[merge_idx]);
145+
auto second_token = token_map.tryGetString(tokens[merge_idx + 1]);
146+
147+
if (first_token && second_token) {
148+
std::string merged_token =
149+
std::string(*first_token) + std::string(*second_token);
150+
auto merged_id = token_map.tryGetInteger(merged_token);
151+
152+
if (merged_id) {
153+
// Replace the two tokens with the merged token
154+
tokens[merge_idx] = *merged_id;
155+
byte_lengths[merge_idx] += byte_lengths[merge_idx + 1];
156+
157+
// Remove the second token
158+
tokens.erase(tokens.begin() + merge_idx + 1);
159+
byte_lengths.erase(byte_lengths.begin() + merge_idx + 1);
160+
} else {
161+
break; // Merged token not found in vocabulary
162+
}
163+
} else {
164+
break; // Original tokens not found in vocabulary
165+
}
166+
}
167+
}
168+
};
169+
26170
class HFTokenizer : public detail::BPETokenizerBase {
27171
public:
28172
/*-- Public Interface --*/
@@ -46,8 +190,25 @@ class HFTokenizer : public detail::BPETokenizerBase {
46190

47191
void _decode(const std::string& input, std::string& ret) const override;
48192

193+
Result<std::vector<uint64_t>> byte_pair_encode_(
194+
const std::string& piece,
195+
const detail::TokenMap& encoder) const override;
196+
197+
// Override the virtual _byte_pair_merge method to use explicit merges
198+
// specified in tokenizer.json. Different from Tiktoken (another user of
199+
// BPETokenizerBase, but doesn't use explicit merge rules).
200+
std::vector<uint64_t> _byte_pair_merge(
201+
const std::string& piece,
202+
const detail::TokenMap& ranks,
203+
std::function<uint64_t(uint64_t, uint64_t)> func) const override;
204+
205+
Normalizer::Ptr _normalizer;
49206
PreTokenizer::Ptr _pretokenizer;
50207
TokenDecoder::Ptr _decoder;
208+
209+
std::unique_ptr<detail::MergeMap> merge_map_;
210+
std::optional<detail::TokenMap>
211+
merge_ranks_; // Pre-computed merge ranks for BPE
51212
};
52213

53214
} // namespace tokenizers

include/pytorch/tokenizers/llama2c_tokenizer.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,15 @@ class Llama2cTokenizer : public Tokenizer {
2828
const override;
2929

3030
private:
31+
inline Error _decode_verify(uint64_t token) const {
32+
if (!initialized_) {
33+
return Error::Uninitialized;
34+
}
35+
if (token >= vocab_size_) {
36+
return Error::OutOfRange;
37+
}
38+
return Error::Ok;
39+
}
3140
std::unique_ptr<char*[]> vocab_ = nullptr;
3241
std::unique_ptr<float[]> vocab_scores_ = nullptr;
3342
std::unique_ptr<TokenIndex[]> sorted_vocab_ = nullptr;

0 commit comments

Comments
 (0)