Skip to content

Commit 2c1001b

Browse files
authored
[Bench] Add AzureLLMInference dataset (#3104)
This PR introduces the [AzureLLMInference dataset](https://github.com/Azure/AzurePublicDataset). This dataset contains the timestamp for each entries, and this PR also introduces the dataset replay mode for mlc-llm benchmark. This mode reuses the provided timestamps for benchmark.
1 parent 03509ce commit 2c1001b

File tree

3 files changed

+222
-18
lines changed

3 files changed

+222
-18
lines changed

python/mlc_llm/bench/__main__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,14 @@ def _main():
247247
"When specified, the benchmark sends these many new requests each second. "
248248
'If it is "inf", all requests will be sent together at once.',
249249
)
250+
parser.add_argument(
251+
"--replay-timestamp-scale",
252+
type=float,
253+
help="The timestamp scale when replaying the timestamps in a dataset. "
254+
'The dataset replay mode is enabled when neither "--num-concurrent-requests" and '
255+
'"--request-rate" is specified. '
256+
"The scale is 1 by default in the replay mode.",
257+
)
250258
parser.add_argument(
251259
"--input-len",
252260
type=int,

python/mlc_llm/bench/dataset.py

Lines changed: 126 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@
33
import argparse
44
import json
55
import random
6+
from datetime import datetime
67
from typing import Dict, List, Optional, Tuple
78

89
import numpy as np
10+
import pandas as pd # pylint: disable=import-error
911
from datasets import load_dataset # pylint: disable=import-error
1012
from transformers import AutoTokenizer # pylint: disable=import-error
1113

@@ -25,6 +27,10 @@ class Dataset: # pylint: disable=too-few-public-methods
2527
# For some that datasets (e.g., dataset that has shared common prefix),
2628
# we need fake warmup requests to avoid prefilling common prefixes to the engine.
2729
require_fake_warmup: bool = False
30+
# Whether the dataset contains timestamps already.
31+
# If the dataset comes with timestamps, the benchmark can just replay
32+
# the requests according to their timestamps.
33+
timestamp_available: bool = False
2834

2935
def generate_request_records(
3036
self,
@@ -702,19 +708,111 @@ def generate_request_records( # pylint: disable=too-many-locals
702708
return request_records
703709

704710

711+
class AzureLLMInferenceDataset(Dataset): # pylint: disable=too-few-public-methods
712+
"""The dataset class for AzureLLMInference dataset.
713+
Reference: https://github.com/Azure/AzurePublicDataset
714+
"""
715+
716+
timestamp_available: bool = True
717+
718+
def __init__(self, dataset_path: str, tokenizer: AutoTokenizer) -> None:
719+
df = pd.read_csv(dataset_path)
720+
self.tokenizer = tokenizer
721+
722+
# Filter out the conversations with less than 2 turns.
723+
self.dataset = [
724+
(
725+
entry["TIMESTAMP"],
726+
min(entry["ContextTokens"], tokenizer.model_max_length, self.truncate_length),
727+
min(entry["GeneratedTokens"], tokenizer.model_max_length, self.truncate_length),
728+
)
729+
for _, entry in df.iterrows()
730+
if entry["ContextTokens"] >= 4 and entry["GeneratedTokens"] >= 4
731+
]
732+
733+
def generate_request_records( # pylint: disable=too-many-locals
734+
self,
735+
input_len: Optional[int],
736+
output_len: Optional[int],
737+
input_len_std: float = 0.0,
738+
output_len_std: float = 0.0,
739+
) -> List[RequestRecord]:
740+
time_fmt = "%Y-%m-%d %H:%M:%S.%f"
741+
start_time = datetime.strptime(self.dataset[0][0][:-1], time_fmt)
742+
request_records = []
743+
for timestamp, input_length, output_length in self.dataset:
744+
# If the request does not have enough length, discard it.
745+
if input_len is not None and input_length < input_len + 4 * input_len_std:
746+
continue
747+
748+
if input_len is not None:
749+
input_length = round(
750+
float(np.random.normal(loc=input_len, scale=input_len_std, size=1)[0])
751+
)
752+
if output_len is not None:
753+
output_length = round(
754+
float(np.random.normal(loc=output_len, scale=output_len_std, size=1)[0])
755+
)
756+
elif output_length <= 1:
757+
continue
758+
759+
prompt_token_ids = [
760+
random.randint(0, self.tokenizer.vocab_size - 1) for _ in range(input_length)
761+
]
762+
while True:
763+
# Adjust the token ids until the retokenization on the decoded string
764+
# matches the required input length.
765+
prompt = self.tokenizer.decode(prompt_token_ids)
766+
retokenized_token_ids = self.tokenizer.encode(prompt, add_special_tokens=False)
767+
if len(retokenized_token_ids) < input_length:
768+
prompt_token_ids = retokenized_token_ids + [
769+
random.randint(0, self.tokenizer.vocab_size - 1)
770+
for _ in range(input_length - len(retokenized_token_ids))
771+
]
772+
elif len(retokenized_token_ids) > input_length:
773+
prompt_token_ids = retokenized_token_ids[:input_length]
774+
else:
775+
break
776+
777+
time_diff = (datetime.strptime(timestamp[:-1], time_fmt) - start_time).total_seconds()
778+
request_records.append(
779+
RequestRecord(
780+
chat_cmpl=ChatCompletionRequest(
781+
messages=[{"role": "user", "content": prompt}],
782+
model="",
783+
max_tokens=output_length,
784+
),
785+
timestamp=time_diff,
786+
metrics=Metrics(
787+
success=False,
788+
start_time=0,
789+
finish_time=0,
790+
end_to_end_latency_s=0,
791+
input_tokens=input_length,
792+
),
793+
)
794+
)
795+
return request_records
796+
797+
705798
SUPPORTED_DATASET = [
706799
"sharegpt",
707800
"llmperf",
708801
"json-mode-eval",
709802
"loogle",
710803
"react",
711804
"wildchat",
805+
"azure-llm-inference",
712806
]
713807

714808

715-
def create_dataset(args: argparse.Namespace, tokenizer: AutoTokenizer) -> "Dataset":
809+
def create_dataset( # pylint: disable=too-many-return-statements,too-many-branches
810+
args: argparse.Namespace, tokenizer: AutoTokenizer
811+
) -> Dataset:
716812
"""Create a dataset instance with regard to the specified dataset kind and file path."""
717-
if args.dataset is None:
813+
if args.dataset_path is not None and not isinstance(args.dataset_path, str):
814+
raise TypeError(f"Invalid dataset path {args.dataset_path}. Please use a string.")
815+
if args.dataset is None and args.dataset_path is not None:
718816
# Auto-detect the dataset kind by looking into the dataset path.
719817
if "sharegpt" in args.dataset_path.lower():
720818
args.dataset = "sharegpt"
@@ -724,8 +822,16 @@ def create_dataset(args: argparse.Namespace, tokenizer: AutoTokenizer) -> "Datas
724822
'Please specify the dataset kind via "--dataset".'
725823
)
726824
if args.dataset == "sharegpt":
825+
if args.dataset_path is None:
826+
raise ValueError(
827+
'ShareGPT dataset requires dataset path. Please specify it with "--dataset-path".'
828+
)
727829
return ShareGPTDataset(args.dataset_path, tokenizer, args.apply_chat_template)
728830
if args.dataset == "llmperf":
831+
if args.dataset_path is None:
832+
raise ValueError(
833+
'LLMPerf dataset requires dataset path. Please specify it with "--dataset-path".'
834+
)
729835
assert (
730836
args.apply_chat_template is False
731837
), "LLMPerf dataset does not support applying chat template"
@@ -738,15 +844,33 @@ def create_dataset(args: argparse.Namespace, tokenizer: AutoTokenizer) -> "Datas
738844
), "JSON mode evaluation does not support applying chat template"
739845
return JSONModeEvalDataset(tokenizer)
740846
if args.dataset == "loogle":
847+
if args.dataset_path is None:
848+
raise ValueError(
849+
'Loogle dataset requires a testset name. Please specify it with "--dataset-path".'
850+
)
741851
assert (
742852
args.apply_chat_template is False
743853
), "Loogle dataset does not support applying chat template"
744854
return LoogleDataset(tokenizer, testset_name=args.dataset_path)
745855
if args.dataset == "react":
856+
if args.dataset_path is None:
857+
raise ValueError(
858+
'ReAct dataset requires dataset path. Please specify it with "--dataset-path".'
859+
)
746860
assert (
747861
args.apply_chat_template is False
748862
), "ReAct dataset does not support applying chat template"
749863
return ReActDataset(args.dataset_path, tokenizer)
750864
if args.dataset == "wildchat":
751865
return WildChatDataset(tokenizer, args.apply_chat_template)
866+
if args.dataset == "azure-llm-inference":
867+
if args.dataset_path is None:
868+
raise ValueError(
869+
"AzureLLMInference dataset requires dataset path. "
870+
'Please specify it with "--dataset-path".'
871+
)
872+
assert (
873+
args.apply_chat_template is False
874+
), "AzureLLMInference dataset does not support applying chat template"
875+
return AzureLLMInferenceDataset(args.dataset_path, tokenizer)
752876
raise ValueError(f"Unrecognized dataset {args.dataset}")

python/mlc_llm/bench/request_processor.py

Lines changed: 88 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,11 @@ def __call__(self, request_records: List[RequestRecord]) -> List[RequestRecord]:
5151
class SampleRequests(RequestProcessor): # pylint: disable=too-few-public-methods
5252
"""The processor that samples requests out from the given request list."""
5353

54-
def __init__(self, num_requests: int) -> None:
54+
def __init__(self, num_requests: int, take_first_x_requests: bool = False) -> None:
5555
self.num_requests = num_requests
56+
# If `take_first_x_requests` is True, the first `num_requests` requests
57+
# are returned and sampling will not happen.
58+
self.take_first_x_requests = take_first_x_requests
5659

5760
def __call__(self, request_records: List[RequestRecord]) -> List[RequestRecord]:
5861
assert len(request_records) > 0, "Empty input request record."
@@ -69,12 +72,20 @@ def _sample_from_plain_request_records(
6972
self, request_records: List[RequestRecord]
7073
) -> List[RequestRecord]:
7174
samples: List[RequestRecord] = []
72-
while len(samples) < self.num_requests:
73-
# Create a new list so that the in-place shuffle does not mutate the input list.
74-
records = list(request_records)
75-
random.shuffle(records)
76-
samples += copy.deepcopy(records)
77-
samples = samples[: self.num_requests]
75+
if self.take_first_x_requests:
76+
if len(request_records) < self.num_requests:
77+
raise ValueError(
78+
f"Insufficient requests. Requiring {self.num_requests} requests "
79+
f"but only {len(request_records)} are available."
80+
)
81+
samples = copy.deepcopy(list(request_records[: self.num_requests]))
82+
else:
83+
while len(samples) < self.num_requests:
84+
# Create a new list so that the in-place shuffle does not mutate the input list.
85+
records = list(request_records)
86+
random.shuffle(records)
87+
samples += copy.deepcopy(records)
88+
samples = samples[: self.num_requests]
7889
for i, record in enumerate(samples):
7990
record.request_id = i
8091
return samples
@@ -95,7 +106,8 @@ def _sample_from_grouped_request_records(
95106

96107
# Create a new list so that the in-place shuffle does not mutate the input list.
97108
records = list(grouped_request_records)
98-
random.shuffle(records)
109+
if not self.take_first_x_requests:
110+
random.shuffle(records)
99111
remaining = self.num_requests
100112
samples: List[RequestRecord] = []
101113
for grouped_request_record in grouped_request_records:
@@ -183,6 +195,22 @@ def __call__(self, request_records: List[RequestRecord]) -> List[RequestRecord]:
183195
return request_records
184196

185197

198+
class ScaleTimestamp(RequestProcessor): # pylint: disable=too-few-public-methods
199+
"""Scale the timestamp of requests by the given scale factor."""
200+
201+
def __init__(self, timestamp_scale: float):
202+
self.timestamp_scale = timestamp_scale
203+
204+
def __call__(self, request_records: List[RequestRecord]) -> List[RequestRecord]:
205+
for request_record in request_records:
206+
if request_record.timestamp is None:
207+
raise ValueError(
208+
f"The timestamp of request {request_record} has not been initialized."
209+
)
210+
request_record.timestamp *= self.timestamp_scale
211+
return request_records
212+
213+
186214
class MetricAnalyzer(RequestProcessor): # pylint: disable=too-few-public-methods
187215
"""The processor that analyzes the raw benchmark results and computes more detailed metrics."""
188216

@@ -463,7 +491,6 @@ def __init__( # pylint: disable=too-many-arguments
463491
disable_tqdm: bool,
464492
max_schedule_gap: float,
465493
num_requests: int,
466-
request_rate: Optional[np.float32] = None,
467494
) -> None:
468495
if num_processes is None:
469496
# We assign each process at most 32 requests to send
@@ -472,7 +499,6 @@ def __init__( # pylint: disable=too-many-arguments
472499
super().__init__(f_create_api_endpoint, num_processes, disable_tqdm)
473500
self.max_schedule_gap = max_schedule_gap
474501
self.num_requests = num_requests
475-
self.request_rate = request_rate
476502

477503
def __call__(self, request_records: List[RequestRecord]) -> List[RequestRecord]:
478504
assert len(request_records) > 0
@@ -574,7 +600,7 @@ async def _task(request_record: RequestRecord) -> None:
574600
)
575601

576602

577-
def create_pipelines(
603+
def create_pipelines( # pylint: disable=too-many-branches
578604
args: argparse.Namespace, f_create_api_endpoint: Callable[[], APIEndPoint], dataset: Dataset
579605
) -> List[RequestProcessor]:
580606
"""Creating request processing pipelines with regard to the specified args."""
@@ -586,6 +612,10 @@ def create_pipelines(
586612
'Both "num_concurrent_requests" and "request_rate" are specified. '
587613
"Please specify only one of them."
588614
)
615+
if args.replay_timestamp_scale is not None:
616+
raise ValueError(
617+
"Dataset replay is unsupported when fixing number of concurrent requests."
618+
)
589619
for num_concurrent_requests in args.num_concurrent_requests:
590620
num_warmup_requests = (
591621
args.num_warmup_requests
@@ -622,6 +652,8 @@ def create_pipelines(
622652
"Please specify the number of warmup requests via "
623653
'"--num-warmup-requests" when fixing request rate.'
624654
)
655+
if args.replay_timestamp_scale is not None:
656+
raise ValueError("Dataset replay is unsupported when fixing request rates.")
625657
num_total_requests = int(
626658
args.num_requests if not args.per_gpu_workload else args.num_requests * args.num_gpus
627659
)
@@ -649,15 +681,55 @@ def create_pipelines(
649681
args.disable_tqdm,
650682
args.max_schedule_gap,
651683
args.num_requests,
652-
request_rate,
653684
),
654685
cuda_profile_url=cuda_profile_url,
655686
fake_warmup=dataset.require_fake_warmup,
656687
),
657688
)
658689
for request_rate in args.request_rate
659690
]
660-
raise ValueError(
661-
'Unable to create executor. Please specify one of "num_concurrent_requests" '
662-
'and "request_rate".'
663-
)
691+
692+
# Default: dataset replay mode
693+
# The dataset must come with timestamps.
694+
if not dataset.timestamp_available:
695+
raise ValueError(
696+
"The dataset does not have timestamps, so dataset replay is unsupported. "
697+
'Please specify one of "num_concurrent_requests" '
698+
'and "request_rate".'
699+
)
700+
if args.per_gpu_workload:
701+
raise ValueError("Fixing per-GPU workload is not compatible with dataset replay.")
702+
if args.num_warmup_requests is None:
703+
raise ValueError(
704+
"Please specify the number of warmup requests via "
705+
'"--num-warmup-requests" for dataset replay.'
706+
)
707+
timestamp_scale = args.replay_timestamp_scale or 1.0
708+
if dataset.require_fake_warmup:
709+
num_samples = args.num_requests
710+
else:
711+
num_samples = args.num_requests + args.num_warmup_requests
712+
return [
713+
SequentialProcessor(
714+
LogMessage(f"Dataset replay with time scaling of {timestamp_scale}"),
715+
SampleRequests(num_samples, take_first_x_requests=True),
716+
AttachModelName(args.tokenizer),
717+
ScaleTimestamp(timestamp_scale),
718+
AttachStreamFlag(args.stream),
719+
AttachSamplingOptions(args.temperature, args.top_p, args.ignore_eos),
720+
AttachExecutionFeature({"timestamp_scale": timestamp_scale}),
721+
WarmupAndRun(
722+
num_warmup_requests=args.num_warmup_requests,
723+
num_benchmark_requests=args.num_requests,
724+
pipeline=FixTimestampExecutor(
725+
f_create_api_endpoint,
726+
args.num_process_workers,
727+
args.disable_tqdm,
728+
args.max_schedule_gap,
729+
args.num_requests,
730+
),
731+
cuda_profile_url=cuda_profile_url,
732+
fake_warmup=dataset.require_fake_warmup,
733+
),
734+
)
735+
]

0 commit comments

Comments
 (0)