Skip to content

Index Final State Semantics Do Not Guarantee Regex Completion #210

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
ykhrustalev opened this issue Apr 16, 2025 · 0 comments
Open

Index Final State Semantics Do Not Guarantee Regex Completion #210

ykhrustalev opened this issue Apr 16, 2025 · 0 comments
Labels
enhancement New feature or request

Comments

@ykhrustalev
Copy link
Contributor

ykhrustalev commented Apr 16, 2025

What behavior of the library made you think about the improvement?

In the current Index abstraction over the outlines-core regex engine, the concept of a final state is misleading: it does not actually signify that the regex has been fully matched or that a client can safely stop generating tokens.

While final_states includes states that can transition to themselves on the eos_token_id, this transition is not enforced or surfaced in a way that makes it obvious when a valid, complete match has occurred.

This results in a problem for client code: even after reaching a state marked as final, the regex may not be truly complete unless an additional eos_token_id transition is explicitly made. But from the perspective of Index::is_final_state, the client may incorrectly assume that it can exit.

Why This Matters:

Consumers using the Index to validate token sequences—especially in contexts like constrained decoding—need a reliable signal that a match is valid and complete. The current API makes it easy to:

  1. Reach a "final" state,

  2. Exit early,

  3. And unknowingly produce an incomplete match.

This is especially problematic when using Index to validate partial decoding paths during generation.

How would you like it to behave?

Potential Fixes / Suggestions:

  1. Clarify Semantics in Docs:
    Clearly state that final_states does not guarantee regex completion unless followed by eos_token_id.

  2. Introduce is_accepting_state(&self, state: &StateId) -> bool:
    Add a method that checks whether a state is both in final_states and has a valid eos_token_id transition—i.e., a true "accepting" state.

  3. Strengthen Invariants in next_state:
    Consider enforcing or guiding the use of eos_token_id transitions when appropriate. Or, make these transitions more visible through a specialized method.

  4. Provide Utility to Check for Completion:
    A method like can_finish(&self, state: &StateId) -> bool would be extremely helpful to validate match completion.

Related abstraction that tries to solve that

use crate::errors::ConstrainedDecodingError;
use outlines_core::index::Index;
use outlines_core::json_schema;
use outlines_core::prelude::StateId;
use outlines_core::primitives::TokenId;
use outlines_core::vocabulary::Vocabulary;
use rustc_hash::FxBuildHasher;
use std::collections::HashMap;

/// A trait to guide the decoding process.
///
/// Based on PyGuide from outlines-core.
pub trait Guide {
    /// Returns true if it is generally possible to advance the guide.
    ///
    /// When it returns false, the client shall stop the generation.
    fn can_advance(&self) -> bool;

    /// Returns true if the given token is allowed for the next transition from the current state.
    fn is_allowed(&self, token_id: &u32) -> bool;

    /// Attempts to transition to the next state using the given token.
    ///
    /// The client can invoke is_allowed to verify the safe transition.
    ///
    /// # Errors
    ///
    /// * `StateTrackerError::TokenNotAllowed` if the token is not allowed for the current state.
    /// * `StateTrackerError::EndOfIndex` if the current state is a final state.
    fn advance(&mut self, token_id: &u32) -> Result<(), ConstrainedDecodingError>;

    /// Returns true if the current state is a final state.
    ///
    /// This means the guide saw the complete sequence, thought the client need to invoke can_advance
    /// to know if that is a fully complete state or other variations are possible.
    /// The example of such behavior is a greedy pattern that reports as complete early but allows
    /// longer sequence.
    fn is_finished(&self) -> bool;
}

/// A guide that uses a regex to guide the decoding process.
///
/// Based on PyGuide from outlines-core.
#[derive(Debug, Clone)]
pub struct RegexGuide {
    index: Index,
    // (state, allowed_tokens)
    state: Option<StateId>,
    // optimized structure for the tokens lookup
    allowed_tokens: ArrayBasedLookup,
}

impl RegexGuide {
    /// Creates a new regex guide.
    ///
    /// # Arguments
    ///
    /// * `regex` - The regex to guide the decoding process.
    /// * `eos_token_id` - The token id of the end-of-sequence token.
    /// * `tokens` - A map of tokens to their ids.
    pub fn new(
        regex: &str,
        eos_token_id: u32,
        tokens: &HashMap<&str, u32>,
    ) -> Result<Self, ConstrainedDecodingError> {
        let tokens_map: HashMap<String, Vec<TokenId>, FxBuildHasher> = tokens
            .iter()
            .map(|(k, v)| (k.to_string(), vec![*v]))
            // strip the eos token id from the allowed tokens
            .filter(|(_, v)| v.iter().any(|id| id != &eos_token_id))
            .collect::<HashMap<String, Vec<TokenId>, FxBuildHasher>>();

        let vocabulary = Vocabulary::try_from((eos_token_id, tokens_map))
            .map_err(ConstrainedDecodingError::BackendError)?;

        let index =
            Index::new(regex, &vocabulary).map_err(ConstrainedDecodingError::BackendError)?;

        let initial_state = index.initial_state();

        // The lookups to happen for the logits array which size may be larger then the vocabulary.
        // Prepare the lookup table based on the max token id.
        // +1 for the eos token
        let max_token_id: usize = tokens
            .values()
            .chain(std::iter::once(&eos_token_id))
            .max()
            .unwrap_or(&eos_token_id)
            .to_owned() as usize + 1;
        let allowed_tokens = ArrayBasedLookup::new(max_token_id);

        let mut guide = Self {
            index,
            state: None,
            allowed_tokens,
        };

        // set at the beginning
        guide.set_state(Some(initial_state));

        Ok(guide)
    }

    fn set_state(&mut self, new_state: Option<StateId>) {
        match new_state {
            Some(state) => match self.index.transitions().get(&state) {
                Some(transitions) => {
                    self.allowed_tokens.reset_with(transitions);
                    self.state = Some(state);
                }
                None => {
                    self.state = None;
                }
            },
            None => {
                self.state = None;
            }
        }
    }

    /// Creates a new regex guide from a json schema.
    ///
    /// # Arguments
    ///
    /// * `schema` - The json schema to guide the decoding process.
    /// * `eos_token_id` - The token id of the end-of-sequence token.
    /// * `tokens` - A map of tokens to their ids.
    pub fn from_json_schema(
        schema: &str,
        eos_token_id: u32,
        tokens: HashMap<&str, u32>,
    ) -> Result<Self, ConstrainedDecodingError> {
        let regex = json_schema::regex_from_str(schema, None)
            .map_err(ConstrainedDecodingError::JsonSchemaError)?;
        Self::new(&regex, eos_token_id, &tokens)
    }
}

impl Guide for RegexGuide {
    fn can_advance(&self) -> bool {
        match &self.state {
            Some(_) => !self.allowed_tokens.is_empty(),
            None => false,
        }
    }

    fn is_allowed(&self, token_id: &u32) -> bool {
        match &self.state {
            Some(_) => self.allowed_tokens.contains(token_id),
            None => false,
        }
    }

    fn advance(&mut self, token_id: &u32) -> Result<(), ConstrainedDecodingError> {
        match &self.state {
            Some(state) => {
                if !self.is_allowed(token_id) {
                    // user didn't check it
                    return Err(ConstrainedDecodingError::TokenNotAllowed);
                }

                match self.index.next_state(state, token_id) {
                    Some(next_state) => {
                        self.set_state(Some(next_state));
                        Ok(())
                    }
                    None => {
                        // should not be reached, consider it as end of index
                        self.set_state(None);
                        Ok(())
                    }
                }
            }
            None => Err(ConstrainedDecodingError::EndOfIndex),
        }
    }

    fn is_finished(&self) -> bool {
        match &self.state {
            Some(state) => self.index.is_final_state(state),
            None => true,
        }
    }
}

#[derive(Debug, Clone)]
pub struct ArrayBasedLookup {
    tokens: Box<[bool]>,
    is_empty: bool,
}

impl ArrayBasedLookup {
    fn new(max_size: usize) -> Self {
        Self {
            tokens: vec![false; max_size].into_boxed_slice(),
            is_empty: true,
        }
    }

    fn reset_with(&mut self, state_map: &HashMap<TokenId, StateId, FxBuildHasher>) -> &Self {
        self.tokens.iter_mut().for_each(|m| *m = false);
        assert!(state_map.len() <= self.tokens.len());
        for token_id in state_map.keys() {
            self.tokens[*token_id as usize] = true;
        }
        self.is_empty = state_map.is_empty();
        self
    }

    fn contains(&self, token_id: &u32) -> bool {
        // the lookup may happen on the value outside of the vocabulary
        if *token_id >= self.tokens.len() as u32 {
            return false;
        }
        self.tokens[*token_id as usize]
    }

    fn is_empty(&self) -> bool {
        self.is_empty
    }
}

#[cfg(test)]
mod tests {

    use super::*;

    fn create_guide_1() -> (RegexGuide, u32, HashMap<String, u32>) {
        let regex = "0|[1-9][0-9]{1,2}";
        let eos_token_id = 4;
        let tokens = HashMap::from_iter([("blah", 0), ("1a", 1), ("2", 2), ("0", 3)]);
        (
            RegexGuide::new(regex, eos_token_id, &tokens).expect("State tracker failed"),
            eos_token_id,
            tokens
                .into_iter()
                .map(|(k, v)| (k.to_string(), v))
                .collect(),
        )
    }

    #[test]
    fn guide_1_initial() {
        let (mut guide, eos_token_id, tokens) = create_guide_1();

        assert!(!guide.is_finished());
        assert!(guide.can_advance());

        assert!(!guide.is_allowed(&tokens["blah"]));
        assert!(!guide.is_allowed(&tokens["1a"]));
        assert!(guide.is_allowed(&tokens["2"]));
        assert!(guide.is_allowed(&tokens["0"]));
        assert!(!guide.is_allowed(&eos_token_id));

        assert!(matches!(
            guide.clone().advance(&tokens["blah"]),
            Err(ConstrainedDecodingError::TokenNotAllowed)
        ));
        assert!(matches!(
            guide.clone().advance(&tokens["1a"]),
            Err(ConstrainedDecodingError::TokenNotAllowed)
        ));
        // try 2nd branch
        assert!(matches!(guide.clone().advance(&tokens["2"]), Ok(())));

        // go 1st branch
        assert!(matches!(guide.advance(&tokens["0"]), Ok(())));
        assert!(guide.can_advance());
        assert!(!guide.is_allowed(&tokens["blah"]));
        assert!(!guide.is_allowed(&tokens["1a"]));
        assert!(!guide.is_allowed(&tokens["2"]));
        assert!(!guide.is_allowed(&tokens["0"]));
        assert!(guide.is_allowed(&eos_token_id));
    }

    #[test]
    fn guide_1_short_branch() {
        let (mut guide, eos_token_id, tokens) = create_guide_1();

        // first
        assert!(guide.advance(&tokens["0"]).is_ok());

        assert!(guide.is_finished());
        assert!(guide.can_advance());

        assert!(!guide.is_allowed(&tokens["blah"]));
        assert!(!guide.is_allowed(&tokens["1a"]));
        assert!(!guide.is_allowed(&tokens["2"]));
        assert!(!guide.is_allowed(&tokens["0"]));
        assert!(guide.is_allowed(&eos_token_id));

        assert!(matches!(
            guide.advance(&tokens["blah"]),
            Err(ConstrainedDecodingError::TokenNotAllowed)
        ));
        assert!(matches!(
            guide.advance(&tokens["1a"]),
            Err(ConstrainedDecodingError::TokenNotAllowed)
        ));
        assert!(matches!(
            guide.advance(&tokens["2"]),
            Err(ConstrainedDecodingError::TokenNotAllowed)
        ));
        assert!(matches!(
            guide.advance(&tokens["0"]),
            Err(ConstrainedDecodingError::TokenNotAllowed)
        ));

        // end
        assert!(matches!(guide.advance(&eos_token_id), Ok(())));

        assert!(!guide.can_advance());
        assert!(guide.is_finished());

        assert!(matches!(
            guide.advance(&tokens["blah"]),
            Err(ConstrainedDecodingError::EndOfIndex)
        ));
        assert!(matches!(
            guide.advance(&tokens["1a"]),
            Err(ConstrainedDecodingError::EndOfIndex)
        ));
        assert!(matches!(
            guide.advance(&tokens["2"]),
            Err(ConstrainedDecodingError::EndOfIndex)
        ));
        assert!(matches!(
            guide.advance(&tokens["0"]),
            Err(ConstrainedDecodingError::EndOfIndex)
        ));
    }

    #[test]
    fn guide_1_long_branch() {
        let (mut guide, eos_token_id, tokens) = create_guide_1();

        // first
        assert!(guide.advance(&tokens["2"]).is_ok());

        assert!(!guide.is_finished());
        assert!(guide.can_advance());
        assert!(!guide.is_allowed(&tokens["blah"]));
        assert!(!guide.is_allowed(&tokens["1a"]));
        assert!(guide.is_allowed(&tokens["2"]));
        assert!(guide.is_allowed(&tokens["0"]));
        assert!(!guide.is_allowed(&eos_token_id));

        // second
        assert!(guide.advance(&tokens["0"]).is_ok());
        assert!(guide.can_advance());
        assert!(guide.is_finished());
        assert!(!guide.is_allowed(&tokens["blah"]));
        assert!(!guide.is_allowed(&tokens["1a"]));
        assert!(guide.is_allowed(&tokens["2"]));
        assert!(guide.is_allowed(&tokens["0"]));
        assert!(guide.is_allowed(&eos_token_id));

        // third
        assert!(guide.advance(&tokens["2"]).is_ok());
        assert!(guide.can_advance());
        assert!(guide.is_finished());
        assert!(!guide.is_allowed(&tokens["blah"]));
        assert!(!guide.is_allowed(&tokens["1a"]));
        assert!(!guide.is_allowed(&tokens["2"]));
        assert!(!guide.is_allowed(&tokens["0"]));
        assert!(guide.is_allowed(&eos_token_id));

        // eos
        assert!(guide.advance(&eos_token_id).is_ok());
        assert!(!guide.can_advance());
        assert!(guide.is_finished());
        assert!(!guide.is_allowed(&tokens["blah"]));
        assert!(!guide.is_allowed(&tokens["1a"]));
        assert!(!guide.is_allowed(&tokens["2"]));
        assert!(!guide.is_allowed(&tokens["0"]));
        assert!(!guide.is_allowed(&eos_token_id));
    }
}
@ykhrustalev ykhrustalev added the enhancement New feature or request label Apr 16, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

1 participant