1
1
//! Provides tools and interfaces to integrate the crate's functionality with Python.
2
2
3
+ use std:: collections:: VecDeque ;
3
4
use std:: sync:: Arc ;
4
5
5
6
use bincode:: { config, Decode , Encode } ;
@@ -27,16 +28,19 @@ macro_rules! type_name {
27
28
pub struct PyGuide {
28
29
state : StateId ,
29
30
index : PyIndex ,
31
+ state_cache : VecDeque < StateId > ,
30
32
}
31
33
32
34
#[ pymethods]
33
35
impl PyGuide {
34
36
/// Creates a Guide object based on Index.
35
37
#[ new]
36
- fn __new__ ( index : PyIndex ) -> Self {
38
+ #[ pyo3( signature = ( index, max_rollback=32 ) ) ]
39
+ fn __new__ ( index : PyIndex , max_rollback : usize ) -> Self {
37
40
PyGuide {
38
41
state : index. get_initial_state ( ) ,
39
42
index,
43
+ state_cache : VecDeque :: with_capacity ( max_rollback) ,
40
44
}
41
45
}
42
46
@@ -57,6 +61,11 @@ impl PyGuide {
57
61
) ) )
58
62
}
59
63
64
+ /// Get the number of rollback steps available.
65
+ fn get_allowed_rollback ( & self ) -> usize {
66
+ self . state_cache . len ( )
67
+ }
68
+
60
69
/// Guide moves to the next state provided by the token id and returns a list of allowed tokens, unless return_tokens is False.
61
70
#[ pyo3( signature = ( token_id, return_tokens=None ) ) ]
62
71
fn advance (
@@ -66,6 +75,11 @@ impl PyGuide {
66
75
) -> PyResult < Option < Vec < TokenId > > > {
67
76
match self . index . get_next_state ( self . state , token_id) {
68
77
Some ( new_state) => {
78
+ // Free up space in state_cache if needed.
79
+ if self . state_cache . len ( ) == self . state_cache . capacity ( ) {
80
+ self . state_cache . pop_front ( ) ;
81
+ }
82
+ self . state_cache . push_back ( self . state ) ;
69
83
self . state = new_state;
70
84
if return_tokens. unwrap_or ( true ) {
71
85
self . get_tokens ( ) . map ( Some )
@@ -80,6 +94,41 @@ impl PyGuide {
80
94
}
81
95
}
82
96
97
+ /// Rollback the Guide state `n` tokens (states).
98
+ /// Fails if `n` is greater than stored prior states.
99
+ fn rollback_state ( & mut self , n : usize ) -> PyResult < ( ) > {
100
+ if n == 0 {
101
+ return Ok ( ( ) ) ;
102
+ }
103
+ if n > self . get_allowed_rollback ( ) {
104
+ return Err ( PyValueError :: new_err ( format ! (
105
+ "Cannot roll back {n} step(s): only {available} states stored (max_rollback = {cap}). \
106
+ You must advance through at least {n} state(s) before rolling back {n} step(s).",
107
+ cap = self . state_cache. capacity( ) ,
108
+ available = self . get_allowed_rollback( ) ,
109
+ ) ) ) ;
110
+ }
111
+ let mut new_state: u32 = self . state ;
112
+ for _ in 0 ..n {
113
+ // unwrap is safe because length is checked above
114
+ new_state = self . state_cache . pop_back ( ) . unwrap ( ) ;
115
+ }
116
+ self . state = new_state;
117
+ Ok ( ( ) )
118
+ }
119
+
120
+ // Returns a boolean indicating if the sequence leads to a valid state in the DFA
121
+ fn accepts_tokens ( & self , sequence : Vec < u32 > ) -> bool {
122
+ let mut state = self . state ;
123
+ for t in sequence {
124
+ match self . index . get_next_state ( state, t) {
125
+ Some ( s) => state = s,
126
+ None => return false ,
127
+ }
128
+ }
129
+ true
130
+ }
131
+
83
132
/// Checks if the automaton is in a final state.
84
133
fn is_finished ( & self ) -> bool {
85
134
self . index . is_final_state ( self . state )
0 commit comments