Skip to content

Commit 47331c7

Browse files
committed
5x faster token validator
- Parallelized lookahead - Cached decoding - Cached get vocab
1 parent 62fa6eb commit 47331c7

File tree

4 files changed

+32
-19
lines changed

4 files changed

+32
-19
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
env
2+
.venv
23
.ruff_cache
34
dist
45
*.egg-info

rellm/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from rellm.rellm import complete_re
1+
from rellm.rellm import complete_re

rellm/rellm.py

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
from concurrent.futures import ThreadPoolExecutor
2+
from typing import Dict, Set
3+
14
import numpy as np
25
import regex
36
from transformers import LogitsProcessor, PreTrainedModel, PreTrainedTokenizer
@@ -30,11 +33,13 @@ def complete_re(prompt:str, pattern: regex.Pattern, tokenizer: PreTrainedTokeniz
3033
partial_completion = ""
3134
prompt_plus_completion = prompt + partial_completion
3235

36+
token_validator = TokenValidator(tokenizer)
37+
3338
while gen_tokens < max_new_tokens:
3439
prompt_token_ids = tokenizer.encode(prompt_plus_completion, return_tensors="pt")
3540
prompt_length = prompt_token_ids.shape[1]
3641

37-
allowed_token_ids = valid_next_tokens(partial_completion, pattern, tokenizer)
42+
allowed_token_ids = token_validator.get_valid_next_tokens(partial_completion, pattern)
3843
custom_mask_processor = CustomLogitsMask(allowed_token_ids)
3944

4045
output_ids = model.generate(prompt_token_ids,
@@ -55,17 +60,27 @@ def complete_re(prompt:str, pattern: regex.Pattern, tokenizer: PreTrainedTokeniz
5560
gen_tokens += 1
5661

5762
return partial_completion
58-
59-
def valid_next_tokens(partial_completion: str,
60-
pattern: regex.Pattern,
61-
tokenizer: PreTrainedTokenizer):
62-
"""
63-
Return a list of valid next tokens for a prompt.
64-
"""
65-
valid_token_ids = set()
66-
for _, token_id in tokenizer.get_vocab().items():
67-
decoded_token = tokenizer.decode(token_id)
68-
if pattern.match(partial_completion + decoded_token, partial=True):
69-
valid_token_ids.add(token_id)
70-
71-
return valid_token_ids
63+
64+
class TokenValidator:
65+
def __init__(self, tokenizer: PreTrainedTokenizer):
66+
self.tokenizer = tokenizer
67+
self.decoded_tokens_cache = self.build_decoded_tokens_cache(tokenizer)
68+
69+
@staticmethod
70+
def build_decoded_tokens_cache(tokenizer: PreTrainedTokenizer) -> Dict[int, str]:
71+
return {token_id: tokenizer.decode(token_id) for _, token_id in tokenizer.get_vocab().items()}
72+
73+
def is_valid_token(self, token_id: int, partial_completion: str, pattern: regex.Pattern) -> bool:
74+
decoded_token = self.decoded_tokens_cache[token_id]
75+
return pattern.match(partial_completion + decoded_token, partial=True)
76+
77+
def get_valid_next_tokens(self, partial_completion: str, pattern: regex.Pattern) -> Set[int]:
78+
with ThreadPoolExecutor():
79+
valid_token_ids = set(
80+
filter(
81+
lambda token_id: self.is_valid_token(token_id, partial_completion, pattern),
82+
self.decoded_tokens_cache.keys()
83+
)
84+
)
85+
86+
return valid_token_ids

requirements.txt

Lines changed: 0 additions & 3 deletions
This file was deleted.

0 commit comments

Comments
 (0)