@@ -1651,12 +1651,13 @@ struct llama_mlock {
1651
1651
};
1652
1652
using llama_mlocks = std::vector<std::unique_ptr<llama_mlock>>;
1653
1653
1654
- static std::string llama_token_to_piece(const struct llama_context * ctx, llama_token token, bool special) {
1654
+ // NOTE: avoid ever using this except for building the token_to_piece caches
1655
+ static std::string llama_token_to_piece(const struct llama_model * model, llama_token token, bool special) {
1655
1656
std::vector<char> result(8, 0);
1656
- const int n_tokens = llama_token_to_piece(llama_get_model(ctx) , token, result.data(), result.size(), special);
1657
+ const int n_tokens = llama_token_to_piece(model , token, result.data(), result.size(), special);
1657
1658
if (n_tokens < 0) {
1658
1659
result.resize(-n_tokens);
1659
- int check = llama_token_to_piece(llama_get_model(ctx) , token, result.data(), result.size(), special);
1660
+ int check = llama_token_to_piece(model , token, result.data(), result.size(), special);
1660
1661
GGML_ASSERT(check == -n_tokens);
1661
1662
}
1662
1663
else {
@@ -2086,7 +2087,11 @@ struct llama_vocab {
2086
2087
std::unordered_map<token, id> token_to_id;
2087
2088
std::vector<token_data> id_to_token;
2088
2089
2089
- std::unordered_map<token, id> special_tokens_cache;
2090
+ bool has_cache = false;
2091
+
2092
+ std::unordered_map<token, id> cache_special_tokens;
2093
+ std::unordered_map<id, token> cache_token_to_piece; // llama_token_to_piece(special = false);
2094
+ std::unordered_map<id, token> cache_token_to_piece_special; // llama_token_to_piece(special = true);
2090
2095
2091
2096
std::map<std::pair<std::string, std::string>, int> bpe_ranks;
2092
2097
@@ -4789,7 +4794,7 @@ static void llm_load_vocab(
4789
4794
// And skip the ones which are one character
4790
4795
if (utf8_str_len > 1) {
4791
4796
// At this point what we have left are special tokens only
4792
- vocab.special_tokens_cache [token] = id;
4797
+ vocab.cache_special_tokens [token] = id;
4793
4798
4794
4799
// Count manually found special tokens
4795
4800
special_tokens_count_from_verification++;
@@ -4816,6 +4821,13 @@ static void llm_load_vocab(
4816
4821
);
4817
4822
}
4818
4823
}
4824
+
4825
+ for (llama_token id = 0; id < (llama_token) n_vocab; ++id) {
4826
+ vocab.cache_token_to_piece[id] = llama_token_to_piece(&model, id, false);
4827
+ vocab.cache_token_to_piece_special[id] = llama_token_to_piece(&model, id, true);
4828
+ }
4829
+
4830
+ vocab.has_cache = true;
4819
4831
}
4820
4832
4821
4833
static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
@@ -12898,7 +12910,7 @@ struct fragment_buffer_variant {
12898
12910
12899
12911
static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<fragment_buffer_variant> & buffer) {
12900
12912
// for each special token
12901
- for (const auto & st: vocab.special_tokens_cache ) {
12913
+ for (const auto & st: vocab.cache_special_tokens ) {
12902
12914
const auto & special_token = st.first;
12903
12915
const auto & special_id = st.second;
12904
12916
@@ -14058,7 +14070,7 @@ void llama_sample_repetition_penalties(
14058
14070
14059
14071
void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * candidates, const struct llama_grammar * grammar) {
14060
14072
GGML_ASSERT(ctx);
14061
- const int64_t t_start_sample_us = ggml_time_us();
14073
+ int64_t t_start_sample_us = ggml_time_us();
14062
14074
14063
14075
bool allow_eog = false;
14064
14076
for (const auto & stack : grammar->stacks) {
@@ -14074,8 +14086,8 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c
14074
14086
candidates_grammar.reserve(candidates->size);
14075
14087
14076
14088
for (size_t i = 0; i < candidates->size; ++i) {
14077
- const llama_token id = candidates->data[i].id;
14078
- const std::string piece = llama_token_to_piece( ctx, id, false );
14089
+ const llama_token id = candidates->data[i].id;
14090
+ const std::string & piece = ctx->model.vocab.cache_token_to_piece.at(id );
14079
14091
14080
14092
if (llama_token_is_eog(&ctx->model, id)) {
14081
14093
if (!allow_eog) {
@@ -14275,7 +14287,7 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar
14275
14287
GGML_ASSERT(false);
14276
14288
}
14277
14289
14278
- const std::string piece = llama_token_to_piece( ctx, token, false );
14290
+ const std::string & piece = ctx->model.vocab.cache_token_to_piece.at( token);
14279
14291
14280
14292
// Note terminating 0 in decoded string
14281
14293
const auto decoded = decode_utf8(piece, grammar->partial_utf8);
@@ -17957,69 +17969,79 @@ static std::string llama_decode_text(const std::string & text) {
17957
17969
17958
17970
// does not write null-terminator to buf
17959
17971
int32_t llama_token_to_piece(const struct llama_model * model, llama_token token, char * buf, int32_t length, bool special) {
17972
+ if (model->vocab.has_cache) {
17973
+ const auto & cache = special ? model->vocab.cache_token_to_piece_special : model->vocab.cache_token_to_piece;
17974
+ const auto & res = cache.at(token);
17975
+ if (length < (int) res.size()) {
17976
+ return -(int) res.size();
17977
+ }
17978
+ memcpy(buf, res.c_str(), res.size());
17979
+ return res.size();
17980
+ }
17981
+
17960
17982
if (0 <= token && token < llama_n_vocab(model)) {
17961
17983
switch (llama_vocab_get_type(model->vocab)) {
17962
- case LLAMA_VOCAB_TYPE_WPM:
17963
- case LLAMA_VOCAB_TYPE_SPM: {
17964
- // NOTE: we accept all unsupported token types,
17965
- // suppressing them like CONTROL tokens.
17966
- if (llama_is_normal_token(model->vocab, token)) {
17967
- std::string result = model->vocab.id_to_token[token].text;
17968
- llama_unescape_whitespace(result);
17969
- if (length < (int) result.length()) {
17970
- return -(int) result.length();
17971
- }
17972
- memcpy(buf, result.c_str(), result.length());
17973
- return result.length();
17974
- } else if (
17975
- (llama_is_user_defined_token(model->vocab, token)) ||
17976
- (llama_is_control_token (model->vocab, token) && special)) {
17977
- std::string result = model->vocab.id_to_token[token].text;
17978
- if (length < (int) result.length()) {
17979
- return -(int) result.length();
17980
- }
17981
- memcpy(buf, result.c_str(), result.length());
17982
- return result.length();
17983
- } else if (llama_is_unknown_token(model->vocab, token)) { // NOLINT
17984
- if (length < 3) {
17985
- return -3;
17986
- }
17987
- memcpy(buf, "\xe2\x96\x85", 3);
17988
- return 3;
17989
- } else if (llama_is_byte_token(model->vocab, token)) {
17990
- if (length < 1) {
17991
- return -1;
17984
+ case LLAMA_VOCAB_TYPE_WPM:
17985
+ case LLAMA_VOCAB_TYPE_SPM: {
17986
+ // NOTE: we accept all unsupported token types,
17987
+ // suppressing them like CONTROL tokens.
17988
+ if (llama_is_normal_token(model->vocab, token)) {
17989
+ std::string result = model->vocab.id_to_token[token].text;
17990
+ llama_unescape_whitespace(result);
17991
+ if (length < (int) result.length()) {
17992
+ return -(int) result.length();
17993
+ }
17994
+ memcpy(buf, result.c_str(), result.length());
17995
+ return result.length();
17996
+ } else if (
17997
+ (llama_is_user_defined_token(model->vocab, token)) ||
17998
+ (llama_is_control_token (model->vocab, token) && special)) {
17999
+ std::string result = model->vocab.id_to_token[token].text;
18000
+ if (length < (int) result.length()) {
18001
+ return -(int) result.length();
18002
+ }
18003
+ memcpy(buf, result.c_str(), result.length());
18004
+ return result.length();
18005
+ } else if (llama_is_unknown_token(model->vocab, token)) { // NOLINT
18006
+ if (length < 3) {
18007
+ return -3;
18008
+ }
18009
+ memcpy(buf, "\xe2\x96\x85", 3);
18010
+ return 3;
18011
+ } else if (llama_is_byte_token(model->vocab, token)) {
18012
+ if (length < 1) {
18013
+ return -1;
18014
+ }
18015
+ buf[0] = llama_token_to_byte(model->vocab, token);
18016
+ return 1;
17992
18017
}
17993
- buf[0] = llama_token_to_byte(model->vocab, token);
17994
- return 1;
18018
+ break;
17995
18019
}
17996
- break;
17997
- }
17998
- case LLAMA_VOCAB_TYPE_BPE: {
17999
- // NOTE: we accept all unsupported token types,
18000
- // suppressing them like CONTROL tokens.
18001
- if (llama_is_normal_token(model->vocab, token)) {
18002
- std::string result = model->vocab.id_to_token[token].text;
18003
- result = llama_decode_text(result);
18004
- if (length < (int) result.length()) {
18005
- return -(int) result.length();
18006
- }
18007
- memcpy(buf, result.c_str(), result.length());
18008
- return result.length();
18009
- } else if (
18010
- (llama_is_user_defined_token(model->vocab, token)) ||
18011
- (llama_is_control_token (model->vocab, token) && special)) {
18012
- std::string result = model->vocab.id_to_token[token].text;
18013
- if (length < (int) result.length()) {
18014
- return -(int) result.length();
18020
+ case LLAMA_VOCAB_TYPE_BPE: {
18021
+ // NOTE: we accept all unsupported token types,
18022
+ // suppressing them like CONTROL tokens.
18023
+ if (llama_is_normal_token(model->vocab, token)) {
18024
+ std::string result = model->vocab.id_to_token[token].text;
18025
+ result = llama_decode_text(result);
18026
+ if (length < (int) result.length()) {
18027
+ return -(int) result.length();
18028
+ }
18029
+ memcpy(buf, result.c_str(), result.length());
18030
+ return result.length();
18031
+ } else if (
18032
+ (llama_is_user_defined_token(model->vocab, token)) ||
18033
+ (llama_is_control_token (model->vocab, token) && special)) {
18034
+ std::string result = model->vocab.id_to_token[token].text;
18035
+ if (length < (int) result.length()) {
18036
+ return -(int) result.length();
18037
+ }
18038
+ memcpy(buf, result.c_str(), result.length());
18039
+ return result.length();
18015
18040
}
18016
- memcpy(buf, result.c_str(), result.length());
18017
- return result.length();
18041
+ break;
18018
18042
}
18019
- break;
18020
- }
18021
- default:
18022
- GGML_ASSERT(false);
18043
+ default:
18044
+ GGML_ASSERT(false);
18023
18045
}
18024
18046
}
18025
18047
return 0;
0 commit comments