@@ -654,6 +654,9 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
654
654
elif args .dataset_path in ASRDataset .SUPPORTED_DATASET_PATHS :
655
655
dataset_class = ASRDataset
656
656
args .hf_split = "train"
657
+ elif args .dataset_path in MLPerfDataset .SUPPORTED_DATASET_PATHS :
658
+ dataset_class = MLPerfDataset
659
+ args .hf_split = "train"
657
660
else :
658
661
supported_datasets = set ([
659
662
dataset_name for cls in HuggingFaceDataset .__subclasses__ ()
@@ -1447,3 +1450,82 @@ def sample(
1447
1450
)
1448
1451
self .maybe_oversample_requests (sampled_requests , num_requests )
1449
1452
return sampled_requests
1453
+
1454
+
1455
+ # -----------------------------------------------------------------------------
1456
+ # MLPerf Dataset Implementation
1457
+ # -----------------------------------------------------------------------------
1458
+
1459
+
1460
+ class MLPerfDataset (HuggingFaceDataset ):
1461
+ """
1462
+ MLPerf Inference Dataset.
1463
+
1464
+ Dataset on HF:
1465
+ https://huggingface.co/datasets/mgoin/mlperf-inference-llama2-data
1466
+ https://huggingface.co/datasets/mgoin/mlperf-inference-llama3.1-data
1467
+
1468
+ Each record contains:
1469
+ - "system_prompt": system role instruction.
1470
+ - "question": user question.
1471
+ - "output": reference answer.
1472
+
1473
+ We combine the system prompt and question into a chat-formatted prompt
1474
+ (using the tokenizer's chat template) and set the expected output length to
1475
+ the tokenized length of the provided reference answer.
1476
+ """
1477
+
1478
+ SUPPORTED_DATASET_PATHS = {
1479
+ "mgoin/mlperf-inference-llama2-data" ,
1480
+ "mgoin/mlperf-inference-llama3.1-data" ,
1481
+ }
1482
+
1483
+ def sample (
1484
+ self ,
1485
+ tokenizer : PreTrainedTokenizerBase ,
1486
+ num_requests : int ,
1487
+ output_len : Optional [int ] = None ,
1488
+ ** kwargs ,
1489
+ ) -> list [SampleRequest ]:
1490
+ # Force dynamic output length based on reference completion.
1491
+ dynamic_output = output_len is None
1492
+ sampled_requests : list [SampleRequest ] = []
1493
+
1494
+ for item in self .data :
1495
+ if len (sampled_requests ) >= num_requests :
1496
+ break
1497
+
1498
+ system_prompt = item ["system_prompt" ]
1499
+ question = item ["question" ]
1500
+ reference_answer = item ["output" ]
1501
+
1502
+ # Build chat-style prompt using tokenizer template, if available.
1503
+ messages = [
1504
+ {"role" : "system" , "content" : system_prompt },
1505
+ {"role" : "user" , "content" : question },
1506
+ ]
1507
+ prompt_formatted = tokenizer .apply_chat_template (
1508
+ messages , add_generation_prompt = True , tokenize = False
1509
+ )
1510
+ prompt_len = len (tokenizer (prompt_formatted ).input_ids )
1511
+
1512
+ # Determine output length from reference answer tokens.
1513
+ ref_out_len = len (
1514
+ tokenizer (reference_answer , add_special_tokens = False ).input_ids
1515
+ )
1516
+ expected_output_len = ref_out_len if dynamic_output else output_len
1517
+
1518
+ # Validate sequence lengths.
1519
+ if not is_valid_sequence (prompt_len , expected_output_len ):
1520
+ continue
1521
+
1522
+ sampled_requests .append (
1523
+ SampleRequest (
1524
+ prompt = prompt_formatted ,
1525
+ prompt_len = prompt_len ,
1526
+ expected_output_len = expected_output_len ,
1527
+ )
1528
+ )
1529
+
1530
+ self .maybe_oversample_requests (sampled_requests , num_requests )
1531
+ return sampled_requests
0 commit comments