1
+ from concurrent .futures import ThreadPoolExecutor
2
+ from typing import Dict , Set
3
+
1
4
import numpy as np
2
5
import regex
3
6
from transformers import LogitsProcessor , PreTrainedModel , PreTrainedTokenizer
@@ -30,11 +33,13 @@ def complete_re(prompt:str, pattern: regex.Pattern, tokenizer: PreTrainedTokeniz
30
33
partial_completion = ""
31
34
prompt_plus_completion = prompt + partial_completion
32
35
36
+ token_validator = TokenValidator (tokenizer )
37
+
33
38
while gen_tokens < max_new_tokens :
34
39
prompt_token_ids = tokenizer .encode (prompt_plus_completion , return_tensors = "pt" )
35
40
prompt_length = prompt_token_ids .shape [1 ]
36
41
37
- allowed_token_ids = valid_next_tokens (partial_completion , pattern , tokenizer )
42
+ allowed_token_ids = token_validator . get_valid_next_tokens (partial_completion , pattern )
38
43
custom_mask_processor = CustomLogitsMask (allowed_token_ids )
39
44
40
45
output_ids = model .generate (prompt_token_ids ,
@@ -55,17 +60,27 @@ def complete_re(prompt:str, pattern: regex.Pattern, tokenizer: PreTrainedTokeniz
55
60
gen_tokens += 1
56
61
57
62
return partial_completion
58
-
59
- def valid_next_tokens (partial_completion : str ,
60
- pattern : regex .Pattern ,
61
- tokenizer : PreTrainedTokenizer ):
62
- """
63
- Return a list of valid next tokens for a prompt.
64
- """
65
- valid_token_ids = set ()
66
- for _ , token_id in tokenizer .get_vocab ().items ():
67
- decoded_token = tokenizer .decode (token_id )
68
- if pattern .match (partial_completion + decoded_token , partial = True ):
69
- valid_token_ids .add (token_id )
70
-
71
- return valid_token_ids
63
+
64
+ class TokenValidator :
65
+ def __init__ (self , tokenizer : PreTrainedTokenizer ):
66
+ self .tokenizer = tokenizer
67
+ self .decoded_tokens_cache = self .build_decoded_tokens_cache (tokenizer )
68
+
69
+ @staticmethod
70
+ def build_decoded_tokens_cache (tokenizer : PreTrainedTokenizer ) -> Dict [int , str ]:
71
+ return {token_id : tokenizer .decode (token_id ) for _ , token_id in tokenizer .get_vocab ().items ()}
72
+
73
+ def is_valid_token (self , token_id : int , partial_completion : str , pattern : regex .Pattern ) -> bool :
74
+ decoded_token = self .decoded_tokens_cache [token_id ]
75
+ return pattern .match (partial_completion + decoded_token , partial = True )
76
+
77
+ def get_valid_next_tokens (self , partial_completion : str , pattern : regex .Pattern ) -> Set [int ]:
78
+ with ThreadPoolExecutor ():
79
+ valid_token_ids = set (
80
+ filter (
81
+ lambda token_id : self .is_valid_token (token_id , partial_completion , pattern ),
82
+ self .decoded_tokens_cache .keys ()
83
+ )
84
+ )
85
+
86
+ return valid_token_ids
0 commit comments