Skip to content

Commit c1fce50

Browse files
committed
update
1 parent 614f631 commit c1fce50

File tree

2 files changed

+50
-7
lines changed

2 files changed

+50
-7
lines changed

python/mlc_llm/bench/api_endpoint.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,9 @@
77
import traceback
88
from typing import Optional
99

10-
from typing_extensions import Self
11-
1210
from mlc_llm.bench.request_record import Metrics, RequestRecord, ServerMetrics
1311
from mlc_llm.support import logging
12+
from typing_extensions import Self
1413

1514
logging.enable_logging()
1615
logger = logging.getLogger(__name__)
@@ -67,7 +66,7 @@ async def __aexit__(self, exc_type, exc_value, tb) -> None:
6766
async def __call__( # pylint: disable=too-many-branches,too-many-statements,too-many-locals
6867
self, request_record: RequestRecord
6968
) -> RequestRecord:
70-
payload = request_record.chat_cmpl.model_dump()
69+
payload = request_record.chat_cmpl.model_dump(exclude_unset=True, exclude_none=True)
7170
if self.timeout is not None and "timeout" not in payload:
7271
payload["timeout"] = self.timeout
7372
if self.include_server_metrics:
@@ -81,6 +80,12 @@ async def __call__( # pylint: disable=too-many-branches,too-many-statements,too
8180
):
8281
payload["ignore_eos"] = True
8382

83+
print(payload)
84+
85+
if "response_format" in payload and "json_schema" in payload["response_format"]:
86+
payload["response_format"]["schema"] = payload["response_format"]["json_schema"]
87+
payload["response_format"].pop("json_schema")
88+
8489
generated_text = ""
8590
first_chunk_output_str = ""
8691
time_to_first_token_s = None

python/mlc_llm/bench/dataset.py

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,13 @@
77

88
import numpy as np
99
from datasets import load_dataset # pylint: disable=import-error
10-
from transformers import AutoTokenizer # pylint: disable=import-error
11-
1210
from mlc_llm.bench.request_record import Metrics, RequestRecord
1311
from mlc_llm.protocol.openai_api_protocol import (
1412
ChatCompletionMessage,
1513
ChatCompletionRequest,
1614
DebugConfig,
1715
)
16+
from transformers import AutoTokenizer # pylint: disable=import-error
1817

1918

2019
class Dataset: # pylint: disable=too-few-public-methods
@@ -243,10 +242,11 @@ class JSONModeEvalDataset(Dataset): # pylint: disable=too-few-public-methods
243242
"""The dataset class for JSON dataset."""
244243

245244
def __init__(self, tokenizer: AutoTokenizer) -> None:
246-
raw_dataset = load_dataset("NousResearch/json-mode-eval")
245+
raw_dataset = load_dataset("NousResearch/json-mode-eval", split="train")
247246
self.tokenizer = tokenizer
248247
self.dataset = []
249-
for data in raw_dataset["train"]:
248+
for data in raw_dataset:
249+
data = self._process_data(data)
250250
messages = data["prompt"]
251251
schema = {
252252
"type": "json_object",
@@ -259,6 +259,40 @@ def __init__(self, tokenizer: AutoTokenizer) -> None:
259259
)
260260
self.dataset.append((messages, schema, num_tokens))
261261

262+
def _process_data(self, data):
263+
data["prompt"][0]["content"] = data["prompt"][0]["content"].replace(
264+
", 'format': 'email'", ""
265+
)
266+
data["schema"] = data["schema"].replace(', "format": "email"', "")
267+
268+
data["prompt"][0]["content"] = data["prompt"][0]["content"].replace(
269+
", 'pattern': '\\\\d{5}'", ""
270+
)
271+
data["schema"] = data["schema"].replace(', "pattern": "\\\\d{5}"', "")
272+
273+
schema_str = data["schema"]
274+
schema = json.loads(schema_str)
275+
new_schema = None
276+
if "type" not in schema:
277+
if len(schema.keys()) == 1:
278+
key = list(schema.keys())[0]
279+
new_schema = {"title": key, **schema[key]}
280+
else:
281+
new_schema = {"type": "object", **schema}
282+
if new_schema is None:
283+
return data
284+
return {
285+
"prompt": [
286+
{
287+
"content": f"You are a helpful assistant that answers in JSON. Here's the json schema you must adhere to:\n<schema>\n{new_schema}\n</schema>\n",
288+
"role": "system",
289+
},
290+
data["prompt"][1],
291+
],
292+
"completion": data["completion"],
293+
"schema": json.dumps(new_schema),
294+
}
295+
262296
def generate_request_records(
263297
self,
264298
input_len: Optional[int],
@@ -288,6 +322,10 @@ def generate_request_records(
288322
model="",
289323
max_tokens=output_length,
290324
response_format=schema,
325+
debug_config=DebugConfig(
326+
grammar_execution_mode="constraint",
327+
compact_json_output=True,
328+
),
291329
),
292330
metrics=Metrics(
293331
success=False,

0 commit comments

Comments
 (0)