Skip to content

Commit 2a8f6b0

Browse files
committed
IndexV2 + TokensDFA extension
1 parent 33d6deb commit 2a8f6b0

File tree

21 files changed

+3022
-65
lines changed

21 files changed

+3022
-65
lines changed

Cargo.toml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,18 @@ hf-hub = "=0.3.2"
2121
tokenizers = { version = "=0.20.3", features = ["http"] }
2222
rustc-hash = "2.1.0"
2323
regex-automata = "0.4.9"
24+
smallvec = "1.14.0"
25+
regex-syntax = "0.8.5"
26+
rayon = "1.10.0"
27+
28+
[dev-dependencies]
29+
rand = { version = "0.9.0" }
30+
2431

2532
[features]
2633
python-bindings = ["pyo3", "serde-pyobject"]
34+
run_benchmarks = []
35+
2736

2837
[lib]
2938
name = "outlines_core"

benchmarks/bench_indexes.py

Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
# flake8: noqa
2+
# mypy: ignore-errors
3+
import os
4+
import random
5+
import time
6+
7+
import psutil
8+
from outlines_core import Guide, Index, Vocabulary, create_mask, mask_to_list
9+
from outlines_core.json_schema import build_regex_from_schema
10+
11+
os.environ["RUST_LOG"] = "debug"
12+
13+
14+
regexes = [
15+
{
16+
"name": "email",
17+
"regex": r"(?:[a-z0-9!#$%&'*+/=?^_`{|}~-]{1,63}(?:\.[a-z0-9!#$%&'*+/=?^_`{|}~-]{1,63}){0,10})@(?:[a-z0-9](?:[a-z0-9-]{0,61}[a-z0-9])?\.){1,3}[a-z0-9](?:[a-z0-9-]{0,30}[a-z0-9])?",
18+
},
19+
{"name": "simple_phone", "regex": r"\+?[1-9][0-9]{7,14}"},
20+
{
21+
"name": "complex_phone",
22+
"regex": r"\+?\d{1,4}?[-.\s]?\(?\d{1,3}?\)?[-.\s]?\d{1,4}[-.\s]?\d{1,4}[-.\s]?\d{1,9}",
23+
},
24+
{"name": "permissive_any", "regex": r".{255}$"},
25+
{"name": "permissive_words", "regex": r"[a-zA-Z]{100}"},
26+
{"name": "https", "regex" : r"(https?:\\/\\/)?([\\da-z\\.-]+)\\.([a-z\\.]{2,6})([\\/\\w \\.-]*)*\\/?"}
27+
]
28+
schemas = [
29+
{
30+
"name": "schema_simple",
31+
"regex": r'{"type": "object", "properties": {"name": {"type": "string"}, "age": {"type": "integer"}}, "required": ["name", "age"]}',
32+
},
33+
{
34+
"name": "schema_simple_phone",
35+
"regex": r'{"type": "object", "properties": {"name": {"type": "string"}, "age": {"type": "integer"}, "complexe_phone": {"type": "string", "pattern": "\\+?\\d{1,4}?[-. ]?\\(\\d{1,3}\\)?[-. ]?\\d{1,4}[-. ]?\\d{1,4}[-. ]?\\d{1,9}"}}, "required": ["name", "age", "complexe_phone"]}',
36+
},
37+
{
38+
"name": "schema_complexe",
39+
"regex": """{
40+
"$schema": "http://json-schema.org/draft-04/schema#",
41+
"title": "Schema for a recording",
42+
"type": "object",
43+
"definitions": {
44+
"artist": {
45+
"type": "object",
46+
"properties": {
47+
"id": {"type": "number"},
48+
"name": {"type": "string"},
49+
"functions": {
50+
"type": "array",
51+
"items": {"type": "string"}
52+
}
53+
},
54+
"required": ["id", "name", "functions"]
55+
}
56+
},
57+
"properties": {
58+
"id": {"type": "number"},
59+
"work": {
60+
"type": "object",
61+
"properties": {
62+
"id": {"type": "number"},
63+
"name": {"type": "string"},
64+
"composer": {"$ref": "#/definitions/artist"}
65+
}
66+
},
67+
"recording_artists": {
68+
"type": "array",
69+
"items": {"$ref": "#/definitions/artist"}
70+
}
71+
},
72+
"required": ["id", "work", "recording_artists"]
73+
}"""
74+
},
75+
{
76+
"name" : "schema_curriculum",
77+
"regex" : r'''{
78+
"$schema": "http://json-schema.org/draft-04/schema#",
79+
"title": "Schema for a Curriculum Vitae",
80+
"type": "object",
81+
"definitions": {
82+
"experienceEntry": {
83+
"type": "object",
84+
"properties": {
85+
"date": {
86+
"type": "string",
87+
"format": "date"
88+
},
89+
"position": {
90+
"type": "string"
91+
}
92+
},
93+
"required": ["date", "position"]
94+
}
95+
},
96+
"properties": {
97+
"name": {
98+
"type": "string"
99+
},
100+
"surname": {
101+
"type": "string"
102+
},
103+
"email": {
104+
"type": "string",
105+
"pattern": "[a-z0-9!#$%&'*+/=?^_`{|}~-]+(?:\\.[a-z0-9!#$%&'*+/=?^_`{|}~-]+)*@(?:[a-z0-9](?:[a-z0-9-]*[a-z0-9])?\\.)+[a-z0-9](?:[a-z0-9-]*[a-z0-9])?"
106+
},
107+
"phone": {
108+
"type": "string",
109+
"pattern": "\\+?\\d{1,4}?[-. ]?\\(\\d{1,3}\\)?[-. ]?\\d{1,4}[-. ]?\\d{1,4}[-. ]?\\d{1,9}"
110+
},
111+
"website": {
112+
"type": "string",
113+
"pattern": "(https?:\\/\\/)?([\\da-z\\.-]+)\\.([a-z\\.]{2,6})([\\/\\w \\.-]*)*\\/?"
114+
},
115+
"resume": {
116+
"type": "array",
117+
"items": {
118+
"$ref": "#/definitions/experienceEntry"
119+
}
120+
}
121+
},
122+
"required": ["name", "surname", "email", "phone", "resume"]
123+
}'''
124+
}
125+
]
126+
127+
128+
class V2IndexBenchmark:
129+
def setup(self, regex):
130+
self.vocab = Vocabulary.from_pretrained("unsloth/Llama-3.1-8B-Instruct")
131+
self.v2_index = Index(regex, self.vocab)
132+
133+
self.v2_guide = Guide(self.v2_index)
134+
135+
self.mask = create_mask(len(self.vocab) + 1)
136+
137+
self.process = psutil.Process()
138+
139+
assert (
140+
not self.v2_guide.is_finished()
141+
), f"Compressed Guide should not be finished for {regex}"
142+
143+
def run_benchmark(self):
144+
iterations = 0
145+
v2_total_time = 0
146+
147+
self.current_token_id = -1
148+
149+
if not self.v2_guide.is_finished():
150+
iterations += 1
151+
152+
start_compressed = time.perf_counter()
153+
self.v2_guide.get_tokens(self.mask)
154+
end_compressed = time.perf_counter()
155+
156+
v2_time = end_compressed - start_compressed
157+
v2_total_time += v2_time
158+
159+
160+
mask_tokens_list = mask_to_list(self.mask)
161+
random_idx = random.randrange(len(mask_tokens_list))
162+
self.current_token_id = mask_tokens_list[random_idx]
163+
164+
165+
while not self.v2_guide.is_finished():
166+
iterations += 1
167+
168+
start_compressed = time.perf_counter()
169+
self.v2_guide.advance(self.current_token_id, self.mask)
170+
end_compressed = time.perf_counter()
171+
172+
v2_time = end_compressed - start_compressed
173+
v2_total_time += v2_time
174+
175+
176+
if not self.v2_guide.is_finished():
177+
if iterations > 2000 :
178+
break
179+
mask_tokens_list = mask_to_list(self.mask)
180+
random_idx = random.randrange(len(mask_tokens_list))
181+
182+
self.current_token_id = mask_tokens_list[random_idx]
183+
184+
185+
186+
v2_total_time_us = v2_total_time * 1e6
187+
188+
print(f" Total iterations (Number of tokens): {iterations}")
189+
print(
190+
f" Guide with Compressed Index: {v2_total_time_us:.2f} µs ({v2_total_time_us / iterations:.2f} µs per iteration)"
191+
)
192+
193+
194+
195+
def test_benchmark_v2index():
196+
for r in regexes:
197+
name = r["name"]
198+
regex = r["regex"]
199+
200+
print(f"> Regex : '{name}'")
201+
bench = V2IndexBenchmark()
202+
bench.setup(regex)
203+
bench.run_benchmark()
204+
205+
for s in schemas:
206+
name = s["name"]
207+
schema = s["regex"]
208+
regex = build_regex_from_schema(schema, None)
209+
print(f"> Schema : '{name}'")
210+
bench = V2IndexBenchmark()
211+
bench.setup(regex)
212+
bench.run_benchmark()
213+
214+
215+
if __name__ == "__main__":
216+
print("Running main...")
217+
test_benchmark_v2index()

python/outlines_core/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from importlib.metadata import PackageNotFoundError, version
33

44
from .outlines_core_rs import Guide, Index, Vocabulary
5+
from .utils import create_mask, first_token_id_from_mask, mask_to_list
56

67
try:
78
__version__ = version("outlines_core")

python/outlines_core/outlines_core_rs.pyi

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from typing import Dict, List, Optional, Set, Tuple, Union
2+
import array
23

34
def build_regex_from_schema(
45
json_schema: str, whitespace_pattern: Optional[str] = None
@@ -26,10 +27,10 @@ class Guide:
2627
def get_state(self) -> int:
2728
"""Retrieves current state id of the Guide."""
2829
...
29-
def get_tokens(self) -> List[int]:
30+
def get_tokens(self, mask:Optional[array.array]) -> List[int]:
3031
"""Gets the list of allowed tokens for the current state."""
3132
...
32-
def advance(self, token_id: int) -> List[int]:
33+
def advance(self, token_id: int, mask: Optional[array.array]) -> List[int]:
3334
"""Guide moves to the next state provided by the token id and returns a list of allowed tokens."""
3435
...
3536
def is_finished(self) -> bool:
@@ -86,7 +87,7 @@ class Index:
8687
def __init__(self, regex: str, vocabulary: "Vocabulary"):
8788
"""Creates an index from a regex and vocabulary."""
8889
...
89-
def get_allowed_tokens(self, state: int) -> Optional[List[int]]:
90+
def get_allowed_tokens(self, state: int, mask: Optional[array.array]) -> Optional[List[int]]:
9091
"""Returns allowed tokens in this state."""
9192
...
9293
def get_next_state(self, state: int, token_id: int) -> Optional[int]:

python/outlines_core/utils.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import array
2+
from typing import List
3+
4+
5+
def mask_to_list(mask_buffer: array.array) -> List[int]:
6+
"""
7+
Converts a mask buffer into a list of token IDs where bits are set to 1.
8+
9+
Args:
10+
mask_buffer: array.array containing the mask bits.
11+
12+
Returns:
13+
List[int]: A list of token IDs corresponding to bits set to 1 in the mask.
14+
"""
15+
16+
tokens = []
17+
for word_idx, word in enumerate(mask_buffer):
18+
base = word_idx * 64
19+
for bit_idx in range(64):
20+
if word & (1 << bit_idx):
21+
tokens.append(base + bit_idx)
22+
23+
return tokens
24+
25+
26+
def create_mask(size: int) -> array.array:
27+
"""
28+
Creates a mask buffer initialized with zeros for a given number of bits.
29+
30+
Args:
31+
size (int): The number of bits the mask should represent (e.g., vocab_size).
32+
33+
Returns:
34+
array.array: A buffer of bytes initialized to zero, sized to hold `size` bits.
35+
Each byte represents 8 bits, so the length is ceil(size / 8).
36+
37+
Raises:
38+
ValueError: If size is not positive.
39+
"""
40+
if size <= 0:
41+
raise ValueError("Mask size must be positive")
42+
u64_size = (size + 63) // 64
43+
return array.array("Q", [0] * u64_size)
44+
45+
46+
def first_token_id_from_mask(mask_buffer: array.array) -> int:
47+
bytes_data = mask_buffer.tobytes()
48+
49+
# Parcourir chaque octet
50+
for byte_idx, byte in enumerate(bytes_data):
51+
if byte: # Si l'octet contient au moins un bit à 1
52+
# Trouver le premier bit à 1 dans cet octet
53+
for bit_idx in range(8):
54+
if byte & (128 >> bit_idx): # Vérifier le bit de gauche à droite (MSB)
55+
return byte_idx * 8 + bit_idx
56+
57+
return -1

rustfmt.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@ group_imports = "StdExternalCrate"
22
imports_granularity = "Module"
33
reorder_impl_items = true
44
reorder_imports = true
5+

0 commit comments

Comments
 (0)