You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
[Serve][Grammar] Jump-forward decoding
This PR supports the jump-forward decoding as described in
<https://lmsys.org/blog/2024-02-05-compressed-fsm/>. The jump-forward
decoding uses the grammar constraint to predict the next output string and
tokenize the string into tokens, and therefore speeds up the decoding.
This PR implements these optimizations to ensure the output quality:
- Retokenization in jumpforward: Tokenize the last k token as string appended with the predicted
string. If the tokenization result differs from the old tokens, roll back
these tokens and accept the new ones.
- Retokenization in decoding: Tokenize the last k token as string appended with
the decoded token. This will happen in decoding stage when the jumpforward decoding happens
in the last round. If the result differs, the old tokens will be rolled back.
- Skip prefix tokens in jumpforward: We call tokens that is a prefix of another token
as prefix tokens. If the last token from jumpforward is a prefix token, it's highly possible
that it will be rolled back in the next decode stage, as it may be combined with the
decoded token. It also effects the output distribution as such pattern is rare in training data.
Therefore, we skip the last prefix token in jumpforward decoding.
This PR also includes the following changes:
- Add several metrics for request and engine, especially about the jumpforward decoding
- Fix a bug in `_async_query_engine_metrics` to avoid throwing CancelledError from early return
Performance and benchmark:
Schema(Pydantic):
```
class Product(BaseModel):
product_id: int
is_available: bool
price: float
is_featured: Literal[True]
category: Literal["Electronics", "Clothing", "Food"]
tags: List[str]
stock: Dict[str, int]
```
Platform: AMD Ryzen 9 5900X, NVIDIA 3080 10G
Results:
```
Jump forward: False, Batch: 1
Engine metrics:
{
"engine_decode_time_sum": 0.4988938220000001,
"engine_jump_forward_time_sum": 0,
"completion_tokens_sum": 66,
"decode_tokens_sum": 66,
"jump_forward_tokens_sum": 0,
"decode_tokens_per_s": 132.2926785010378,
}
Jump forward: True, Batch: 1
Engine metrics:
{
"engine_decode_time_sum": 0.37242740600000007,
"engine_jump_forward_time_sum": 0.027989265000000006,
"completion_tokens_sum": 68,
"decode_tokens_sum": 68,
"jump_forward_tokens_sum": 28,
"decode_tokens_per_s": 182.58591850246378,
}
Jump forward: False, Batch: 4
Engine metrics:
{
"engine_decode_time_sum": 0.9106805410000002,
"engine_jump_forward_time_sum": 0,
"completion_tokens_sum": 261,
"decode_tokens_sum": 261,
"jump_forward_tokens_sum": 0,
"decode_tokens_per_s": 286.5988546470984,
}
Jump forward: True, Batch: 4
Engine metrics:
{
"engine_decode_time_sum": 0.6843025599999999,
"engine_jump_forward_time_sum": 0.028089531999999997,
"completion_tokens_sum": 266,
"decode_tokens_sum": 266,
"jump_forward_tokens_sum": 112,
"decode_tokens_per_s": 388.71694415405966,
}
Jump forward: False, Batch: 8
Engine metrics:
{
"engine_decode_time_sum": 1.62462493,
"engine_jump_forward_time_sum": 0,
"completion_tokens_sum": 538,
"decode_tokens_sum": 538,
"jump_forward_tokens_sum": 0,
"decode_tokens_per_s": 331.1533573475325,
}
Jump forward: True, Batch: 8
Engine metrics:
{
"engine_decode_time_sum": 1.0509048310000002,
"engine_jump_forward_time_sum": 0.027971332000000022,
"completion_tokens_sum": 525,
"decode_tokens_sum": 525,
"jump_forward_tokens_sum": 224,
"decode_tokens_per_s": 499.5694990767436,
}
Jump forward: False, Batch: 16
Engine metrics:
{
"engine_decode_time_sum": 2.317279175,
"engine_jump_forward_time_sum": 0,
"completion_tokens_sum": 1068,
"decode_tokens_sum": 1068,
"jump_forward_tokens_sum": 0,
"decode_tokens_per_s": 460.8853398080531,
}
Jump forward: True, Batch: 16
Engine metrics:
{
"engine_decode_time_sum": 1.3962938819999997,
"engine_jump_forward_time_sum": 0.030129287999999994,
"completion_tokens_sum": 1059,
"decode_tokens_sum": 1059,
"jump_forward_tokens_sum": 448,
"decode_tokens_per_s": 758.4363246533227,
}
```
0 commit comments