1
1
"""MLC LLM bench prompts generator"""
2
+
2
3
import json
3
4
import random
5
+ from collections import defaultdict
4
6
from pathlib import Path
5
7
from typing import Any , Dict , List , Optional
6
8
@@ -18,6 +20,7 @@ class PromptsGenerator: # pylint: disable=too-few-public-methods
18
20
def __init__ (
19
21
self ,
20
22
prompts_path : Optional [str ] = None ,
23
+ json_prompts_path : Optional [str ] = None ,
21
24
tokenizer : Optional [Any ] = None ,
22
25
seed : Optional [int ] = 11111 ,
23
26
) -> None :
@@ -32,6 +35,11 @@ def __init__(
32
35
or a .jsonl file where each line is a JSON object formatted as
33
36
{"prompt": "prompt text", "prompt_tokens": 10}.
34
37
38
+ json_prompts_path : Optional[str]
39
+ The path to the file containing the source json prompts. This file a
40
+ .jsonl file where each line is a JSON object formatted as
41
+ {"messages": List[Dict[str, Any]], "response_format": Dict[str, Any]}.
42
+
35
43
tokenizer : Optional[Any]
36
44
The tokenizer object to use for tokenizing the prompts.
37
45
@@ -66,6 +74,22 @@ def __init__(
66
74
prompt_line = file .readline ()
67
75
prompt_tokens = self ._count_tokens (prompt_line )
68
76
self .prompts .append ({"prompt" : prompt_line , "prompt_tokens" : prompt_tokens })
77
+ if json_prompts_path :
78
+ self .json_prompts = defaultdict (list )
79
+ with open (json_prompts_path , "r" , encoding = "utf-8" ) as file :
80
+ for line in file :
81
+ json_line = json .loads (line )
82
+ assert (
83
+ "messages" in json_line
84
+ ), "The messages field is required in the JSONL file."
85
+ assert (
86
+ "response_format" in json_line
87
+ ), "The response_format field is required in the JSONL file."
88
+ self .json_prompts [json .dumps (json_line ["response_format" ]["schema" ])].append (
89
+ json_line ["messages" ]
90
+ )
91
+ else :
92
+ self .json_prompts = None
69
93
70
94
def _count_tokens (self , text : str ) -> int :
71
95
"""Get the number of tokens.
@@ -82,40 +106,44 @@ def _count_tokens(self, text: str) -> int:
82
106
"""
83
107
return len (self .tokenizer .encode (text ))
84
108
85
- def generate_prompt (self , tokens_mean : int , tokens_stddev : Optional [ int ] = 0 ) -> str :
109
+ def generate_prompt (self , params : Dict [ str , Any ] ) -> Dict [ str , Any ] :
86
110
"""
87
- Generates a prompt that closely matches the desired token count .
111
+ Generates a prompt based on the params, e.g. prompt_tokens, response_format .
88
112
89
113
Parameters
90
114
----------
91
- token_mean : int
115
+ params : Dict[str, Any]
92
116
The desired mean number of tokens in the prompt.
93
117
94
- token_stddev : Optional[int]
95
- The desired standard deviation of tokens in the prompt.
96
-
97
118
Returns
98
119
-------
99
- out: str
100
- A prompt string with the specified number of tokens .
120
+ override_params: Dict[ str, Any]
121
+ The params to override the original request, e.g. messages, response_format .
101
122
"""
123
+ if "response_format" in params :
124
+ response_format = params ["response_format" ]
125
+ if response_format .get ("type" ) == "json_object" :
126
+ if response_format .get ("schema" ) in self .json_prompts :
127
+ assert len (self .json_prompts [response_format ["schema" ]]) > 0
128
+ return {"messages" : random .choice (self .json_prompts [response_format ["schema" ]])}
129
+ schema , prompts = random .choice (list (self .json_prompts .items ()))
130
+ response_format ["schema" ] = schema
131
+ return {"messages" : random .choice (prompts ), "response_format" : response_format }
132
+ tokens_mean = params .get ("prompt_tokens" , 128 )
102
133
assert tokens_mean > 0 , "The mean number of tokens must be greater than 0."
103
- out_prompt_tokens = (
104
- int (random .gauss (tokens_mean , tokens_stddev )) if tokens_stddev else tokens_mean
105
- )
106
- if out_prompt_tokens <= 0 :
107
- out_prompt_tokens = tokens_mean
108
- remaining_prompt_tokens = out_prompt_tokens
134
+ remaining_prompt_tokens = tokens_mean
109
135
result_prompt = ""
136
+ override_params = None
110
137
while remaining_prompt_tokens > 0 :
111
138
prompt_dict = random .choice (self .prompts )
112
139
cur_prompt_tokens = prompt_dict ["prompt_tokens" ]
113
140
cur_prompt = prompt_dict ["prompt" ]
141
+ if override_params is None :
142
+ override_params = prompt_dict ["override_params" ]
114
143
if remaining_prompt_tokens - cur_prompt_tokens < 0 :
115
144
result_prompt += cur_prompt [:remaining_prompt_tokens ]
116
145
remaining_prompt_tokens = 0
117
146
break
118
147
result_prompt += cur_prompt
119
148
remaining_prompt_tokens -= cur_prompt_tokens
120
- self ._count_tokens (result_prompt )
121
- return result_prompt
149
+ return {"messages" : [{"role" : "system" , "content" : result_prompt }]}
0 commit comments