Skip to content

Commit 13c0661

Browse files
authored
[Tokenizer] Auto-detect TokenizerInfo from tokenizer.json (#2416)
This PR adds a new `TokenizerInfo` class that contains useful information about the tokenizer during generation. It is auto-detected from tokenizer.json if it exists. Otherwise it raises a warning and uses the default value (byte fallback tokenizer, not prepend/strip space).
1 parent c62e143 commit 13c0661

File tree

12 files changed

+376
-156
lines changed

12 files changed

+376
-156
lines changed

cpp/serve/data.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ TVM_REGISTER_GLOBAL("mlc.serve.ImageDataGetImage").set_body_typed([](ImageData d
109109
/*! \brief Convert a single token with probability to JSON string. */
110110
inline void TokenToLogProbJSON(const Tokenizer& tokenizer, const TokenProbPair& token_prob,
111111
std::ostringstream* os) {
112-
const std::string& token = tokenizer->TokenTable()[token_prob.first];
112+
const std::string& token = tokenizer->PostProcessedTokenTable()[token_prob.first];
113113

114114
(*os) << "\"token\": \"";
115115
for (char ch : token) {

cpp/serve/engine.cc

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -126,17 +126,8 @@ class EngineImpl : public Engine {
126126
ModelWorkspace{model->AllocEmbeddingTensor(), model->AllocHiddenStatesTensor()});
127127
}
128128
// - Initialize tokenizer and grammar
129-
n->tokenizer_ = Tokenizer::FromPath(engine_config->model);
130-
std::string token_table_postproc_method;
131-
if (model_configs[0].count("token_table_postproc_method") == 0) {
132-
// Backward compatibility: use "byte_fallback" by default
133-
token_table_postproc_method = "byte_fallback";
134-
} else {
135-
token_table_postproc_method =
136-
model_configs[0].at("token_table_postproc_method").get<std::string>();
137-
}
138-
n->token_table_ =
139-
Tokenizer::PostProcessTokenTable(n->tokenizer_->TokenTable(), token_table_postproc_method);
129+
n->tokenizer_ = Tokenizer::FromPath(engine_config->model, GetTokenizerInfo(model_configs[0]));
130+
n->token_table_ = n->tokenizer_->PostProcessedTokenTable();
140131
n->grammar_init_context_cache_ = GrammarInitContextCache(n->token_table_);
141132
// - Create the logit processor and sampler, and
142133
// the DraftTokenWorkspaceManager for speculative decoding.
@@ -549,6 +540,28 @@ class EngineImpl : public Engine {
549540
}
550541
}
551542

543+
static std::optional<TokenizerInfo> GetTokenizerInfo(const picojson::object& model_config) {
544+
if (model_config.count("tokenizer_info") == 0) {
545+
LOG(WARNING) << "Tokenizer info not found in mlc-chat-config.json. "
546+
<< "Trying to automatically detect the tokenizer info";
547+
return std::nullopt;
548+
}
549+
const picojson::object& tokenizer_info_obj =
550+
model_config.at("tokenizer_info").get<picojson::object>();
551+
auto info = make_object<TokenizerInfoNode>();
552+
if (tokenizer_info_obj.count("token_postproc_method")) {
553+
info->token_postproc_method =
554+
tokenizer_info_obj.at("token_postproc_method").get<std::string>();
555+
}
556+
if (tokenizer_info_obj.count("prepend_space_in_encode")) {
557+
info->prepend_space_in_encode = tokenizer_info_obj.at("prepend_space_in_encode").get<bool>();
558+
}
559+
if (tokenizer_info_obj.count("strip_space_in_decode")) {
560+
info->strip_space_in_decode = tokenizer_info_obj.at("strip_space_in_decode").get<bool>();
561+
}
562+
return TokenizerInfo(info);
563+
}
564+
552565
// Engine state, managing requests and request states.
553566
EngineState estate_;
554567
// Configurations and singletons

cpp/serve/grammar/grammar_state_matcher.cc

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -483,14 +483,12 @@ GrammarStateMatcher::GrammarStateMatcher(std::shared_ptr<GrammarStateInitContext
483483
#ifndef COMPILE_MLC_WASM_RUNTIME
484484
// This creates tokenizer dependency issue in WASM building for web, hence skipped
485485
TVM_REGISTER_GLOBAL("mlc.serve.GrammarStateMatcherFromTokenizer")
486-
.set_body_typed([](BNFGrammar grammar, Optional<Tokenizer> tokenizer, int max_rollback_steps,
487-
String token_table_postproc_method) {
486+
.set_body_typed([](BNFGrammar grammar, Optional<Tokenizer> tokenizer, int max_rollback_steps) {
488487
auto preproc_start = std::chrono::high_resolution_clock::now();
489488
std::shared_ptr<mlc::llm::serve::GrammarStateInitContext> init_ctx;
490489
if (tokenizer) {
491-
auto token_table = Tokenizer::PostProcessTokenTable(tokenizer.value()->TokenTable(),
492-
token_table_postproc_method);
493-
init_ctx = GrammarStateMatcher::CreateInitContext(grammar, token_table);
490+
init_ctx = GrammarStateMatcher::CreateInitContext(
491+
grammar, tokenizer.value()->PostProcessedTokenTable());
494492
} else {
495493
init_ctx = GrammarStateMatcher::CreateInitContext(grammar, {});
496494
}

cpp/serve/grammar/grammar_state_matcher.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@ using namespace tvm::runtime;
4040
* \example
4141
* \code
4242
* Tokenizer tokenizer = ...;
43-
* auto init_ctx = GrammarStateMatcher::CreateInitContext(grammar, tokenizer->TokenTable());
43+
* auto init_ctx = GrammarStateMatcher::CreateInitContext(grammar,
44+
* tokenizer->PostProcessedTokenTable());
4445
* GrammarStateMatcher matcher(init_ctx, 10);
4546
* matcher->AcceptToken(67);
4647
*

cpp/serve/grammar/grammar_state_matcher_base.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
#include <vector>
1010

11-
#include "../../tokenizers.h"
1211
#include "grammar.h"
1312
#include "grammar_state_matcher_state.h"
1413

cpp/streamer.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ StopStrHandler::StopStrHandler(Array<String> stop_strs,
263263

264264
TVM_REGISTER_GLOBAL("mlc.StopStrHandler")
265265
.set_body_typed([](Array<String> stop_strs, const Tokenizer& tokenizer) {
266-
return StopStrHandler(std::move(stop_strs), tokenizer->TokenTable());
266+
return StopStrHandler(std::move(stop_strs), tokenizer->PostProcessedTokenTable());
267267
});
268268

269269
TVM_REGISTER_GLOBAL("mlc.StopStrHandlerPut")

0 commit comments

Comments
 (0)