From 25fa31c2b8d1102f5fb99a1b6ad8d709b0718bc7 Mon Sep 17 00:00:00 2001 From: Seiji Eicher Date: Tue, 8 Jul 2025 11:15:11 -0700 Subject: [PATCH 1/2] PrefixRepetitionRandomDataset Signed-off-by: Seiji Eicher --- benchmarks/benchmark_dataset.py | 69 +++++++++++++++++++++++++++++++++ benchmarks/benchmark_serving.py | 57 ++++++++++++++++++++++++++- 2 files changed, 125 insertions(+), 1 deletion(-) diff --git a/benchmarks/benchmark_dataset.py b/benchmarks/benchmark_dataset.py index 55c0cf85126..79c9d085c07 100644 --- a/benchmarks/benchmark_dataset.py +++ b/benchmarks/benchmark_dataset.py @@ -1166,3 +1166,72 @@ def sample( ) self.maybe_oversample_requests(sampled_requests, num_requests) return sampled_requests + + +# ----------------------------------------------------------------------------- +# Prefix Repetition Dataset Implementation +# ----------------------------------------------------------------------------- + + +class PrefixRepetitionRandomDataset(BenchmarkDataset): + # Default values copied from benchmark_serving.py for the repeated prefix dataset. + DEFAULT_PROMPTS_PER_PREFIX = 200 + DEFAULT_PREFIX_LEN = 256 + DEFAULT_SUFFIX_LEN = 256 + DEFAULT_NUM_PREFIXES = 10 + DEFAULT_OUTPUT_LEN = 128 + + def __init__( + self, + **kwargs, + ) -> None: + super().__init__(**kwargs) + + def sample( + self, + tokenizer: PreTrainedTokenizerBase, + prompts_per_prefix: int = DEFAULT_PROMPTS_PER_PREFIX, + prefix_len: int = DEFAULT_PREFIX_LEN, + suffix_len: int = DEFAULT_SUFFIX_LEN, + num_prefixes: int = DEFAULT_NUM_PREFIXES, + output_len: int = DEFAULT_OUTPUT_LEN, + **kwargs, + ) -> list[SampleRequest]: + vocab_size = tokenizer.vocab_size + + requests = [] + for _ in range(num_prefixes): + prefix_token_ids = ( + np.random.randint(0, vocab_size, size=prefix_len).tolist() + if prefix_len > 0 + else [] + ) + decoded_prefix = tokenizer.decode(prefix_token_ids) + re_encoded_prefix = tokenizer.encode( + decoded_prefix, add_special_tokens=False + )[:prefix_len] + decoded_prefix = tokenizer.decode(re_encoded_prefix) + + for _ in range(prompts_per_prefix): + suffix_token_ids = ( + np.random.randint(0, vocab_size, size=suffix_len).tolist() + if suffix_len > 0 + else [] + ) + decoded_suffix = tokenizer.decode(suffix_token_ids) + re_encoded_suffix = tokenizer.encode( + decoded_suffix, add_special_tokens=False + )[:suffix_len] + decoded_suffix = tokenizer.decode(re_encoded_suffix) + + prompt = decoded_prefix + decoded_suffix + prompt_len = len(re_encoded_prefix) + len(re_encoded_suffix) + requests.append( + SampleRequest( + prompt=prompt, + prompt_len=prompt_len, + expected_output_len=output_len, + ) + ) + + return requests diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index 9b235266dff..592e1294a86 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -66,6 +66,7 @@ InstructCoderDataset, MTBenchDataset, NextEditPredictionDataset, + PrefixRepetitionRandomDataset, RandomDataset, SampleRequest, ShareGPTDataset, @@ -852,6 +853,16 @@ def main(args: argparse.Namespace): output_len=args.random_output_len, range_ratio=args.random_range_ratio, ), + "prefix_repetition": lambda: PrefixRepetitionRandomDataset( + random_seed=args.seed, dataset_path=args.dataset_path + ).sample( + tokenizer=tokenizer, + prompts_per_prefix=args.repeated_prefix_prompts_per_prefix, + prefix_len=args.repeated_prefix_prefix_len, + suffix_len=args.repeated_prefix_suffix_len, + num_prefixes=args.repeated_prefix_num_prefixes, + output_len=args.repeated_prefix_output_len, + ), } try: @@ -1023,7 +1034,15 @@ def create_argument_parser(): "--dataset-name", type=str, default="sharegpt", - choices=["sharegpt", "burstgpt", "sonnet", "random", "hf", "custom"], + choices=[ + "sharegpt", + "burstgpt", + "sonnet", + "random", + "hf", + "custom", + "prefix_repetition", + ], help="Name of the dataset to benchmark on.", ) parser.add_argument( @@ -1271,6 +1290,42 @@ def create_argument_parser(): ), ) + repeated_prefix_group = parser.add_argument_group("repeated prefix dataset options") + repeated_prefix_group.add_argument( + "--repeated-prefix-prompts-per-prefix", + type=int, + default=200, + help="Number of prompts per prefix, used only for repeated prefix dataset.", + ) + repeated_prefix_group.add_argument( + "--repeated-prefix-prefix-len", + type=int, + default=256, + help="Number of prefix tokens per request, used only for repeated " + "prefix dataset.", + ) + repeated_prefix_group.add_argument( + "--repeated-prefix-suffix-len", + type=int, + default=256, + help="Number of suffix tokens per request, used only for repeated " + "prefix dataset. Total input length is prefix_len + suffix_len.", + ) + repeated_prefix_group.add_argument( + "--repeated-prefix-num-prefixes", + type=int, + default=10, + help="Number of prefixes to generate, used only for repeated prefix " + "dataset. Total number of requests is prompts_per_prefix * num_prefixes.", + ) + repeated_prefix_group.add_argument( + "--repeated-prefix-output-len", + type=int, + default=128, + help="Number of output tokens per request, used only for repeated " + "prefix dataset.", + ) + hf_group = parser.add_argument_group("hf dataset options") hf_group.add_argument( "--hf-subset", type=str, default=None, help="Subset of the HF dataset." From 5353a7f6dfa4271160f58bd89c589b879bb8753e Mon Sep 17 00:00:00 2001 From: Seiji Eicher Date: Tue, 8 Jul 2025 11:31:09 -0700 Subject: [PATCH 2/2] Refactor for num_requests Signed-off-by: Seiji Eicher --- benchmarks/benchmark_dataset.py | 37 ++++++++++++++------------------- benchmarks/benchmark_serving.py | 10 ++------- 2 files changed, 18 insertions(+), 29 deletions(-) diff --git a/benchmarks/benchmark_dataset.py b/benchmarks/benchmark_dataset.py index 79c9d085c07..dddd6650eda 100644 --- a/benchmarks/benchmark_dataset.py +++ b/benchmarks/benchmark_dataset.py @@ -1175,7 +1175,6 @@ def sample( class PrefixRepetitionRandomDataset(BenchmarkDataset): # Default values copied from benchmark_serving.py for the repeated prefix dataset. - DEFAULT_PROMPTS_PER_PREFIX = 200 DEFAULT_PREFIX_LEN = 256 DEFAULT_SUFFIX_LEN = 256 DEFAULT_NUM_PREFIXES = 10 @@ -1190,7 +1189,7 @@ def __init__( def sample( self, tokenizer: PreTrainedTokenizerBase, - prompts_per_prefix: int = DEFAULT_PROMPTS_PER_PREFIX, + num_requests: int, prefix_len: int = DEFAULT_PREFIX_LEN, suffix_len: int = DEFAULT_SUFFIX_LEN, num_prefixes: int = DEFAULT_NUM_PREFIXES, @@ -1198,31 +1197,27 @@ def sample( **kwargs, ) -> list[SampleRequest]: vocab_size = tokenizer.vocab_size + prompts_per_prefix = num_requests // num_prefixes + + def _generate_random_text_part(length: int) -> tuple[str, list[int]]: + token_ids = np.random.randint(0, vocab_size, size=length).tolist() + decoded_text = tokenizer.decode(token_ids) + # Re-encoding and decoding is necessary to ensure the final + # token count is correct. + re_encoded_ids = tokenizer.encode(decoded_text, add_special_tokens=False)[ + :length + ] + final_text = tokenizer.decode(re_encoded_ids) + return final_text, re_encoded_ids requests = [] for _ in range(num_prefixes): - prefix_token_ids = ( - np.random.randint(0, vocab_size, size=prefix_len).tolist() - if prefix_len > 0 - else [] - ) - decoded_prefix = tokenizer.decode(prefix_token_ids) - re_encoded_prefix = tokenizer.encode( - decoded_prefix, add_special_tokens=False - )[:prefix_len] - decoded_prefix = tokenizer.decode(re_encoded_prefix) + decoded_prefix, re_encoded_prefix = _generate_random_text_part(prefix_len) for _ in range(prompts_per_prefix): - suffix_token_ids = ( - np.random.randint(0, vocab_size, size=suffix_len).tolist() - if suffix_len > 0 - else [] + decoded_suffix, re_encoded_suffix = _generate_random_text_part( + suffix_len ) - decoded_suffix = tokenizer.decode(suffix_token_ids) - re_encoded_suffix = tokenizer.encode( - decoded_suffix, add_special_tokens=False - )[:suffix_len] - decoded_suffix = tokenizer.decode(re_encoded_suffix) prompt = decoded_prefix + decoded_suffix prompt_len = len(re_encoded_prefix) + len(re_encoded_suffix) diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index 592e1294a86..29e09f3fb10 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -857,7 +857,7 @@ def main(args: argparse.Namespace): random_seed=args.seed, dataset_path=args.dataset_path ).sample( tokenizer=tokenizer, - prompts_per_prefix=args.repeated_prefix_prompts_per_prefix, + num_requests=args.num_prompts, prefix_len=args.repeated_prefix_prefix_len, suffix_len=args.repeated_prefix_suffix_len, num_prefixes=args.repeated_prefix_num_prefixes, @@ -1291,12 +1291,6 @@ def create_argument_parser(): ) repeated_prefix_group = parser.add_argument_group("repeated prefix dataset options") - repeated_prefix_group.add_argument( - "--repeated-prefix-prompts-per-prefix", - type=int, - default=200, - help="Number of prompts per prefix, used only for repeated prefix dataset.", - ) repeated_prefix_group.add_argument( "--repeated-prefix-prefix-len", type=int, @@ -1316,7 +1310,7 @@ def create_argument_parser(): type=int, default=10, help="Number of prefixes to generate, used only for repeated prefix " - "dataset. Total number of requests is prompts_per_prefix * num_prefixes.", + "dataset. Prompts per prefix is num_requests // num_prefixes.", ) repeated_prefix_group.add_argument( "--repeated-prefix-output-len",