Skip to content

Commit 8bb43b9

Browse files
authored
Add benchmark dataset for mlperf llama tasks (#20338)
Signed-off-by: mgoin <mgoin64@gmail.com>
1 parent 5597562 commit 8bb43b9

File tree

1 file changed

+82
-0
lines changed

1 file changed

+82
-0
lines changed

vllm/benchmarks/datasets.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -654,6 +654,9 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
654654
elif args.dataset_path in ASRDataset.SUPPORTED_DATASET_PATHS:
655655
dataset_class = ASRDataset
656656
args.hf_split = "train"
657+
elif args.dataset_path in MLPerfDataset.SUPPORTED_DATASET_PATHS:
658+
dataset_class = MLPerfDataset
659+
args.hf_split = "train"
657660
else:
658661
supported_datasets = set([
659662
dataset_name for cls in HuggingFaceDataset.__subclasses__()
@@ -1447,3 +1450,82 @@ def sample(
14471450
)
14481451
self.maybe_oversample_requests(sampled_requests, num_requests)
14491452
return sampled_requests
1453+
1454+
1455+
# -----------------------------------------------------------------------------
1456+
# MLPerf Dataset Implementation
1457+
# -----------------------------------------------------------------------------
1458+
1459+
1460+
class MLPerfDataset(HuggingFaceDataset):
1461+
"""
1462+
MLPerf Inference Dataset.
1463+
1464+
Dataset on HF:
1465+
https://huggingface.co/datasets/mgoin/mlperf-inference-llama2-data
1466+
https://huggingface.co/datasets/mgoin/mlperf-inference-llama3.1-data
1467+
1468+
Each record contains:
1469+
- "system_prompt": system role instruction.
1470+
- "question": user question.
1471+
- "output": reference answer.
1472+
1473+
We combine the system prompt and question into a chat-formatted prompt
1474+
(using the tokenizer's chat template) and set the expected output length to
1475+
the tokenized length of the provided reference answer.
1476+
"""
1477+
1478+
SUPPORTED_DATASET_PATHS = {
1479+
"mgoin/mlperf-inference-llama2-data",
1480+
"mgoin/mlperf-inference-llama3.1-data",
1481+
}
1482+
1483+
def sample(
1484+
self,
1485+
tokenizer: PreTrainedTokenizerBase,
1486+
num_requests: int,
1487+
output_len: Optional[int] = None,
1488+
**kwargs,
1489+
) -> list[SampleRequest]:
1490+
# Force dynamic output length based on reference completion.
1491+
dynamic_output = output_len is None
1492+
sampled_requests: list[SampleRequest] = []
1493+
1494+
for item in self.data:
1495+
if len(sampled_requests) >= num_requests:
1496+
break
1497+
1498+
system_prompt = item["system_prompt"]
1499+
question = item["question"]
1500+
reference_answer = item["output"]
1501+
1502+
# Build chat-style prompt using tokenizer template, if available.
1503+
messages = [
1504+
{"role": "system", "content": system_prompt},
1505+
{"role": "user", "content": question},
1506+
]
1507+
prompt_formatted = tokenizer.apply_chat_template(
1508+
messages, add_generation_prompt=True, tokenize=False
1509+
)
1510+
prompt_len = len(tokenizer(prompt_formatted).input_ids)
1511+
1512+
# Determine output length from reference answer tokens.
1513+
ref_out_len = len(
1514+
tokenizer(reference_answer, add_special_tokens=False).input_ids
1515+
)
1516+
expected_output_len = ref_out_len if dynamic_output else output_len
1517+
1518+
# Validate sequence lengths.
1519+
if not is_valid_sequence(prompt_len, expected_output_len):
1520+
continue
1521+
1522+
sampled_requests.append(
1523+
SampleRequest(
1524+
prompt=prompt_formatted,
1525+
prompt_len=prompt_len,
1526+
expected_output_len=expected_output_len,
1527+
)
1528+
)
1529+
1530+
self.maybe_oversample_requests(sampled_requests, num_requests)
1531+
return sampled_requests

0 commit comments

Comments
 (0)