Skip to content

Commit 07c92b0

Browse files
authored
[Bench] Json mode bench (#2552)
* [Bench] Json mode bench This PR refactors mlc bench to enable json mode in dataset. * upd * fix lint
1 parent dcece51 commit 07c92b0

File tree

2 files changed

+53
-26
lines changed

2 files changed

+53
-26
lines changed

python/mlc_llm/bench/prompts.py

Lines changed: 44 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
"""MLC LLM bench prompts generator"""
2+
23
import json
34
import random
5+
from collections import defaultdict
46
from pathlib import Path
57
from typing import Any, Dict, List, Optional
68

@@ -18,6 +20,7 @@ class PromptsGenerator: # pylint: disable=too-few-public-methods
1820
def __init__(
1921
self,
2022
prompts_path: Optional[str] = None,
23+
json_prompts_path: Optional[str] = None,
2124
tokenizer: Optional[Any] = None,
2225
seed: Optional[int] = 11111,
2326
) -> None:
@@ -32,6 +35,11 @@ def __init__(
3235
or a .jsonl file where each line is a JSON object formatted as
3336
{"prompt": "prompt text", "prompt_tokens": 10}.
3437
38+
json_prompts_path : Optional[str]
39+
The path to the file containing the source json prompts. This file a
40+
.jsonl file where each line is a JSON object formatted as
41+
{"messages": List[Dict[str, Any]], "response_format": Dict[str, Any]}.
42+
3543
tokenizer : Optional[Any]
3644
The tokenizer object to use for tokenizing the prompts.
3745
@@ -66,6 +74,22 @@ def __init__(
6674
prompt_line = file.readline()
6775
prompt_tokens = self._count_tokens(prompt_line)
6876
self.prompts.append({"prompt": prompt_line, "prompt_tokens": prompt_tokens})
77+
if json_prompts_path:
78+
self.json_prompts = defaultdict(list)
79+
with open(json_prompts_path, "r", encoding="utf-8") as file:
80+
for line in file:
81+
json_line = json.loads(line)
82+
assert (
83+
"messages" in json_line
84+
), "The messages field is required in the JSONL file."
85+
assert (
86+
"response_format" in json_line
87+
), "The response_format field is required in the JSONL file."
88+
self.json_prompts[json.dumps(json_line["response_format"]["schema"])].append(
89+
json_line["messages"]
90+
)
91+
else:
92+
self.json_prompts = None
6993

7094
def _count_tokens(self, text: str) -> int:
7195
"""Get the number of tokens.
@@ -82,40 +106,44 @@ def _count_tokens(self, text: str) -> int:
82106
"""
83107
return len(self.tokenizer.encode(text))
84108

85-
def generate_prompt(self, tokens_mean: int, tokens_stddev: Optional[int] = 0) -> str:
109+
def generate_prompt(self, params: Dict[str, Any]) -> Dict[str, Any]:
86110
"""
87-
Generates a prompt that closely matches the desired token count.
111+
Generates a prompt based on the params, e.g. prompt_tokens, response_format.
88112
89113
Parameters
90114
----------
91-
token_mean : int
115+
params : Dict[str, Any]
92116
The desired mean number of tokens in the prompt.
93117
94-
token_stddev : Optional[int]
95-
The desired standard deviation of tokens in the prompt.
96-
97118
Returns
98119
-------
99-
out: str
100-
A prompt string with the specified number of tokens.
120+
override_params: Dict[str, Any]
121+
The params to override the original request, e.g. messages, response_format.
101122
"""
123+
if "response_format" in params:
124+
response_format = params["response_format"]
125+
if response_format.get("type") == "json_object":
126+
if response_format.get("schema") in self.json_prompts:
127+
assert len(self.json_prompts[response_format["schema"]]) > 0
128+
return {"messages": random.choice(self.json_prompts[response_format["schema"]])}
129+
schema, prompts = random.choice(list(self.json_prompts.items()))
130+
response_format["schema"] = schema
131+
return {"messages": random.choice(prompts), "response_format": response_format}
132+
tokens_mean = params.get("prompt_tokens", 128)
102133
assert tokens_mean > 0, "The mean number of tokens must be greater than 0."
103-
out_prompt_tokens = (
104-
int(random.gauss(tokens_mean, tokens_stddev)) if tokens_stddev else tokens_mean
105-
)
106-
if out_prompt_tokens <= 0:
107-
out_prompt_tokens = tokens_mean
108-
remaining_prompt_tokens = out_prompt_tokens
134+
remaining_prompt_tokens = tokens_mean
109135
result_prompt = ""
136+
override_params = None
110137
while remaining_prompt_tokens > 0:
111138
prompt_dict = random.choice(self.prompts)
112139
cur_prompt_tokens = prompt_dict["prompt_tokens"]
113140
cur_prompt = prompt_dict["prompt"]
141+
if override_params is None:
142+
override_params = prompt_dict["override_params"]
114143
if remaining_prompt_tokens - cur_prompt_tokens < 0:
115144
result_prompt += cur_prompt[:remaining_prompt_tokens]
116145
remaining_prompt_tokens = 0
117146
break
118147
result_prompt += cur_prompt
119148
remaining_prompt_tokens -= cur_prompt_tokens
120-
self._count_tokens(result_prompt)
121-
return result_prompt
149+
return {"messages": [{"role": "system", "content": result_prompt}]}

python/mlc_llm/bench/request.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""MLC LLM Bench Request"""
2+
23
import json
34
import os
45
import time
@@ -45,6 +46,8 @@ class OpenAIRequestSender: # pylint: disable=too-many-instance-attributes
4546
The client to use for sending requests.
4647
include_server_metrics : Optional[bool]
4748
Specifies if server metrics should be included, default is False.
49+
prompt_generator : Optional[PromptsGenerator]
50+
The prompt generator for missing messages fields.
4851
4952
Attributes
5053
----------
@@ -60,6 +63,7 @@ def __init__( # pylint: disable=too-many-arguments
6063
timeout: Optional[float] = None,
6164
client: Optional[Any] = None,
6265
include_server_metrics: Optional[bool] = False,
66+
prompt_generator: Optional[PromptsGenerator] = None,
6367
) -> None:
6468
import aiohttp # pylint: disable=import-outside-toplevel,import-error
6569
from transformers import ( # pylint: disable=import-outside-toplevel,import-error
@@ -69,7 +73,7 @@ def __init__( # pylint: disable=too-many-arguments
6973
self.stream = stream
7074
self.timeout = timeout
7175
self.tokenizer = LlamaTokenizerFast.from_pretrained("hf-internal-testing/llama-tokenizer")
72-
self.prompt_generator = PromptsGenerator()
76+
self.prompt_generator = PromptsGenerator() if prompt_generator is None else prompt_generator
7377
self.request_records: List[RequestRecords] = []
7478
self.client = client if client else aiohttp.ClientSession()
7579
self.include_server_metrics = include_server_metrics
@@ -88,15 +92,10 @@ async def __call__( # pylint: disable=too-many-locals, too-many-branches, too-m
8892
self, params: Dict[str, Any] = None
8993
) -> None:
9094
if "messages" not in params:
91-
prompt_tokens = 128
92-
if "prompt_tokens" in params:
93-
prompt_tokens = params["prompt_tokens"]
94-
else:
95-
logger.warning("A random prompt with %d tokens will be generated.", prompt_tokens)
96-
prompt = self.prompt_generator.generate_prompt(prompt_tokens)
97-
params["messages"] = [{"role": "system", "content": prompt}]
98-
else:
99-
prompt = params["messages"][-1]["content"]
95+
override_params = self.prompt_generator.generate_prompt(params)
96+
assert "messages" in override_params, "override params must contain messages field"
97+
params.update(override_params)
98+
prompt = params["messages"][-1]["content"]
10099
chat_params = self._get_chat_completion_params(params)
101100
if "stream" not in chat_params:
102101
chat_params["stream"] = self.stream

0 commit comments

Comments
 (0)