7
7
8
8
import numpy as np
9
9
from datasets import load_dataset # pylint: disable=import-error
10
- from transformers import AutoTokenizer # pylint: disable=import-error
11
-
12
10
from mlc_llm .bench .request_record import Metrics , RequestRecord
13
11
from mlc_llm .protocol .openai_api_protocol import (
14
12
ChatCompletionMessage ,
15
13
ChatCompletionRequest ,
16
14
DebugConfig ,
17
15
)
16
+ from transformers import AutoTokenizer # pylint: disable=import-error
18
17
19
18
20
19
class Dataset : # pylint: disable=too-few-public-methods
@@ -243,10 +242,11 @@ class JSONModeEvalDataset(Dataset): # pylint: disable=too-few-public-methods
243
242
"""The dataset class for JSON dataset."""
244
243
245
244
def __init__ (self , tokenizer : AutoTokenizer ) -> None :
246
- raw_dataset = load_dataset ("NousResearch/json-mode-eval" )
245
+ raw_dataset = load_dataset ("NousResearch/json-mode-eval" , split = "train" )
247
246
self .tokenizer = tokenizer
248
247
self .dataset = []
249
- for data in raw_dataset ["train" ]:
248
+ for data in raw_dataset :
249
+ data = self ._process_data (data )
250
250
messages = data ["prompt" ]
251
251
schema = {
252
252
"type" : "json_object" ,
@@ -259,6 +259,40 @@ def __init__(self, tokenizer: AutoTokenizer) -> None:
259
259
)
260
260
self .dataset .append ((messages , schema , num_tokens ))
261
261
262
+ def _process_data (self , data ):
263
+ data ["prompt" ][0 ]["content" ] = data ["prompt" ][0 ]["content" ].replace (
264
+ ", 'format': 'email'" , ""
265
+ )
266
+ data ["schema" ] = data ["schema" ].replace (', "format": "email"' , "" )
267
+
268
+ data ["prompt" ][0 ]["content" ] = data ["prompt" ][0 ]["content" ].replace (
269
+ ", 'pattern': '\\ \\ d{5}'" , ""
270
+ )
271
+ data ["schema" ] = data ["schema" ].replace (', "pattern": "\\ \\ d{5}"' , "" )
272
+
273
+ schema_str = data ["schema" ]
274
+ schema = json .loads (schema_str )
275
+ new_schema = None
276
+ if "type" not in schema :
277
+ if len (schema .keys ()) == 1 :
278
+ key = list (schema .keys ())[0 ]
279
+ new_schema = {"title" : key , ** schema [key ]}
280
+ else :
281
+ new_schema = {"type" : "object" , ** schema }
282
+ if new_schema is None :
283
+ return data
284
+ return {
285
+ "prompt" : [
286
+ {
287
+ "content" : f"You are a helpful assistant that answers in JSON. Here's the json schema you must adhere to:\n <schema>\n { new_schema } \n </schema>\n " ,
288
+ "role" : "system" ,
289
+ },
290
+ data ["prompt" ][1 ],
291
+ ],
292
+ "completion" : data ["completion" ],
293
+ "schema" : json .dumps (new_schema ),
294
+ }
295
+
262
296
def generate_request_records (
263
297
self ,
264
298
input_len : Optional [int ],
@@ -288,6 +322,10 @@ def generate_request_records(
288
322
model = "" ,
289
323
max_tokens = output_length ,
290
324
response_format = schema ,
325
+ debug_config = DebugConfig (
326
+ grammar_execution_mode = "constraint" ,
327
+ compact_json_output = True ,
328
+ ),
291
329
),
292
330
metrics = Metrics (
293
331
success = False ,
0 commit comments