Skip to content

Commit 14ef326

Browse files
committed
[Bench] Add support for multiple backend
This PR adds the support for vllm and llama.cpp backend, especially for json generation.
1 parent e283cd0 commit 14ef326

File tree

3 files changed

+91
-7
lines changed

3 files changed

+91
-7
lines changed

python/mlc_llm/bench/api_endpoint.py

+42-3
Original file line numberDiff line numberDiff line change
@@ -41,19 +41,23 @@ def __init__( # pylint: disable=too-many-arguments
4141
self,
4242
host: str,
4343
port: int,
44+
backend: str,
4445
timeout: Optional[float] = None,
4546
include_server_metrics: bool = False,
47+
no_debug_config: bool = False,
4648
) -> None:
4749
super().__init__(include_server_metrics=include_server_metrics)
4850

4951
import aiohttp # pylint: disable=import-outside-toplevel,import-error
5052

53+
self.backend = backend
5154
self.timeout = timeout
5255
self.client: aiohttp.ClientSession = None
5356
self.url = f"http://{host}:{port}/v1/chat/completions"
5457
self.headers = {"Content-Type": "application/json"}
5558
if os.getenv("MLC_LLM_API_KEY"):
5659
self.headers["Authorization"] = f"Bearer {os.getenv('MLC_LLM_API_KEY')}"
60+
self.no_debug_config = no_debug_config
5761

5862
async def __aenter__(self) -> Self:
5963
import aiohttp # pylint: disable=import-outside-toplevel,import-error
@@ -67,7 +71,7 @@ async def __aexit__(self, exc_type, exc_value, tb) -> None:
6771
async def __call__( # pylint: disable=too-many-branches,too-many-statements,too-many-locals
6872
self, request_record: RequestRecord
6973
) -> RequestRecord:
70-
payload = request_record.chat_cmpl.model_dump()
74+
payload = request_record.chat_cmpl.model_dump(exclude_unset=True, exclude_none=True)
7175
if self.timeout is not None and "timeout" not in payload:
7276
payload["timeout"] = self.timeout
7377
if self.include_server_metrics:
@@ -80,7 +84,28 @@ async def __call__( # pylint: disable=too-many-branches,too-many-statements,too
8084
and request_record.chat_cmpl.debug_config.ignore_eos
8185
):
8286
payload["ignore_eos"] = True
87+
if not self.no_debug_config:
88+
payload["debug_config"] = {"ignore_eos": True}
8389

90+
if self.backend == "vllm":
91+
if payload["debug_config"] and "ignore_eos" in payload["debug_config"]:
92+
payload["ignore_eos"] = payload["debug_config"]["ignore_eos"]
93+
payload.pop("debug_config")
94+
if "response_format" in payload:
95+
if "json_schema" in payload["response_format"]:
96+
payload["guided_json"] = json.loads(payload["response_format"]["json_schema"])
97+
payload["guided_decoding_backend"] = "outlines"
98+
payload.pop("response_format")
99+
elif self.backend == "llama.cpp":
100+
if "response_format" in payload and "schema" in payload["response_format"]:
101+
payload["response_format"]["schema"] = json.loads(
102+
payload["response_format"]["json_schema"]
103+
)
104+
payload["response_format"].pop("json_schema")
105+
else:
106+
if "response_format" in payload and "json_schema" in payload["response_format"]:
107+
payload["response_format"]["schema"] = payload["response_format"]["json_schema"]
108+
payload["response_format"].pop("json_schema")
84109
generated_text = ""
85110
first_chunk_output_str = ""
86111
time_to_first_token_s = None
@@ -441,19 +466,33 @@ async def __call__( # pylint: disable=too-many-branches,too-many-locals,too-man
441466
"sglang",
442467
"tensorrt-llm",
443468
"vllm",
469+
"vllm-chat",
470+
"llama.cpp-chat",
444471
]
445472

446473

447474
def create_api_endpoint(args: argparse.Namespace) -> APIEndPoint:
448475
"""Create an API endpoint instance with regard to the specified endpoint kind."""
449476
if args.api_endpoint in ["openai", "mlc", "sglang"]:
450477
return OpenAIEndPoint(args.host, args.port, args.timeout, args.include_server_metrics)
451-
if args.api_endpoint == "vllm":
478+
if args.api_endpoint in ["vllm", "llama.cpp"]:
452479
return OpenAIEndPoint(
453480
args.host, args.port, args.timeout, include_server_metrics=False, no_debug_config=True
454481
)
455482
if args.api_endpoint == "openai-chat":
456-
return OpenAIChatEndPoint(args.host, args.port, args.timeout, args.include_server_metrics)
483+
return OpenAIChatEndPoint(
484+
args.host, args.port, args.timeout, args.api_endpoint, args.include_server_metrics
485+
)
486+
if args.api_endpoint in ["vllm-chat", "llama.cpp-chat"]:
487+
return OpenAIChatEndPoint(
488+
args.host,
489+
args.port,
490+
args.api_endpoint[:-5],
491+
args.timeout,
492+
include_server_metrics=False,
493+
no_debug_config=True,
494+
)
495+
457496
if args.api_endpoint == "tensorrt-llm":
458497
return TensorRTLLMEndPoint(args.host, args.port, args.timeout)
459498
raise ValueError(f'Unrecognized endpoint "{args.api_endpoint}"')

python/mlc_llm/bench/dataset.py

+48-3
Original file line numberDiff line numberDiff line change
@@ -243,11 +243,17 @@ class JSONModeEvalDataset(Dataset): # pylint: disable=too-few-public-methods
243243
"""The dataset class for JSON dataset."""
244244

245245
def __init__(self, tokenizer: AutoTokenizer) -> None:
246-
raw_dataset = load_dataset("NousResearch/json-mode-eval")
246+
raw_dataset = load_dataset("NousResearch/json-mode-eval", split="train")
247247
self.tokenizer = tokenizer
248248
self.dataset = []
249-
for data in raw_dataset["train"]:
250-
messages = data["prompt"]
249+
for data in raw_dataset:
250+
data = self._process_data(data)
251+
messages = [
252+
{
253+
"content": data["prompt"][0]["content"] + " " + data["prompt"][1]["content"],
254+
"role": data["prompt"][1]["role"],
255+
},
256+
]
251257
schema = {
252258
"type": "json_object",
253259
"schema": data["schema"],
@@ -259,6 +265,42 @@ def __init__(self, tokenizer: AutoTokenizer) -> None:
259265
)
260266
self.dataset.append((messages, schema, num_tokens))
261267

268+
def _process_data(self, data):
269+
data["prompt"][0]["content"] = data["prompt"][0]["content"].replace(
270+
", 'format': 'email'", ""
271+
)
272+
data["schema"] = data["schema"].replace(', "format": "email"', "")
273+
274+
data["prompt"][0]["content"] = data["prompt"][0]["content"].replace(
275+
", 'pattern': '\\\\d{5}'", ""
276+
)
277+
data["schema"] = data["schema"].replace(', "pattern": "\\\\d{5}"', "")
278+
279+
schema_str = data["schema"]
280+
schema = json.loads(schema_str)
281+
new_schema = None
282+
if "type" not in schema:
283+
if len(schema.keys()) == 1:
284+
key = list(schema.keys())[0]
285+
new_schema = {"title": key, **schema[key]}
286+
else:
287+
new_schema = {"type": "object", **schema}
288+
if new_schema is None:
289+
return data
290+
return {
291+
"prompt": [
292+
{
293+
"content": "You are a helpful assistant that answers in JSON. "
294+
"Here's the json schema you must adhere to:"
295+
f"\n<schema>\n{new_schema}\n</schema>\n",
296+
"role": "system",
297+
},
298+
data["prompt"][1],
299+
],
300+
"completion": data["completion"],
301+
"schema": json.dumps(new_schema),
302+
}
303+
262304
def generate_request_records(
263305
self,
264306
input_len: Optional[int],
@@ -288,6 +330,9 @@ def generate_request_records(
288330
model="",
289331
max_tokens=output_length,
290332
response_format=schema,
333+
debug_config=DebugConfig(
334+
grammar_execution_mode="constraint",
335+
),
291336
),
292337
metrics=Metrics(
293338
success=False,

python/mlc_llm/bench/request_processor.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def __call__(self, request_records: List[RequestRecord]) -> List[RequestRecord]:
131131
request_record.chat_cmpl.top_p = self.top_p
132132
request_record.chat_cmpl.frequency_penalty = 0.0
133133
request_record.chat_cmpl.presence_penalty = 0.0
134-
request_record.chat_cmpl.tool_choice = "none"
134+
request_record.chat_cmpl.tool_choice = None
135135
if self.ignore_eos:
136136
request_record.chat_cmpl.debug_config = DebugConfig(ignore_eos=True)
137137
return request_records

0 commit comments

Comments
 (0)