diff --git a/benchmarks/benchmark_dataset.py b/benchmarks/benchmark_dataset.py index 55c0cf85126..dddd6650eda 100644 --- a/benchmarks/benchmark_dataset.py +++ b/benchmarks/benchmark_dataset.py @@ -1166,3 +1166,67 @@ def sample( ) self.maybe_oversample_requests(sampled_requests, num_requests) return sampled_requests + + +# ----------------------------------------------------------------------------- +# Prefix Repetition Dataset Implementation +# ----------------------------------------------------------------------------- + + +class PrefixRepetitionRandomDataset(BenchmarkDataset): + # Default values copied from benchmark_serving.py for the repeated prefix dataset. + DEFAULT_PREFIX_LEN = 256 + DEFAULT_SUFFIX_LEN = 256 + DEFAULT_NUM_PREFIXES = 10 + DEFAULT_OUTPUT_LEN = 128 + + def __init__( + self, + **kwargs, + ) -> None: + super().__init__(**kwargs) + + def sample( + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + prefix_len: int = DEFAULT_PREFIX_LEN, + suffix_len: int = DEFAULT_SUFFIX_LEN, + num_prefixes: int = DEFAULT_NUM_PREFIXES, + output_len: int = DEFAULT_OUTPUT_LEN, + **kwargs, + ) -> list[SampleRequest]: + vocab_size = tokenizer.vocab_size + prompts_per_prefix = num_requests // num_prefixes + + def _generate_random_text_part(length: int) -> tuple[str, list[int]]: + token_ids = np.random.randint(0, vocab_size, size=length).tolist() + decoded_text = tokenizer.decode(token_ids) + # Re-encoding and decoding is necessary to ensure the final + # token count is correct. + re_encoded_ids = tokenizer.encode(decoded_text, add_special_tokens=False)[ + :length + ] + final_text = tokenizer.decode(re_encoded_ids) + return final_text, re_encoded_ids + + requests = [] + for _ in range(num_prefixes): + decoded_prefix, re_encoded_prefix = _generate_random_text_part(prefix_len) + + for _ in range(prompts_per_prefix): + decoded_suffix, re_encoded_suffix = _generate_random_text_part( + suffix_len + ) + + prompt = decoded_prefix + decoded_suffix + prompt_len = len(re_encoded_prefix) + len(re_encoded_suffix) + requests.append( + SampleRequest( + prompt=prompt, + prompt_len=prompt_len, + expected_output_len=output_len, + ) + ) + + return requests diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index 9b235266dff..29e09f3fb10 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -66,6 +66,7 @@ InstructCoderDataset, MTBenchDataset, NextEditPredictionDataset, + PrefixRepetitionRandomDataset, RandomDataset, SampleRequest, ShareGPTDataset, @@ -852,6 +853,16 @@ def main(args: argparse.Namespace): output_len=args.random_output_len, range_ratio=args.random_range_ratio, ), + "prefix_repetition": lambda: PrefixRepetitionRandomDataset( + random_seed=args.seed, dataset_path=args.dataset_path + ).sample( + tokenizer=tokenizer, + num_requests=args.num_prompts, + prefix_len=args.repeated_prefix_prefix_len, + suffix_len=args.repeated_prefix_suffix_len, + num_prefixes=args.repeated_prefix_num_prefixes, + output_len=args.repeated_prefix_output_len, + ), } try: @@ -1023,7 +1034,15 @@ def create_argument_parser(): "--dataset-name", type=str, default="sharegpt", - choices=["sharegpt", "burstgpt", "sonnet", "random", "hf", "custom"], + choices=[ + "sharegpt", + "burstgpt", + "sonnet", + "random", + "hf", + "custom", + "prefix_repetition", + ], help="Name of the dataset to benchmark on.", ) parser.add_argument( @@ -1271,6 +1290,36 @@ def create_argument_parser(): ), ) + repeated_prefix_group = parser.add_argument_group("repeated prefix dataset options") + repeated_prefix_group.add_argument( + "--repeated-prefix-prefix-len", + type=int, + default=256, + help="Number of prefix tokens per request, used only for repeated " + "prefix dataset.", + ) + repeated_prefix_group.add_argument( + "--repeated-prefix-suffix-len", + type=int, + default=256, + help="Number of suffix tokens per request, used only for repeated " + "prefix dataset. Total input length is prefix_len + suffix_len.", + ) + repeated_prefix_group.add_argument( + "--repeated-prefix-num-prefixes", + type=int, + default=10, + help="Number of prefixes to generate, used only for repeated prefix " + "dataset. Prompts per prefix is num_requests // num_prefixes.", + ) + repeated_prefix_group.add_argument( + "--repeated-prefix-output-len", + type=int, + default=128, + help="Number of output tokens per request, used only for repeated " + "prefix dataset.", + ) + hf_group = parser.add_argument_group("hf dataset options") hf_group.add_argument( "--hf-subset", type=str, default=None, help="Subset of the HF dataset."