Skip to content

Commit 3e5d281

Browse files
committed
llama : cache llama_token_to_piece
ggml-ci
1 parent 8b99e2a commit 3e5d281

File tree

2 files changed

+91
-69
lines changed

2 files changed

+91
-69
lines changed

llama.cpp

Lines changed: 89 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1651,12 +1651,13 @@ struct llama_mlock {
16511651
};
16521652
using llama_mlocks = std::vector<std::unique_ptr<llama_mlock>>;
16531653

1654-
static std::string llama_token_to_piece(const struct llama_context * ctx, llama_token token, bool special) {
1654+
// NOTE: avoid ever using this except for building the token_to_piece caches
1655+
static std::string llama_token_to_piece(const struct llama_model * model, llama_token token, bool special) {
16551656
std::vector<char> result(8, 0);
1656-
const int n_tokens = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size(), special);
1657+
const int n_tokens = llama_token_to_piece(model, token, result.data(), result.size(), special);
16571658
if (n_tokens < 0) {
16581659
result.resize(-n_tokens);
1659-
int check = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size(), special);
1660+
int check = llama_token_to_piece(model, token, result.data(), result.size(), special);
16601661
GGML_ASSERT(check == -n_tokens);
16611662
}
16621663
else {
@@ -2086,7 +2087,11 @@ struct llama_vocab {
20862087
std::unordered_map<token, id> token_to_id;
20872088
std::vector<token_data> id_to_token;
20882089

2089-
std::unordered_map<token, id> special_tokens_cache;
2090+
bool has_cache = false;
2091+
2092+
std::unordered_map<token, id> cache_special_tokens;
2093+
std::unordered_map<id, token> cache_token_to_piece; // llama_token_to_piece(special = false);
2094+
std::unordered_map<id, token> cache_token_to_piece_special; // llama_token_to_piece(special = true);
20902095

20912096
std::map<std::pair<std::string, std::string>, int> bpe_ranks;
20922097

@@ -4789,7 +4794,7 @@ static void llm_load_vocab(
47894794
// And skip the ones which are one character
47904795
if (utf8_str_len > 1) {
47914796
// At this point what we have left are special tokens only
4792-
vocab.special_tokens_cache[token] = id;
4797+
vocab.cache_special_tokens[token] = id;
47934798

47944799
// Count manually found special tokens
47954800
special_tokens_count_from_verification++;
@@ -4816,6 +4821,13 @@ static void llm_load_vocab(
48164821
);
48174822
}
48184823
}
4824+
4825+
for (llama_token id = 0; id < (llama_token) n_vocab; ++id) {
4826+
vocab.cache_token_to_piece[id] = llama_token_to_piece(&model, id, false);
4827+
vocab.cache_token_to_piece_special[id] = llama_token_to_piece(&model, id, true);
4828+
}
4829+
4830+
vocab.has_cache = true;
48194831
}
48204832

48214833
static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
@@ -12898,7 +12910,7 @@ struct fragment_buffer_variant {
1289812910

1289912911
static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<fragment_buffer_variant> & buffer) {
1290012912
// for each special token
12901-
for (const auto & st: vocab.special_tokens_cache) {
12913+
for (const auto & st: vocab.cache_special_tokens) {
1290212914
const auto & special_token = st.first;
1290312915
const auto & special_id = st.second;
1290412916

@@ -14058,7 +14070,7 @@ void llama_sample_repetition_penalties(
1405814070

1405914071
void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * candidates, const struct llama_grammar * grammar) {
1406014072
GGML_ASSERT(ctx);
14061-
const int64_t t_start_sample_us = ggml_time_us();
14073+
int64_t t_start_sample_us = ggml_time_us();
1406214074

1406314075
bool allow_eog = false;
1406414076
for (const auto & stack : grammar->stacks) {
@@ -14074,8 +14086,8 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c
1407414086
candidates_grammar.reserve(candidates->size);
1407514087

1407614088
for (size_t i = 0; i < candidates->size; ++i) {
14077-
const llama_token id = candidates->data[i].id;
14078-
const std::string piece = llama_token_to_piece(ctx, id, false);
14089+
const llama_token id = candidates->data[i].id;
14090+
const std::string & piece = ctx->model.vocab.cache_token_to_piece.at(id);
1407914091

1408014092
if (llama_token_is_eog(&ctx->model, id)) {
1408114093
if (!allow_eog) {
@@ -14275,7 +14287,7 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar
1427514287
GGML_ASSERT(false);
1427614288
}
1427714289

14278-
const std::string piece = llama_token_to_piece(ctx, token, false);
14290+
const std::string & piece = ctx->model.vocab.cache_token_to_piece.at(token);
1427914291

1428014292
// Note terminating 0 in decoded string
1428114293
const auto decoded = decode_utf8(piece, grammar->partial_utf8);
@@ -17957,69 +17969,79 @@ static std::string llama_decode_text(const std::string & text) {
1795717969

1795817970
// does not write null-terminator to buf
1795917971
int32_t llama_token_to_piece(const struct llama_model * model, llama_token token, char * buf, int32_t length, bool special) {
17972+
if (model->vocab.has_cache) {
17973+
const auto & cache = special ? model->vocab.cache_token_to_piece_special : model->vocab.cache_token_to_piece;
17974+
const auto & res = cache.at(token);
17975+
if (length < (int) res.size()) {
17976+
return -(int) res.size();
17977+
}
17978+
memcpy(buf, res.c_str(), res.size());
17979+
return res.size();
17980+
}
17981+
1796017982
if (0 <= token && token < llama_n_vocab(model)) {
1796117983
switch (llama_vocab_get_type(model->vocab)) {
17962-
case LLAMA_VOCAB_TYPE_WPM:
17963-
case LLAMA_VOCAB_TYPE_SPM: {
17964-
// NOTE: we accept all unsupported token types,
17965-
// suppressing them like CONTROL tokens.
17966-
if (llama_is_normal_token(model->vocab, token)) {
17967-
std::string result = model->vocab.id_to_token[token].text;
17968-
llama_unescape_whitespace(result);
17969-
if (length < (int) result.length()) {
17970-
return -(int) result.length();
17971-
}
17972-
memcpy(buf, result.c_str(), result.length());
17973-
return result.length();
17974-
} else if (
17975-
(llama_is_user_defined_token(model->vocab, token)) ||
17976-
(llama_is_control_token (model->vocab, token) && special)) {
17977-
std::string result = model->vocab.id_to_token[token].text;
17978-
if (length < (int) result.length()) {
17979-
return -(int) result.length();
17980-
}
17981-
memcpy(buf, result.c_str(), result.length());
17982-
return result.length();
17983-
} else if (llama_is_unknown_token(model->vocab, token)) { // NOLINT
17984-
if (length < 3) {
17985-
return -3;
17986-
}
17987-
memcpy(buf, "\xe2\x96\x85", 3);
17988-
return 3;
17989-
} else if (llama_is_byte_token(model->vocab, token)) {
17990-
if (length < 1) {
17991-
return -1;
17984+
case LLAMA_VOCAB_TYPE_WPM:
17985+
case LLAMA_VOCAB_TYPE_SPM: {
17986+
// NOTE: we accept all unsupported token types,
17987+
// suppressing them like CONTROL tokens.
17988+
if (llama_is_normal_token(model->vocab, token)) {
17989+
std::string result = model->vocab.id_to_token[token].text;
17990+
llama_unescape_whitespace(result);
17991+
if (length < (int) result.length()) {
17992+
return -(int) result.length();
17993+
}
17994+
memcpy(buf, result.c_str(), result.length());
17995+
return result.length();
17996+
} else if (
17997+
(llama_is_user_defined_token(model->vocab, token)) ||
17998+
(llama_is_control_token (model->vocab, token) && special)) {
17999+
std::string result = model->vocab.id_to_token[token].text;
18000+
if (length < (int) result.length()) {
18001+
return -(int) result.length();
18002+
}
18003+
memcpy(buf, result.c_str(), result.length());
18004+
return result.length();
18005+
} else if (llama_is_unknown_token(model->vocab, token)) { // NOLINT
18006+
if (length < 3) {
18007+
return -3;
18008+
}
18009+
memcpy(buf, "\xe2\x96\x85", 3);
18010+
return 3;
18011+
} else if (llama_is_byte_token(model->vocab, token)) {
18012+
if (length < 1) {
18013+
return -1;
18014+
}
18015+
buf[0] = llama_token_to_byte(model->vocab, token);
18016+
return 1;
1799218017
}
17993-
buf[0] = llama_token_to_byte(model->vocab, token);
17994-
return 1;
18018+
break;
1799518019
}
17996-
break;
17997-
}
17998-
case LLAMA_VOCAB_TYPE_BPE: {
17999-
// NOTE: we accept all unsupported token types,
18000-
// suppressing them like CONTROL tokens.
18001-
if (llama_is_normal_token(model->vocab, token)) {
18002-
std::string result = model->vocab.id_to_token[token].text;
18003-
result = llama_decode_text(result);
18004-
if (length < (int) result.length()) {
18005-
return -(int) result.length();
18006-
}
18007-
memcpy(buf, result.c_str(), result.length());
18008-
return result.length();
18009-
} else if (
18010-
(llama_is_user_defined_token(model->vocab, token)) ||
18011-
(llama_is_control_token (model->vocab, token) && special)) {
18012-
std::string result = model->vocab.id_to_token[token].text;
18013-
if (length < (int) result.length()) {
18014-
return -(int) result.length();
18020+
case LLAMA_VOCAB_TYPE_BPE: {
18021+
// NOTE: we accept all unsupported token types,
18022+
// suppressing them like CONTROL tokens.
18023+
if (llama_is_normal_token(model->vocab, token)) {
18024+
std::string result = model->vocab.id_to_token[token].text;
18025+
result = llama_decode_text(result);
18026+
if (length < (int) result.length()) {
18027+
return -(int) result.length();
18028+
}
18029+
memcpy(buf, result.c_str(), result.length());
18030+
return result.length();
18031+
} else if (
18032+
(llama_is_user_defined_token(model->vocab, token)) ||
18033+
(llama_is_control_token (model->vocab, token) && special)) {
18034+
std::string result = model->vocab.id_to_token[token].text;
18035+
if (length < (int) result.length()) {
18036+
return -(int) result.length();
18037+
}
18038+
memcpy(buf, result.c_str(), result.length());
18039+
return result.length();
1801518040
}
18016-
memcpy(buf, result.c_str(), result.length());
18017-
return result.length();
18041+
break;
1801818042
}
18019-
break;
18020-
}
18021-
default:
18022-
GGML_ASSERT(false);
18043+
default:
18044+
GGML_ASSERT(false);
1802318045
}
1802418046
}
1802518047
return 0;

llama.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -424,8 +424,8 @@ extern "C" {
424424

425425
LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx);
426426

427-
LLAMA_API enum llama_vocab_type llama_vocab_type (const struct llama_model * model);
428-
LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model);
427+
LLAMA_API enum llama_vocab_type llama_vocab_type (const struct llama_model * model);
428+
LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model);
429429

430430
LLAMA_API int32_t llama_n_vocab (const struct llama_model * model);
431431
LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model);

0 commit comments

Comments
 (0)