Skip to content

Commit a08c29f

Browse files
Add accepts, rollback methods on Guide and tests (#212)
Adds two methods on the Guide object: - `accepts_tokens`: given a sequence of tokens, returns a boolean indicating if that sequence will end in a valid DFA state. - `rollback_state`: rollback the state of the Guide `n` tokens ( states ). errors if `n` is greater than the number of states in self.state_cache This is needed for speculative decoding compatibility in vLLM. See: vllm-project/vllm#15975 (comment)
1 parent 1f9df0d commit a08c29f

File tree

2 files changed

+85
-1
lines changed

2 files changed

+85
-1
lines changed

src/python_bindings/mod.rs

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
//! Provides tools and interfaces to integrate the crate's functionality with Python.
22
3+
use std::collections::VecDeque;
34
use std::sync::Arc;
45

56
use bincode::{config, Decode, Encode};
@@ -27,16 +28,19 @@ macro_rules! type_name {
2728
pub struct PyGuide {
2829
state: StateId,
2930
index: PyIndex,
31+
state_cache: VecDeque<StateId>,
3032
}
3133

3234
#[pymethods]
3335
impl PyGuide {
3436
/// Creates a Guide object based on Index.
3537
#[new]
36-
fn __new__(index: PyIndex) -> Self {
38+
#[pyo3(signature = (index, max_rollback=32))]
39+
fn __new__(index: PyIndex, max_rollback: usize) -> Self {
3740
PyGuide {
3841
state: index.get_initial_state(),
3942
index,
43+
state_cache: VecDeque::with_capacity(max_rollback),
4044
}
4145
}
4246

@@ -57,6 +61,11 @@ impl PyGuide {
5761
)))
5862
}
5963

64+
/// Get the number of rollback steps available.
65+
fn get_allowed_rollback(&self) -> usize {
66+
self.state_cache.len()
67+
}
68+
6069
/// Guide moves to the next state provided by the token id and returns a list of allowed tokens, unless return_tokens is False.
6170
#[pyo3(signature = (token_id, return_tokens=None))]
6271
fn advance(
@@ -66,6 +75,11 @@ impl PyGuide {
6675
) -> PyResult<Option<Vec<TokenId>>> {
6776
match self.index.get_next_state(self.state, token_id) {
6877
Some(new_state) => {
78+
// Free up space in state_cache if needed.
79+
if self.state_cache.len() == self.state_cache.capacity() {
80+
self.state_cache.pop_front();
81+
}
82+
self.state_cache.push_back(self.state);
6983
self.state = new_state;
7084
if return_tokens.unwrap_or(true) {
7185
self.get_tokens().map(Some)
@@ -80,6 +94,41 @@ impl PyGuide {
8094
}
8195
}
8296

97+
/// Rollback the Guide state `n` tokens (states).
98+
/// Fails if `n` is greater than stored prior states.
99+
fn rollback_state(&mut self, n: usize) -> PyResult<()> {
100+
if n == 0 {
101+
return Ok(());
102+
}
103+
if n > self.get_allowed_rollback() {
104+
return Err(PyValueError::new_err(format!(
105+
"Cannot roll back {n} step(s): only {available} states stored (max_rollback = {cap}). \
106+
You must advance through at least {n} state(s) before rolling back {n} step(s).",
107+
cap = self.state_cache.capacity(),
108+
available = self.get_allowed_rollback(),
109+
)));
110+
}
111+
let mut new_state: u32 = self.state;
112+
for _ in 0..n {
113+
// unwrap is safe because length is checked above
114+
new_state = self.state_cache.pop_back().unwrap();
115+
}
116+
self.state = new_state;
117+
Ok(())
118+
}
119+
120+
// Returns a boolean indicating if the sequence leads to a valid state in the DFA
121+
fn accepts_tokens(&self, sequence: Vec<u32>) -> bool {
122+
let mut state = self.state;
123+
for t in sequence {
124+
match self.index.get_next_state(state, t) {
125+
Some(s) => state = s,
126+
None => return false,
127+
}
128+
}
129+
true
130+
}
131+
83132
/// Checks if the automaton is in a final state.
84133
fn is_finished(&self) -> bool {
85134
self.index.is_final_state(self.state)

tests/test_guide.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,3 +181,38 @@ def test_write_mask_into_interface(index):
181181
guide.write_mask_into(0, mask.numel(), mask.element_size())
182182
with pytest.raises(ValueError, match="Invalid data pointer alignment"):
183183
guide.write_mask_into(5, mask.numel(), mask.element_size())
184+
185+
186+
def test_rollback(index):
187+
guide = Guide(index, max_rollback=3)
188+
189+
first_state = guide.get_state()
190+
guide.advance(1)
191+
192+
# Roll back one token to initial state
193+
guide.rollback_state(1)
194+
assert not guide.is_finished()
195+
# we should now be in the initial state
196+
assert guide.get_state() == first_state
197+
198+
199+
def test_rollback_interface(index):
200+
guide = Guide(index, max_rollback=3)
201+
202+
# Rolling back more than recorded history must raise
203+
with pytest.raises(ValueError, match="Cannot roll back"):
204+
guide.rollback_state(5)
205+
206+
207+
@pytest.mark.parametrize(
208+
"seq, expected",
209+
[
210+
([1], True), # single allowed token: accept
211+
([2], True), # different allowed token: accept
212+
([1, 1], False), # too long for r"[1-9]": reject
213+
([2, 3], False), # extra token: reject
214+
],
215+
)
216+
def test_accepts_tokens_correctness(index, seq, expected):
217+
guide = Guide(index)
218+
assert guide.accepts_tokens(seq) is expected

0 commit comments

Comments
 (0)