Skip to content

Commit 42f146d

Browse files
authored
[Serving][Grammar] Jump-forward decoding (#2551)
[Serve][Grammar] Jump-forward decoding This PR supports the jump-forward decoding as described in <https://lmsys.org/blog/2024-02-05-compressed-fsm/>. The jump-forward decoding uses the grammar constraint to predict the next output string and tokenize the string into tokens, and therefore speeds up the decoding. This PR implements these optimizations to ensure the output quality: - Retokenization in jumpforward: Tokenize the last k token as string appended with the predicted string. If the tokenization result differs from the old tokens, roll back these tokens and accept the new ones. - Retokenization in decoding: Tokenize the last k token as string appended with the decoded token. This will happen in decoding stage when the jumpforward decoding happens in the last round. If the result differs, the old tokens will be rolled back. - Skip prefix tokens in jumpforward: We call tokens that is a prefix of another token as prefix tokens. If the last token from jumpforward is a prefix token, it's highly possible that it will be rolled back in the next decode stage, as it may be combined with the decoded token. It also effects the output distribution as such pattern is rare in training data. Therefore, we skip the last prefix token in jumpforward decoding. This PR also includes the following changes: - Add several metrics for request and engine, especially about the jumpforward decoding - Fix a bug in `_async_query_engine_metrics` to avoid throwing CancelledError from early return Performance and benchmark: Schema(Pydantic): ``` class Product(BaseModel): product_id: int is_available: bool price: float is_featured: Literal[True] category: Literal["Electronics", "Clothing", "Food"] tags: List[str] stock: Dict[str, int] ``` Platform: AMD Ryzen 9 5900X, NVIDIA 3080 10G Results: ``` Jump forward: False, Batch: 1 Engine metrics: { "engine_decode_time_sum": 0.4988938220000001, "engine_jump_forward_time_sum": 0, "completion_tokens_sum": 66, "decode_tokens_sum": 66, "jump_forward_tokens_sum": 0, "decode_tokens_per_s": 132.2926785010378, } Jump forward: True, Batch: 1 Engine metrics: { "engine_decode_time_sum": 0.37242740600000007, "engine_jump_forward_time_sum": 0.027989265000000006, "completion_tokens_sum": 68, "decode_tokens_sum": 68, "jump_forward_tokens_sum": 28, "decode_tokens_per_s": 182.58591850246378, } Jump forward: False, Batch: 4 Engine metrics: { "engine_decode_time_sum": 0.9106805410000002, "engine_jump_forward_time_sum": 0, "completion_tokens_sum": 261, "decode_tokens_sum": 261, "jump_forward_tokens_sum": 0, "decode_tokens_per_s": 286.5988546470984, } Jump forward: True, Batch: 4 Engine metrics: { "engine_decode_time_sum": 0.6843025599999999, "engine_jump_forward_time_sum": 0.028089531999999997, "completion_tokens_sum": 266, "decode_tokens_sum": 266, "jump_forward_tokens_sum": 112, "decode_tokens_per_s": 388.71694415405966, } Jump forward: False, Batch: 8 Engine metrics: { "engine_decode_time_sum": 1.62462493, "engine_jump_forward_time_sum": 0, "completion_tokens_sum": 538, "decode_tokens_sum": 538, "jump_forward_tokens_sum": 0, "decode_tokens_per_s": 331.1533573475325, } Jump forward: True, Batch: 8 Engine metrics: { "engine_decode_time_sum": 1.0509048310000002, "engine_jump_forward_time_sum": 0.027971332000000022, "completion_tokens_sum": 525, "decode_tokens_sum": 525, "jump_forward_tokens_sum": 224, "decode_tokens_per_s": 499.5694990767436, } Jump forward: False, Batch: 16 Engine metrics: { "engine_decode_time_sum": 2.317279175, "engine_jump_forward_time_sum": 0, "completion_tokens_sum": 1068, "decode_tokens_sum": 1068, "jump_forward_tokens_sum": 0, "decode_tokens_per_s": 460.8853398080531, } Jump forward: True, Batch: 16 Engine metrics: { "engine_decode_time_sum": 1.3962938819999997, "engine_jump_forward_time_sum": 0.030129287999999994, "completion_tokens_sum": 1059, "decode_tokens_sum": 1059, "jump_forward_tokens_sum": 448, "decode_tokens_per_s": 758.4363246533227, } ```
1 parent 4234262 commit 42f146d

32 files changed

+1310
-312
lines changed

cpp/grammar/grammar_state_matcher.cc

Lines changed: 109 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <chrono>
99
#include <queue>
1010

11+
#include "../support/dynamic_bitset.h"
1112
#include "../tokenizers/tokenizers.h"
1213
#include "grammar.h"
1314
#include "grammar_serializer.h"
@@ -134,10 +135,12 @@ class GrammarStateMatcherNodeImpl : public GrammarStateMatcherNode, public Gramm
134135
max_rollback_steps_(max_rollback_steps),
135136
tmp_accepted_bitset_(init_ctx_->vocab_size) {}
136137

137-
bool AcceptToken(int32_t token_id) final;
138+
bool AcceptToken(int32_t token_id, bool verbose = false) final;
138139

139140
void FindNextTokenBitmask(DLTensor* next_token_bitmask) final;
140141

142+
std::string FindJumpForwardString() final;
143+
141144
void Rollback(int num_tokens) final;
142145

143146
int MaxRollbackSteps() const final { return max_rollback_steps_; }
@@ -193,7 +196,7 @@ bool GrammarStateMatcherNodeImpl::AcceptStopToken() {
193196
return true;
194197
}
195198

196-
bool GrammarStateMatcherNodeImpl::AcceptToken(int32_t token_id) {
199+
bool GrammarStateMatcherNodeImpl::AcceptToken(int32_t token_id, bool verbose) {
197200
CHECK(!IsTerminated())
198201
<< "GrammarStateMatcher has terminated after accepting the stop token, but is trying to "
199202
"accept another token id "
@@ -202,10 +205,20 @@ bool GrammarStateMatcherNodeImpl::AcceptToken(int32_t token_id) {
202205
CHECK(token_id >= 0 && token_id < init_ctx_->vocab_size)
203206
<< "Invalid token id " << token_id << " for GrammarStateMatcher";
204207

208+
if (verbose) {
209+
LOG(INFO) << "Accepting token id " << token_id << ", string: \""
210+
<< PrintAsEscaped(init_ctx_->token_table[token_id]) << "\", state state:\n"
211+
<< PrintStackState();
212+
}
213+
205214
// Handle the stop token
206215
if (std::find(init_ctx_->stop_token_ids.begin(), init_ctx_->stop_token_ids.end(), token_id) !=
207216
init_ctx_->stop_token_ids.end()) {
208-
return AcceptStopToken();
217+
bool accepted = AcceptStopToken();
218+
if (verbose) {
219+
LOG(INFO) << "The token is an end token. Is accepted: " << accepted;
220+
}
221+
return accepted;
209222
}
210223

211224
if (init_ctx_->special_token_ids.count(token_id) > 0) {
@@ -215,16 +228,25 @@ bool GrammarStateMatcherNodeImpl::AcceptToken(int32_t token_id) {
215228
}
216229

217230
const auto& token = init_ctx_->token_table[token_id];
231+
int pos = 0;
218232
for (auto char_value : token) {
219233
if (!AcceptChar(char_value, false)) {
234+
if (verbose) {
235+
LOG(INFO) << "The token is rejected at position " << pos << ", character "
236+
<< PrintAsEscaped(char_value);
237+
}
220238
return false;
221239
}
240+
++pos;
222241
}
223242
token_length_history.push_back(token.size());
224243
if (token_length_history.size() > max_rollback_steps_) {
225244
DiscardEarliestChars(token_length_history.front());
226245
token_length_history.pop_front();
227246
}
247+
if (verbose) {
248+
LOG(INFO) << "The token is accepted. State after accepting:\n" << PrintStackState();
249+
}
228250
return true;
229251
}
230252

@@ -342,6 +364,85 @@ void GrammarStateMatcherNodeImpl::FindNextTokenBitmask(DLTensor* next_token_bitm
342364
SetTokenBitmask(next_token_bitmask, tmp_accepted_bitset_, tmp_rejected_indices_, can_reach_end);
343365
}
344366

367+
std::string GrammarStateMatcherNodeImpl::FindJumpForwardString() {
368+
CHECK(!IsTerminated())
369+
<< "GrammarStateMatcher has terminated after accepting the stop token, but is trying to "
370+
"get the jump forward string";
371+
372+
std::string result;
373+
int num_accepted_chars = 0;
374+
bool can_find_next_char = true;
375+
376+
while (can_find_next_char) {
377+
const auto& stack_tops = stack_tops_history_.GetLatest();
378+
379+
// 1. Check that for every stack top, the next possible char is unique and the same
380+
// -1 means not found yet; 0~255 means the next char
381+
int next_char = -1;
382+
for (auto stack_top : stack_tops) {
383+
auto rule_position = tree_[stack_top];
384+
auto cur_sequence = grammar_->GetRuleExpr(rule_position.sequence_id);
385+
if (rule_position.parent_id == RulePosition::kNoParent &&
386+
rule_position.element_id == cur_sequence.size()) {
387+
can_find_next_char = false;
388+
break;
389+
}
390+
391+
auto cur_element = grammar_->GetRuleExpr(cur_sequence[rule_position.element_id]);
392+
393+
if (cur_element.type == RuleExprType::kByteString) {
394+
DCHECK(rule_position.element_in_string < cur_element.size());
395+
if (next_char == -1) {
396+
next_char = cur_element[rule_position.element_in_string];
397+
} else if (next_char != cur_element[rule_position.element_in_string]) {
398+
can_find_next_char = false;
399+
break;
400+
}
401+
} else {
402+
DCHECK(cur_element.type == RuleExprType::kCharacterClass ||
403+
cur_element.type == RuleExprType::kCharacterClassStar);
404+
if (rule_position.left_utf8_bytes > 0 || cur_element.size() != 3 || cur_element[0] != 0 ||
405+
cur_element[1] != cur_element[2]) {
406+
can_find_next_char = false;
407+
break;
408+
} else if (next_char == -1) {
409+
next_char = cur_element[1];
410+
} else if (next_char != cur_element[1]) {
411+
can_find_next_char = false;
412+
break;
413+
}
414+
}
415+
}
416+
417+
if (next_char == -1) {
418+
can_find_next_char = false;
419+
}
420+
421+
// 2. If found, accept the char and iterate to the next position
422+
if (can_find_next_char) {
423+
result += static_cast<uint8_t>(next_char);
424+
425+
tmp_new_stack_tops_.clear();
426+
for (auto stack_top : stack_tops) {
427+
auto cur_rule_position = tree_[stack_top];
428+
auto new_rule_position = UpdatePositionWithChar(cur_rule_position, next_char);
429+
430+
if (new_rule_position == cur_rule_position) {
431+
ExpandRulePosition(new_rule_position, &tmp_new_stack_tops_, true, stack_top);
432+
} else {
433+
ExpandRulePosition(new_rule_position, &tmp_new_stack_tops_, true);
434+
}
435+
}
436+
stack_tops_history_.PushHistory(tmp_new_stack_tops_);
437+
++num_accepted_chars;
438+
}
439+
}
440+
441+
// Rollback all chars accepted
442+
RollbackChars(num_accepted_chars);
443+
return result;
444+
}
445+
345446
void GrammarStateMatcherNodeImpl::Rollback(int num_tokens) {
346447
CHECK(num_tokens <= token_length_history.size())
347448
<< "Intended to rollback " << num_tokens << " tokens, but only the last "
@@ -477,10 +578,13 @@ TVM_REGISTER_GLOBAL("mlc.grammar.GrammarStateMatcherDebugAcceptChar")
477578
});
478579

479580
TVM_REGISTER_GLOBAL("mlc.grammar.GrammarStateMatcherAcceptToken")
480-
.set_body_typed([](GrammarStateMatcher matcher, int32_t token_id) {
481-
return matcher->AcceptToken(token_id);
581+
.set_body_typed([](GrammarStateMatcher matcher, int32_t token_id, bool verbose) {
582+
return matcher->AcceptToken(token_id, verbose);
482583
});
483584

585+
TVM_REGISTER_GLOBAL("mlc.grammar.GrammarStateMatcherFindJumpForwardString")
586+
.set_body_typed([](GrammarStateMatcher matcher) { return matcher->FindJumpForwardString(); });
587+
484588
TVM_REGISTER_GLOBAL("mlc.grammar.GrammarStateMatcherRollback")
485589
.set_body_typed([](GrammarStateMatcher matcher, int num_tokens) {
486590
matcher->Rollback(num_tokens);

cpp/grammar/grammar_state_matcher.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ class GrammarStateMatcherNode : public Object {
6565
* FindNextTokenMask operations can be performed. The termination state can be canceled
6666
* using Rollback().
6767
*/
68-
virtual bool AcceptToken(int32_t token_id) = 0;
68+
virtual bool AcceptToken(int32_t token_id, bool verbose = false) = 0;
6969

7070
/*!
7171
* \brief Find the set of tokens that are acceptable for the next step and store them in a
@@ -75,6 +75,13 @@ class GrammarStateMatcherNode : public Object {
7575
*/
7676
virtual void FindNextTokenBitmask(DLTensor* next_token_bitmask) = 0;
7777

78+
/*!
79+
* \brief Find the jump-forward string for jump-forward decoding. This is the longest string that
80+
will be valid according to the current syntax.
81+
* \note This method does not change the grammar state.
82+
*/
83+
virtual std::string FindJumpForwardString() = 0;
84+
7885
/*!
7986
* \brief Rollback the matcher to a previous state.
8087
* \param num_tokens The number of tokens to rollback. It cannot exceed the current number of

cpp/grammar/grammar_state_matcher_preproc.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include <vector>
1010

11+
#include "../support/dynamic_bitset.h"
1112
#include "../support/encoding.h"
1213
#include "../support/utils.h"
1314
#include "grammar.h"

cpp/grammar/support.h

Lines changed: 0 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -17,96 +17,6 @@ namespace mlc {
1717
namespace llm {
1818
namespace serve {
1919

20-
/*! \brief A bitset with runtime specified length. It manages memory internally or the memory
21-
* provided externally with enough size. */
22-
class DynamicBitset {
23-
public:
24-
static int CalculateBufferSize(int element_size) { return (element_size + 31) / 32; }
25-
26-
DynamicBitset() : size_(0), buffer_size_(0), data_(nullptr), is_internal_(true) {}
27-
28-
DynamicBitset(int size, uint32_t* data = nullptr)
29-
: size_(size), buffer_size_(CalculateBufferSize(size)) {
30-
if (data == nullptr) {
31-
internal_buffer_.resize(buffer_size_, 0);
32-
data_ = internal_buffer_.data();
33-
is_internal_ = true;
34-
} else {
35-
data_ = data;
36-
is_internal_ = false;
37-
}
38-
}
39-
40-
DynamicBitset& operator=(const DynamicBitset& other) {
41-
DCHECK(is_internal_ || size_ >= other.size_) << "Expanding bitset size is not allowed when the "
42-
"memory of the bitset is externally managed";
43-
size_ = other.size_;
44-
buffer_size_ = other.buffer_size_;
45-
if (is_internal_) {
46-
internal_buffer_.reserve(buffer_size_);
47-
data_ = internal_buffer_.data();
48-
}
49-
if (data_ != other.data_) {
50-
std::memcpy(data_, other.data_, buffer_size_ * sizeof(uint32_t));
51-
}
52-
return *this;
53-
}
54-
55-
DynamicBitset& operator=(DynamicBitset&& other) {
56-
size_ = other.size_;
57-
buffer_size_ = other.buffer_size_;
58-
is_internal_ = other.is_internal_;
59-
if (is_internal_) {
60-
internal_buffer_ = std::move(other.internal_buffer_);
61-
data_ = internal_buffer_.data();
62-
} else {
63-
data_ = other.data_;
64-
}
65-
return *this;
66-
}
67-
68-
bool operator[](int index) const {
69-
DCHECK(data_ && index >= 0 && index < size_);
70-
return (data_[index / 32] >> (index % 32)) & 1;
71-
}
72-
73-
int Size() const { return size_; }
74-
75-
void Set(int index, bool value) {
76-
DCHECK(data_ && index >= 0 && index < size_);
77-
if (value) {
78-
data_[index / 32] |= 1 << (index % 32);
79-
} else {
80-
data_[index / 32] &= ~(1 << (index % 32));
81-
}
82-
}
83-
84-
void Set() {
85-
DCHECK(data_);
86-
std::memset(data_, 0xFF, buffer_size_ * sizeof(uint32_t));
87-
}
88-
89-
void Reset() {
90-
DCHECK(data_);
91-
std::memset(data_, 0, buffer_size_ * sizeof(uint32_t));
92-
}
93-
94-
DynamicBitset& operator|=(const DynamicBitset& other) {
95-
DCHECK(buffer_size_ <= other.buffer_size_);
96-
for (int i = 0; i < buffer_size_; ++i) {
97-
data_[i] |= other.data_[i];
98-
}
99-
return *this;
100-
}
101-
102-
private:
103-
int size_;
104-
int buffer_size_;
105-
uint32_t* data_;
106-
std::vector<uint32_t> internal_buffer_;
107-
bool is_internal_;
108-
};
109-
11020
/*!
11121
* \brief Let lhs be the union of lhs and rhs. Suppose that both sets are sorted.
11222
* \note No additional vectors are allocated, and the time complexity is O(n)

cpp/serve/config.cc

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,15 @@ Result<DebugConfig> DebugConfig::FromJSON(const picojson::object& config) {
7777
return TResult::Error("Uknown special request " + special_request);
7878
}
7979
}
80+
std::string grammar_execution_mode =
81+
json::LookupOrDefault<std::string>(config, "grammar_execution_mode", "jump_forward");
82+
if (grammar_execution_mode == "jump_forward") {
83+
res.grammar_execution_mode = GrammarExecutionMode::kJumpForward;
84+
} else if (grammar_execution_mode == "constraint") {
85+
res.grammar_execution_mode = GrammarExecutionMode::kConstraint;
86+
} else {
87+
return TResult::Error("Uknown grammar execution mode " + grammar_execution_mode);
88+
}
8089
return TResult::Ok(res);
8190
}
8291

@@ -95,6 +104,16 @@ picojson::object DebugConfig::AsJSON() const {
95104
case SpecialRequestKind::kNone:
96105
break;
97106
}
107+
switch (grammar_execution_mode) {
108+
case GrammarExecutionMode::kJumpForward: {
109+
config["grammar_execution_mode"] = picojson::value("jump_forward");
110+
break;
111+
}
112+
case GrammarExecutionMode::kConstraint: {
113+
config["grammar_execution_mode"] = picojson::value("constraint");
114+
break;
115+
}
116+
}
98117
return config;
99118
}
100119

cpp/serve/config.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,23 @@ enum class SpecialRequestKind : int {
4646
kQueryEngineMetrics = 1,
4747
};
4848

49+
/*! \brief Controls the behavior of inference with grammar constraint. */
50+
enum class GrammarExecutionMode : int {
51+
/*! \brief If grammar is provided for a request, use the grammar to constrain the output token. */
52+
kConstraint = 0,
53+
/*! \brief If grammar is provided for a request, not only constrain the output, but also use the
54+
* jump-forward decoding to predict the next tokens. This is the default option. */
55+
kJumpForward = 1,
56+
};
57+
4958
/*! \brief The debug configuration of a request. */
5059
class DebugConfig {
5160
public:
5261
bool ignore_eos = false;
5362
bool pinned_system_prompt = false;
5463
SpecialRequestKind special_request = SpecialRequestKind::kNone;
64+
/*! \brief The grammar execution mode. */
65+
GrammarExecutionMode grammar_execution_mode = GrammarExecutionMode::kJumpForward;
5566

5667
/*!
5768
* \brief Create debug config from JSON.

cpp/serve/data.cc

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -173,12 +173,13 @@ TVM_REGISTER_OBJECT_TYPE(RequestStreamOutputObj);
173173
RequestStreamOutput::RequestStreamOutput(
174174
String request_id, Array<IntTuple> group_delta_token_ids,
175175
Optional<Array<Array<String>>> group_delta_logprob_json_strs,
176-
Array<Optional<String>> group_finish_reason) {
176+
Array<Optional<String>> group_finish_reason, Array<String> group_extra_prefix_string) {
177177
ObjectPtr<RequestStreamOutputObj> n = make_object<RequestStreamOutputObj>();
178178
n->request_id = std::move(request_id);
179179
n->group_delta_token_ids = std::move(group_delta_token_ids);
180180
n->group_delta_logprob_json_strs = std::move(group_delta_logprob_json_strs);
181181
n->group_finish_reason = std::move(group_finish_reason);
182+
n->group_extra_prefix_string = std::move(group_extra_prefix_string);
182183
data_ = std::move(n);
183184
}
184185

@@ -192,9 +193,12 @@ RequestStreamOutput RequestStreamOutput::Usage(String request_id,
192193

193194
TVM_REGISTER_GLOBAL("mlc.serve.RequestStreamOutputUnpack")
194195
.set_body_typed([](RequestStreamOutput output) {
195-
return Array<ObjectRef>{output->request_id, output->group_delta_token_ids,
196-
output->group_delta_logprob_json_strs, output->group_finish_reason,
197-
output->request_final_usage_json_str};
196+
return Array<ObjectRef>{output->request_id,
197+
output->group_delta_token_ids,
198+
output->group_delta_logprob_json_strs,
199+
output->group_finish_reason,
200+
output->request_final_usage_json_str,
201+
output->group_extra_prefix_string};
198202
});
199203

200204
} // namespace serve

0 commit comments

Comments
 (0)