@@ -221,6 +221,34 @@ def generate_output_token_counts(mean, std, num, input_token_count):
221
221
return output
222
222
223
223
224
+ def generate_output_token_counts_from_existing (
225
+ distribution : List [int ], num : int , input_token_count : int
226
+ ):
227
+ assert len (distribution ) > 0 , "Can't have a distribution with 0 tokens"
228
+ output = []
229
+ # Sample without replacement so that we don't have as much variance
230
+ for _ in range (num // len (distribution )):
231
+ random .shuffle (distribution )
232
+ output .extend (distribution )
233
+ random .shuffle (distribution )
234
+ output .extend (distribution [: num % len (distribution )])
235
+ assert len (output ) == num
236
+
237
+ for i in range (len (output )):
238
+ output [i ] = min (output [i ], MAX_CONTEXT_WINDOW - input_token_count )
239
+ return output
240
+
241
+
242
+ def read_distribution_from_file (fpath : str ):
243
+ # Assumes the distribution is some json-formatted string that represents a list
244
+ try :
245
+ with open (fpath , "r" ) as fin :
246
+ return json .load (fin )
247
+ except FileNotFoundError :
248
+ print ("File not found. Exiting." )
249
+ raise
250
+
251
+
224
252
def run_benchmark (
225
253
model : str ,
226
254
framework : InferenceFramework ,
@@ -231,17 +259,23 @@ def run_benchmark(
231
259
concurrency : int ,
232
260
verbose : bool ,
233
261
local_port : int ,
262
+ response_token_count_distribution : Optional [List ] = None ,
234
263
):
235
264
prompt = generate_prompt (config .input_token_count , hf_model )
236
265
237
266
prompt_num_tokens = config .input_token_count
238
267
239
- output_token_counts = generate_output_token_counts (
240
- config .output_token_count_mean ,
241
- config .output_token_count_std ,
242
- num_trials ,
243
- config .input_token_count ,
244
- )
268
+ if response_token_count_distribution is not None :
269
+ output_token_counts = generate_output_token_counts_from_existing (
270
+ response_token_count_distribution , num_trials , config .input_token_count
271
+ )
272
+ else :
273
+ output_token_counts = generate_output_token_counts (
274
+ config .output_token_count_mean ,
275
+ config .output_token_count_std ,
276
+ num_trials ,
277
+ config .input_token_count ,
278
+ )
245
279
246
280
start = time .time ()
247
281
results = send_requests (
@@ -352,10 +386,18 @@ def run_benchmarks(
352
386
verbose : bool = False ,
353
387
hf_model : Optional [str ] = None ,
354
388
local_port : int = 5005 ,
389
+ response_token_count_distribution_file : Optional [str ] = None ,
355
390
):
356
391
"""Run benchmarks."""
357
392
all_statistics = []
358
393
config = BenchmarkConfig (input_token_count , output_token_count_mean )
394
+
395
+ response_token_count_distribution = None
396
+ if response_token_count_distribution_file is not None :
397
+ response_token_count_distribution = read_distribution_from_file (
398
+ response_token_count_distribution_file
399
+ )
400
+
359
401
try :
360
402
if verbose :
361
403
print (f"Running benchmark for config { config } " )
@@ -375,6 +417,7 @@ def run_benchmarks(
375
417
concurrency ,
376
418
verbose ,
377
419
local_port ,
420
+ response_token_count_distribution ,
378
421
)
379
422
all_statistics .append (statistics )
380
423
except Exception :
@@ -404,6 +447,7 @@ def run_benchmarks_concurrency_range(
404
447
verbose : bool = False ,
405
448
hf_model : Optional [str ] = None ,
406
449
local_port : int = 5005 ,
450
+ response_token_count_distribution_file : Optional [str ] = None ,
407
451
):
408
452
if output_file is not None :
409
453
# Create empty file
@@ -422,6 +466,7 @@ def run_benchmarks_concurrency_range(
422
466
verbose ,
423
467
hf_model ,
424
468
local_port ,
469
+ response_token_count_distribution_file ,
425
470
)
426
471
427
472
0 commit comments