1
+ # flake8: noqa
2
+ # mypy: ignore-errors
3
+ import os
4
+ import random
5
+ import time
6
+
7
+ import psutil
8
+ from outlines_core import Guide , Index , Vocabulary , create_mask , mask_to_list
9
+ from outlines_core .json_schema import build_regex_from_schema
10
+
11
+ os .environ ["RUST_LOG" ] = "debug"
12
+
13
+
14
+ regexes = [
15
+ {
16
+ "name" : "email" ,
17
+ "regex" : r"(?:[a-z0-9!#$%&'*+/=?^_`{|}~-]{1,63}(?:\.[a-z0-9!#$%&'*+/=?^_`{|}~-]{1,63}){0,10})@(?:[a-z0-9](?:[a-z0-9-]{0,61}[a-z0-9])?\.){1,3}[a-z0-9](?:[a-z0-9-]{0,30}[a-z0-9])?" ,
18
+ },
19
+ {"name" : "simple_phone" , "regex" : r"\+?[1-9][0-9]{7,14}" },
20
+ {
21
+ "name" : "complex_phone" ,
22
+ "regex" : r"\+?\d{1,4}?[-.\s]?\(?\d{1,3}?\)?[-.\s]?\d{1,4}[-.\s]?\d{1,4}[-.\s]?\d{1,9}" ,
23
+ },
24
+ {"name" : "permissive_any" , "regex" : r".{255}$" },
25
+ {"name" : "permissive_words" , "regex" : r"[a-zA-Z]{100}" },
26
+ {"name" : "https" , "regex" : r"(https?:\\/\\/)?([\\da-z\\.-]+)\\.([a-z\\.]{2,6})([\\/\\w \\.-]*)*\\/?" }
27
+ ]
28
+ schemas = [
29
+ {
30
+ "name" : "schema_simple" ,
31
+ "regex" : r'{"type": "object", "properties": {"name": {"type": "string"}, "age": {"type": "integer"}}, "required": ["name", "age"]}' ,
32
+ },
33
+ {
34
+ "name" : "schema_simple_phone" ,
35
+ "regex" : r'{"type": "object", "properties": {"name": {"type": "string"}, "age": {"type": "integer"}, "complexe_phone": {"type": "string", "pattern": "\\+?\\d{1,4}?[-. ]?\\(\\d{1,3}\\)?[-. ]?\\d{1,4}[-. ]?\\d{1,4}[-. ]?\\d{1,9}"}}, "required": ["name", "age", "complexe_phone"]}' ,
36
+ },
37
+ {
38
+ "name" : "schema_complexe" ,
39
+ "regex" : """{
40
+ "$schema": "http://json-schema.org/draft-04/schema#",
41
+ "title": "Schema for a recording",
42
+ "type": "object",
43
+ "definitions": {
44
+ "artist": {
45
+ "type": "object",
46
+ "properties": {
47
+ "id": {"type": "number"},
48
+ "name": {"type": "string"},
49
+ "functions": {
50
+ "type": "array",
51
+ "items": {"type": "string"}
52
+ }
53
+ },
54
+ "required": ["id", "name", "functions"]
55
+ }
56
+ },
57
+ "properties": {
58
+ "id": {"type": "number"},
59
+ "work": {
60
+ "type": "object",
61
+ "properties": {
62
+ "id": {"type": "number"},
63
+ "name": {"type": "string"},
64
+ "composer": {"$ref": "#/definitions/artist"}
65
+ }
66
+ },
67
+ "recording_artists": {
68
+ "type": "array",
69
+ "items": {"$ref": "#/definitions/artist"}
70
+ }
71
+ },
72
+ "required": ["id", "work", "recording_artists"]
73
+ }"""
74
+ },
75
+ {
76
+ "name" : "schema_curriculum" ,
77
+ "regex" : r'''{
78
+ "$schema": "http://json-schema.org/draft-04/schema#",
79
+ "title": "Schema for a Curriculum Vitae",
80
+ "type": "object",
81
+ "definitions": {
82
+ "experienceEntry": {
83
+ "type": "object",
84
+ "properties": {
85
+ "date": {
86
+ "type": "string",
87
+ "format": "date"
88
+ },
89
+ "position": {
90
+ "type": "string"
91
+ }
92
+ },
93
+ "required": ["date", "position"]
94
+ }
95
+ },
96
+ "properties": {
97
+ "name": {
98
+ "type": "string"
99
+ },
100
+ "surname": {
101
+ "type": "string"
102
+ },
103
+ "email": {
104
+ "type": "string",
105
+ "pattern": "[a-z0-9!#$%&'*+/=?^_`{|}~-]+(?:\\.[a-z0-9!#$%&'*+/=?^_`{|}~-]+)*@(?:[a-z0-9](?:[a-z0-9-]*[a-z0-9])?\\.)+[a-z0-9](?:[a-z0-9-]*[a-z0-9])?"
106
+ },
107
+ "phone": {
108
+ "type": "string",
109
+ "pattern": "\\+?\\d{1,4}?[-. ]?\\(\\d{1,3}\\)?[-. ]?\\d{1,4}[-. ]?\\d{1,4}[-. ]?\\d{1,9}"
110
+ },
111
+ "website": {
112
+ "type": "string",
113
+ "pattern": "(https?:\\/\\/)?([\\da-z\\.-]+)\\.([a-z\\.]{2,6})([\\/\\w \\.-]*)*\\/?"
114
+ },
115
+ "resume": {
116
+ "type": "array",
117
+ "items": {
118
+ "$ref": "#/definitions/experienceEntry"
119
+ }
120
+ }
121
+ },
122
+ "required": ["name", "surname", "email", "phone", "resume"]
123
+ }'''
124
+ }
125
+ ]
126
+
127
+
128
+ class V2IndexBenchmark :
129
+ def setup (self , regex ):
130
+ self .vocab = Vocabulary .from_pretrained ("unsloth/Llama-3.1-8B-Instruct" )
131
+ self .v2_index = Index (regex , self .vocab )
132
+
133
+ self .v2_guide = Guide (self .v2_index )
134
+
135
+ self .mask = create_mask (len (self .vocab ) + 1 )
136
+
137
+ self .process = psutil .Process ()
138
+
139
+ assert (
140
+ not self .v2_guide .is_finished ()
141
+ ), f"Compressed Guide should not be finished for { regex } "
142
+
143
+ def run_benchmark (self ):
144
+ iterations = 0
145
+ v2_total_time = 0
146
+
147
+ self .current_token_id = - 1
148
+
149
+ if not self .v2_guide .is_finished ():
150
+ iterations += 1
151
+
152
+ start_compressed = time .perf_counter ()
153
+ self .v2_guide .get_tokens (self .mask )
154
+ end_compressed = time .perf_counter ()
155
+
156
+ v2_time = end_compressed - start_compressed
157
+ v2_total_time += v2_time
158
+
159
+
160
+ mask_tokens_list = mask_to_list (self .mask )
161
+ random_idx = random .randrange (len (mask_tokens_list ))
162
+ self .current_token_id = mask_tokens_list [random_idx ]
163
+
164
+
165
+ while not self .v2_guide .is_finished ():
166
+ iterations += 1
167
+
168
+ start_compressed = time .perf_counter ()
169
+ self .v2_guide .advance (self .current_token_id , self .mask )
170
+ end_compressed = time .perf_counter ()
171
+
172
+ v2_time = end_compressed - start_compressed
173
+ v2_total_time += v2_time
174
+
175
+
176
+ if not self .v2_guide .is_finished ():
177
+ if iterations > 2000 :
178
+ break
179
+ mask_tokens_list = mask_to_list (self .mask )
180
+ random_idx = random .randrange (len (mask_tokens_list ))
181
+
182
+ self .current_token_id = mask_tokens_list [random_idx ]
183
+
184
+
185
+
186
+ v2_total_time_us = v2_total_time * 1e6
187
+
188
+ print (f" Total iterations (Number of tokens): { iterations } " )
189
+ print (
190
+ f" Guide with Compressed Index: { v2_total_time_us :.2f} µs ({ v2_total_time_us / iterations :.2f} µs per iteration)"
191
+ )
192
+
193
+
194
+
195
+ def test_benchmark_v2index ():
196
+ for r in regexes :
197
+ name = r ["name" ]
198
+ regex = r ["regex" ]
199
+
200
+ print (f"> Regex : '{ name } '" )
201
+ bench = V2IndexBenchmark ()
202
+ bench .setup (regex )
203
+ bench .run_benchmark ()
204
+
205
+ for s in schemas :
206
+ name = s ["name" ]
207
+ schema = s ["regex" ]
208
+ regex = build_regex_from_schema (schema , None )
209
+ print (f"> Schema : '{ name } '" )
210
+ bench = V2IndexBenchmark ()
211
+ bench .setup (regex )
212
+ bench .run_benchmark ()
213
+
214
+
215
+ if __name__ == "__main__" :
216
+ print ("Running main..." )
217
+ test_benchmark_v2index ()
0 commit comments