Skip to content

Commit 6d3489a

Browse files
Allow the load test script to use a csv of inputs (#440)
* prepare allowing a csv input * randomly select input * pass some args through * log output token count percentiles * debug + ignore first line in file * lazy try except lmao * oops * oops x2 * oops x3 * . * oops prompt sample is reused * revert the changes to main, I'm gonna just have it take in a distribution of output token counts * output token count distribution * renane var to be more clear
1 parent ecd49fc commit 6d3489a

File tree

1 file changed

+51
-6
lines changed

1 file changed

+51
-6
lines changed

scripts/throughput_benchmarks.py

Lines changed: 51 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,34 @@ def generate_output_token_counts(mean, std, num, input_token_count):
221221
return output
222222

223223

224+
def generate_output_token_counts_from_existing(
225+
distribution: List[int], num: int, input_token_count: int
226+
):
227+
assert len(distribution) > 0, "Can't have a distribution with 0 tokens"
228+
output = []
229+
# Sample without replacement so that we don't have as much variance
230+
for _ in range(num // len(distribution)):
231+
random.shuffle(distribution)
232+
output.extend(distribution)
233+
random.shuffle(distribution)
234+
output.extend(distribution[: num % len(distribution)])
235+
assert len(output) == num
236+
237+
for i in range(len(output)):
238+
output[i] = min(output[i], MAX_CONTEXT_WINDOW - input_token_count)
239+
return output
240+
241+
242+
def read_distribution_from_file(fpath: str):
243+
# Assumes the distribution is some json-formatted string that represents a list
244+
try:
245+
with open(fpath, "r") as fin:
246+
return json.load(fin)
247+
except FileNotFoundError:
248+
print("File not found. Exiting.")
249+
raise
250+
251+
224252
def run_benchmark(
225253
model: str,
226254
framework: InferenceFramework,
@@ -231,17 +259,23 @@ def run_benchmark(
231259
concurrency: int,
232260
verbose: bool,
233261
local_port: int,
262+
response_token_count_distribution: Optional[List] = None,
234263
):
235264
prompt = generate_prompt(config.input_token_count, hf_model)
236265

237266
prompt_num_tokens = config.input_token_count
238267

239-
output_token_counts = generate_output_token_counts(
240-
config.output_token_count_mean,
241-
config.output_token_count_std,
242-
num_trials,
243-
config.input_token_count,
244-
)
268+
if response_token_count_distribution is not None:
269+
output_token_counts = generate_output_token_counts_from_existing(
270+
response_token_count_distribution, num_trials, config.input_token_count
271+
)
272+
else:
273+
output_token_counts = generate_output_token_counts(
274+
config.output_token_count_mean,
275+
config.output_token_count_std,
276+
num_trials,
277+
config.input_token_count,
278+
)
245279

246280
start = time.time()
247281
results = send_requests(
@@ -352,10 +386,18 @@ def run_benchmarks(
352386
verbose: bool = False,
353387
hf_model: Optional[str] = None,
354388
local_port: int = 5005,
389+
response_token_count_distribution_file: Optional[str] = None,
355390
):
356391
"""Run benchmarks."""
357392
all_statistics = []
358393
config = BenchmarkConfig(input_token_count, output_token_count_mean)
394+
395+
response_token_count_distribution = None
396+
if response_token_count_distribution_file is not None:
397+
response_token_count_distribution = read_distribution_from_file(
398+
response_token_count_distribution_file
399+
)
400+
359401
try:
360402
if verbose:
361403
print(f"Running benchmark for config {config}")
@@ -375,6 +417,7 @@ def run_benchmarks(
375417
concurrency,
376418
verbose,
377419
local_port,
420+
response_token_count_distribution,
378421
)
379422
all_statistics.append(statistics)
380423
except Exception:
@@ -404,6 +447,7 @@ def run_benchmarks_concurrency_range(
404447
verbose: bool = False,
405448
hf_model: Optional[str] = None,
406449
local_port: int = 5005,
450+
response_token_count_distribution_file: Optional[str] = None,
407451
):
408452
if output_file is not None:
409453
# Create empty file
@@ -422,6 +466,7 @@ def run_benchmarks_concurrency_range(
422466
verbose,
423467
hf_model,
424468
local_port,
469+
response_token_count_distribution_file,
425470
)
426471

427472

0 commit comments

Comments
 (0)