Skip to content

Commit 712f1d5

Browse files
committed
[Grammar] Upgrade xgrammar to latest version
- upgrade xgrammar calling to latest API
1 parent bd72d21 commit 712f1d5

File tree

2 files changed

+10
-8
lines changed

2 files changed

+10
-8
lines changed

cpp/serve/engine.cc

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -463,9 +463,11 @@ class EngineImpl : public Engine {
463463
ModelWorkspace{model->AllocEmbeddingTensor(), model->AllocHiddenStatesTensor()});
464464
}
465465
// - Initialize tokenizer and grammar
466+
466467
n->tokenizer_ = Tokenizer::FromPath(engine_config->model, GetTokenizerInfo(model_configs[0]));
467468
n->token_table_ = n->tokenizer_->PostProcessedTokenTable();
468-
n->cached_grammar_compiler_ = xgrammar::CachedGrammarCompiler(n->token_table_);
469+
// TODO: check 'vocab_size' of TokenizerInfo
470+
n->grammar_compiler_ = xgrammar::GrammarCompiler(xgrammar::TokenizerInfo(n->token_table_));
469471
// - Create the logit processor and sampler, and
470472
// the DraftTokenWorkspaceManager for speculative decoding.
471473
int max_num_tokens = engine_config->max_num_sequence;
@@ -975,13 +977,13 @@ class EngineImpl : public Engine {
975977
* is not JSON, return std::nullopt. */
976978
std::optional<xgrammar::CompiledGrammar> GetGrammarFromResponseFormat(
977979
const ResponseFormat& response_format) {
980+
// TODO: add other grammar type
978981
if (response_format.type != "json_object") {
979982
return std::nullopt;
980983
} else if (!response_format.schema) {
981-
return cached_grammar_compiler_.GetCompiledGrammarForJSON();
984+
return grammar_compiler_.CompileBuiltinJSONGrammar();
982985
} else {
983-
return cached_grammar_compiler_.GetCompiledGrammarForJSONSchema(
984-
response_format.schema.value());
986+
return grammar_compiler_.CompileJSONSchema(response_format.schema.value());
985987
}
986988
}
987989

@@ -992,8 +994,8 @@ class EngineImpl : public Engine {
992994
// internal tokenizer
993995
Tokenizer tokenizer_;
994996
std::vector<std::string> token_table_;
995-
// Cached grammar compiler for grammar matching.
996-
xgrammar::CachedGrammarCompiler cached_grammar_compiler_;
997+
// Grammar compiler for grammar matching.
998+
xgrammar::GrammarCompiler grammar_compiler_;
997999
// Models
9981000
Array<Model> models_;
9991001
// Device that the models run on.

cpp/serve/request_state.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ RequestModelState::RequestModelState(
2424
if (compiled_grammar.has_value()) {
2525
// TODO(yixin): set rollback limit to a configurable value.
2626
n->grammar_matcher =
27-
xgrammar::GrammarMatcher(compiled_grammar.value(), std::nullopt, false, std::nullopt, 10);
27+
xgrammar::GrammarMatcher(compiled_grammar.value(), std::nullopt, false, 10);
2828
}
2929

3030
n->request = std::move(request);
@@ -44,7 +44,7 @@ bool RequestModelStateNode::RequireNextTokenBitmask() { return grammar_matcher.h
4444
void RequestModelStateNode::GetNextTokenBitmask(DLTensor* bitmask) {
4545
ICHECK(grammar_matcher.has_value());
4646

47-
grammar_matcher->GetNextTokenBitmask(bitmask);
47+
grammar_matcher->FillNextTokenBitmask(bitmask);
4848
}
4949

5050
void RequestModelStateNode::CommitToken(SampleResult sampled_token) {

0 commit comments

Comments
 (0)