18
18
// Local
19
19
#include < pytorch/tokenizers/bpe_tokenizer_base.h>
20
20
#include < pytorch/tokenizers/error.h>
21
+ #include < pytorch/tokenizers/normalizer.h>
21
22
#include < pytorch/tokenizers/pre_tokenizer.h>
22
23
#include < pytorch/tokenizers/result.h>
23
24
#include < pytorch/tokenizers/token_decoder.h>
24
25
25
26
namespace tokenizers {
27
+ namespace detail {
28
+
29
+ // Hash function for std::pair<uint64_t, uint64_t>
30
+ struct PairHash {
31
+ std::size_t operator ()(const std::pair<uint64_t , uint64_t >& p) const {
32
+ return std::hash<uint64_t >{}(p.first ) ^
33
+ (std::hash<uint64_t >{}(p.second ) << 1 );
34
+ }
35
+ };
36
+
37
+ // Type alias for BPE merge map: (token_id_1, token_id_2) -> (rank,
38
+ // merged_token_id)
39
+ using MergeMap = std::unordered_map<
40
+ std::pair<uint64_t , uint64_t >,
41
+ std::pair<uint64_t , uint64_t >,
42
+ PairHash>;
43
+
44
+ // Utility function to build merge ranks map from merge rules
45
+ template <typename TMergeMap>
46
+ inline Result<TokenMap> build_merge_ranks_map (
47
+ const TMergeMap& merge_map,
48
+ const TokenMap& token_map) {
49
+ // Static assertions to verify TMergeMap has the expected key and value types
50
+ using KeyType = typename TMergeMap::key_type;
51
+ using ValueType = typename TMergeMap::mapped_type;
52
+
53
+ static_assert (
54
+ std::is_same_v<KeyType, std::pair<uint64_t , uint64_t >>,
55
+ " TMergeMap key type must be std::pair<uint64_t, uint64_t>" );
56
+
57
+ static_assert (
58
+ std::is_same_v<ValueType, std::pair<uint64_t , uint64_t >>,
59
+ " TMergeMap value type must be std::pair<uint64_t, uint64_t>" );
60
+
61
+ // Use a map to handle duplicates - keep the lowest rank (highest priority)
62
+ std::unordered_map<std::string, uint64_t > unique_merge_ranks;
63
+
64
+ for (const auto & [pair, rank_and_id] : merge_map) {
65
+ uint64_t first_id = pair.first ;
66
+ uint64_t second_id = pair.second ;
67
+ uint64_t rank = rank_and_id.first ;
68
+
69
+ // Get the token strings for the pair
70
+ auto first_token = token_map.tryGetString (first_id);
71
+ auto second_token = token_map.tryGetString (second_id);
72
+
73
+ if (first_token && second_token) {
74
+ std::string merged_token =
75
+ std::string (*first_token) + std::string (*second_token);
76
+
77
+ // Keep the entry with the lowest rank (highest priority in BPE)
78
+ auto it = unique_merge_ranks.find (merged_token);
79
+ if (it == unique_merge_ranks.end () || rank < it->second ) {
80
+ unique_merge_ranks[merged_token] = rank;
81
+ }
82
+ }
83
+ }
84
+
85
+ // Convert to vector for buildTokenMap
86
+ std::vector<std::pair<std::string, uint64_t >> merge_rank_pairs;
87
+ merge_rank_pairs.reserve (unique_merge_ranks.size ());
88
+
89
+ for (const auto & [token, rank] : unique_merge_ranks) {
90
+ merge_rank_pairs.emplace_back (token, rank);
91
+ }
92
+
93
+ return build_token_map (std::move (merge_rank_pairs));
94
+ }
95
+
96
+ } // namespace detail
97
+
98
+ // Simple Word structure to mimic Rust's Word behavior
99
+ struct HFWord {
100
+ std::vector<uint64_t > tokens;
101
+ std::vector<size_t > byte_lengths;
102
+
103
+ void add (uint64_t token_id, size_t byte_len) {
104
+ tokens.push_back (token_id);
105
+ byte_lengths.push_back (byte_len);
106
+ }
107
+
108
+ size_t size () const {
109
+ return tokens.size ();
110
+ }
111
+
112
+ // Apply all possible merges using the merge ranks
113
+ void merge_all (
114
+ const detail::TokenMap& merge_ranks,
115
+ const detail::TokenMap& token_map) {
116
+ while (tokens.size () > 1 ) {
117
+ std::optional<std::pair<size_t , uint32_t >> best_merge;
118
+
119
+ // Find the best merge (lowest rank) among adjacent token pairs
120
+ for (size_t i = 0 ; i < tokens.size () - 1 ; ++i) {
121
+ // Create the merged token string to look up its rank
122
+ auto first_token = token_map.tryGetString (tokens[i]);
123
+ auto second_token = token_map.tryGetString (tokens[i + 1 ]);
124
+
125
+ if (first_token && second_token) {
126
+ std::string merged_token =
127
+ std::string (*first_token) + std::string (*second_token);
128
+ auto rank = merge_ranks.tryGetInteger (merged_token);
129
+
130
+ if (rank && (!best_merge || *rank < best_merge->second )) {
131
+ best_merge = std::make_pair (i, static_cast <uint32_t >(*rank));
132
+ }
133
+ }
134
+ }
135
+
136
+ if (!best_merge) {
137
+ break ; // No more merges possible
138
+ }
139
+
140
+ // Apply the best merge
141
+ size_t merge_idx = best_merge->first ;
142
+
143
+ // Get the merged token ID
144
+ auto first_token = token_map.tryGetString (tokens[merge_idx]);
145
+ auto second_token = token_map.tryGetString (tokens[merge_idx + 1 ]);
146
+
147
+ if (first_token && second_token) {
148
+ std::string merged_token =
149
+ std::string (*first_token) + std::string (*second_token);
150
+ auto merged_id = token_map.tryGetInteger (merged_token);
151
+
152
+ if (merged_id) {
153
+ // Replace the two tokens with the merged token
154
+ tokens[merge_idx] = *merged_id;
155
+ byte_lengths[merge_idx] += byte_lengths[merge_idx + 1 ];
156
+
157
+ // Remove the second token
158
+ tokens.erase (tokens.begin () + merge_idx + 1 );
159
+ byte_lengths.erase (byte_lengths.begin () + merge_idx + 1 );
160
+ } else {
161
+ break ; // Merged token not found in vocabulary
162
+ }
163
+ } else {
164
+ break ; // Original tokens not found in vocabulary
165
+ }
166
+ }
167
+ }
168
+ };
169
+
26
170
class HFTokenizer : public detail ::BPETokenizerBase {
27
171
public:
28
172
/* -- Public Interface --*/
@@ -46,8 +190,25 @@ class HFTokenizer : public detail::BPETokenizerBase {
46
190
47
191
void _decode (const std::string& input, std::string& ret) const override ;
48
192
193
+ Result<std::vector<uint64_t >> byte_pair_encode_ (
194
+ const std::string& piece,
195
+ const detail::TokenMap& encoder) const override ;
196
+
197
+ // Override the virtual _byte_pair_merge method to use explicit merges
198
+ // specified in tokenizer.json. Different from Tiktoken (another user of
199
+ // BPETokenizerBase, but doesn't use explicit merge rules).
200
+ std::vector<uint64_t > _byte_pair_merge (
201
+ const std::string& piece,
202
+ const detail::TokenMap& ranks,
203
+ std::function<uint64_t (uint64_t , uint64_t )> func) const override ;
204
+
205
+ Normalizer::Ptr _normalizer;
49
206
PreTokenizer::Ptr _pretokenizer;
50
207
TokenDecoder::Ptr _decoder;
208
+
209
+ std::unique_ptr<detail::MergeMap> merge_map_;
210
+ std::optional<detail::TokenMap>
211
+ merge_ranks_; // Pre-computed merge ranks for BPE
51
212
};
52
213
53
214
} // namespace tokenizers
0 commit comments