33
33
from collections .abc import AsyncGenerator , Iterable
34
34
from dataclasses import dataclass
35
35
from datetime import datetime
36
- from typing import Any , Optional
36
+ from typing import Any , Literal , Optional
37
37
38
38
import numpy as np
39
39
from tqdm .asyncio import tqdm
@@ -107,14 +107,42 @@ class BenchmarkMetrics:
107
107
percentiles_e2el_ms : list [tuple [float , float ]]
108
108
109
109
110
+ def _get_current_request_rate (
111
+ ramp_up_strategy : Optional [Literal ["linear" , "exponential" ]],
112
+ ramp_up_start_rps : Optional [int ],
113
+ ramp_up_end_rps : Optional [int ],
114
+ request_index : int ,
115
+ total_requests : int ,
116
+ request_rate : float ,
117
+ ) -> float :
118
+ if (
119
+ ramp_up_strategy
120
+ and ramp_up_start_rps is not None
121
+ and ramp_up_end_rps is not None
122
+ ):
123
+ progress = request_index / max (total_requests - 1 , 1 )
124
+ if ramp_up_strategy == "linear" :
125
+ increase = (ramp_up_end_rps - ramp_up_start_rps ) * progress
126
+ return ramp_up_start_rps + increase
127
+ elif ramp_up_strategy == "exponential" :
128
+ ratio = ramp_up_end_rps / ramp_up_start_rps
129
+ return ramp_up_start_rps * (ratio ** progress )
130
+ else :
131
+ raise ValueError (f"Unknown ramp-up strategy: { ramp_up_strategy } " )
132
+ return request_rate
133
+
134
+
110
135
async def get_request (
111
136
input_requests : list [SampleRequest ],
112
137
request_rate : float ,
113
138
burstiness : float = 1.0 ,
114
- ) -> AsyncGenerator [SampleRequest , None ]:
139
+ ramp_up_strategy : Optional [Literal ["linear" , "exponential" ]] = None ,
140
+ ramp_up_start_rps : Optional [int ] = None ,
141
+ ramp_up_end_rps : Optional [int ] = None ,
142
+ ) -> AsyncGenerator [tuple [SampleRequest , float ], None ]:
115
143
"""
116
144
Asynchronously generates requests at a specified rate
117
- with OPTIONAL burstiness.
145
+ with OPTIONAL burstiness and OPTIONAL ramp-up strategy .
118
146
119
147
Args:
120
148
input_requests:
@@ -129,22 +157,44 @@ async def get_request(
129
157
A lower burstiness value (0 < burstiness < 1) results
130
158
in more bursty requests, while a higher burstiness value
131
159
(burstiness > 1) results in a more uniform arrival of requests.
160
+ ramp_up_strategy (optional):
161
+ The ramp-up strategy. Can be "linear" or "exponential".
162
+ If None, uses constant request rate (specified by request_rate).
163
+ ramp_up_start_rps (optional):
164
+ The starting request rate for ramp-up.
165
+ ramp_up_end_rps (optional):
166
+ The ending request rate for ramp-up.
132
167
"""
133
- input_requests : Iterable [SampleRequest ] = iter (input_requests )
134
-
135
- # Calculate scale parameter theta to maintain the desired request_rate.
136
168
assert burstiness > 0 , (
137
169
f"A positive burstiness factor is expected, but given { burstiness } ."
138
170
)
139
- theta = 1.0 / (request_rate * burstiness )
171
+ # Convert to list to get length for ramp-up calculations
172
+ if isinstance (input_requests , Iterable ) and not isinstance (input_requests , list ):
173
+ input_requests = list (input_requests )
174
+
175
+ total_requests = len (input_requests )
176
+ request_index = 0
140
177
141
178
for request in input_requests :
142
- yield request
179
+ current_request_rate = _get_current_request_rate (
180
+ ramp_up_strategy ,
181
+ ramp_up_start_rps ,
182
+ ramp_up_end_rps ,
183
+ request_index ,
184
+ total_requests ,
185
+ request_rate ,
186
+ )
187
+
188
+ yield request , current_request_rate
143
189
144
- if request_rate == float ("inf" ):
190
+ request_index += 1
191
+
192
+ if current_request_rate == float ("inf" ):
145
193
# If the request rate is infinity, then we don't need to wait.
146
194
continue
147
195
196
+ theta = 1.0 / (current_request_rate * burstiness )
197
+
148
198
# Sample the request interval from the gamma distribution.
149
199
# If burstiness is 1, it follows exponential distribution.
150
200
interval = np .random .gamma (shape = burstiness , scale = theta )
@@ -290,6 +340,9 @@ async def benchmark(
290
340
max_concurrency : Optional [int ],
291
341
lora_modules : Optional [Iterable [str ]],
292
342
extra_body : Optional [dict ],
343
+ ramp_up_strategy : Optional [Literal ["linear" , "exponential" ]] = None ,
344
+ ramp_up_start_rps : Optional [int ] = None ,
345
+ ramp_up_end_rps : Optional [int ] = None ,
293
346
):
294
347
if backend in ASYNC_REQUEST_FUNCS :
295
348
request_func = ASYNC_REQUEST_FUNCS [backend ]
@@ -353,7 +406,15 @@ async def benchmark(
353
406
354
407
distribution = "Poisson process" if burstiness == 1.0 else "Gamma distribution"
355
408
356
- print (f"Traffic request rate: { request_rate } " )
409
+ if ramp_up_strategy is not None :
410
+ print (
411
+ f"Traffic ramp-up strategy: { ramp_up_strategy } . Will increase "
412
+ f"RPS from { ramp_up_start_rps } to { ramp_up_end_rps } RPS over "
413
+ "the duration of the benchmark."
414
+ )
415
+ else :
416
+ print (f"Traffic request rate: { request_rate } RPS." )
417
+
357
418
print (f"Burstiness factor: { burstiness } ({ distribution } )" )
358
419
print (f"Maximum request concurrency: { max_concurrency } " )
359
420
@@ -373,7 +434,34 @@ async def limited_request_func(request_func_input, pbar):
373
434
374
435
benchmark_start_time = time .perf_counter ()
375
436
tasks : list [asyncio .Task ] = []
376
- async for request in get_request (input_requests , request_rate , burstiness ):
437
+
438
+ rps_change_events = []
439
+ last_int_rps = - 1
440
+ if ramp_up_strategy is not None and ramp_up_start_rps is not None :
441
+ last_int_rps = ramp_up_start_rps
442
+ rps_change_events .append (
443
+ {
444
+ "rps" : last_int_rps ,
445
+ "timestamp" : datetime .now ().isoformat (),
446
+ }
447
+ )
448
+
449
+ async for request , current_request_rate in get_request (
450
+ input_requests ,
451
+ request_rate ,
452
+ burstiness ,
453
+ ramp_up_strategy ,
454
+ ramp_up_start_rps ,
455
+ ramp_up_end_rps ,
456
+ ):
457
+ if ramp_up_strategy is not None :
458
+ current_int_rps = int (current_request_rate )
459
+ if current_int_rps > last_int_rps :
460
+ timestamp = datetime .now ().isoformat ()
461
+ for rps_val in range (last_int_rps + 1 , current_int_rps + 1 ):
462
+ rps_change_events .append ({"rps" : rps_val , "timestamp" : timestamp })
463
+ last_int_rps = current_int_rps
464
+
377
465
prompt , prompt_len , output_len , mm_content = (
378
466
request .prompt ,
379
467
request .prompt_len ,
@@ -397,11 +485,8 @@ async def limited_request_func(request_func_input, pbar):
397
485
ignore_eos = ignore_eos ,
398
486
extra_body = extra_body ,
399
487
)
400
- tasks .append (
401
- asyncio .create_task (
402
- limited_request_func (request_func_input = request_func_input , pbar = pbar )
403
- )
404
- )
488
+ task = limited_request_func (request_func_input = request_func_input , pbar = pbar )
489
+ tasks .append (asyncio .create_task (task ))
405
490
outputs : list [RequestFuncOutput ] = await asyncio .gather (* tasks )
406
491
407
492
if profile :
@@ -477,6 +562,9 @@ async def limited_request_func(request_func_input, pbar):
477
562
"errors" : [output .error for output in outputs ],
478
563
}
479
564
565
+ if rps_change_events :
566
+ result ["rps_change_events" ] = rps_change_events
567
+
480
568
def process_one_metric (
481
569
# E.g., "ttft"
482
570
metric_attribute_name : str ,
@@ -610,6 +698,26 @@ def main(args: argparse.Namespace):
610
698
tokenizer_id = args .tokenizer if args .tokenizer is not None else args .model
611
699
tokenizer_mode = args .tokenizer_mode
612
700
701
+ # Validate ramp-up arguments
702
+ if args .ramp_up_strategy is not None :
703
+ if args .request_rate != float ("inf" ):
704
+ raise ValueError (
705
+ "When using ramp-up, do not specify --request-rate. "
706
+ "The request rate will be controlled by ramp-up parameters. "
707
+ "Please remove the --request-rate argument."
708
+ )
709
+ if args .ramp_up_start_rps is None or args .ramp_up_end_rps is None :
710
+ raise ValueError (
711
+ "When using --ramp-up-strategy, both --ramp-up-start-rps and "
712
+ "--ramp-up-end-rps must be specified"
713
+ )
714
+ if args .ramp_up_start_rps < 0 or args .ramp_up_end_rps < 0 :
715
+ raise ValueError ("Ramp-up start and end RPS must be non-negative" )
716
+ if args .ramp_up_start_rps > args .ramp_up_end_rps :
717
+ raise ValueError ("Ramp-up start RPS must be less than end RPS" )
718
+ if args .ramp_up_strategy == "exponential" and args .ramp_up_start_rps == 0 :
719
+ raise ValueError ("For exponential ramp-up, the start RPS cannot be 0." )
720
+
613
721
if args .base_url is not None :
614
722
api_url = f"{ args .base_url } { args .endpoint } "
615
723
base_url = f"{ args .base_url } "
@@ -802,6 +910,9 @@ def main(args: argparse.Namespace):
802
910
max_concurrency = args .max_concurrency ,
803
911
lora_modules = args .lora_modules ,
804
912
extra_body = sampling_params ,
913
+ ramp_up_strategy = args .ramp_up_strategy ,
914
+ ramp_up_start_rps = args .ramp_up_start_rps ,
915
+ ramp_up_end_rps = args .ramp_up_end_rps ,
805
916
)
806
917
)
807
918
@@ -834,6 +945,11 @@ def main(args: argparse.Namespace):
834
945
result_json ["burstiness" ] = args .burstiness
835
946
result_json ["max_concurrency" ] = args .max_concurrency
836
947
948
+ if args .ramp_up_strategy is not None :
949
+ result_json ["ramp_up_strategy" ] = args .ramp_up_strategy
950
+ result_json ["ramp_up_start_rps" ] = args .ramp_up_start_rps
951
+ result_json ["ramp_up_end_rps" ] = args .ramp_up_end_rps
952
+
837
953
# Merge with benchmark result
838
954
result_json = {** result_json , ** benchmark_result }
839
955
@@ -859,7 +975,10 @@ def main(args: argparse.Namespace):
859
975
if args .max_concurrency is not None
860
976
else ""
861
977
)
862
- file_name = f"{ backend } -{ args .request_rate } qps{ max_concurrency_str } -{ base_model_id } -{ current_dt } .json" # noqa
978
+ if args .ramp_up_strategy is not None :
979
+ file_name = f"{ backend } -ramp-up-{ args .ramp_up_strategy } -{ args .ramp_up_start_rps } qps-{ args .ramp_up_end_rps } qps{ max_concurrency_str } -{ base_model_id } -{ current_dt } .json" # noqa
980
+ else :
981
+ file_name = f"{ backend } -{ args .request_rate } qps{ max_concurrency_str } -{ base_model_id } -{ current_dt } .json" # noqa
863
982
if args .result_filename :
864
983
file_name = args .result_filename
865
984
if args .result_dir :
@@ -1225,6 +1344,31 @@ def create_argument_parser():
1225
1344
"script chooses a LoRA module at random." ,
1226
1345
)
1227
1346
1347
+ parser .add_argument (
1348
+ "--ramp-up-strategy" ,
1349
+ type = str ,
1350
+ default = None ,
1351
+ choices = ["linear" , "exponential" ],
1352
+ help = "The ramp-up strategy. This would be used to "
1353
+ "ramp up the request rate from initial RPS to final "
1354
+ "RPS rate (specified by --ramp-up-start-rps and --ramp-up-end-rps). "
1355
+ "over the duration of the benchmark." ,
1356
+ )
1357
+ parser .add_argument (
1358
+ "--ramp-up-start-rps" ,
1359
+ type = int ,
1360
+ default = None ,
1361
+ help = "The starting request rate for ramp-up (RPS). "
1362
+ "Needs to be specified when --ramp-up-strategy is used." ,
1363
+ )
1364
+ parser .add_argument (
1365
+ "--ramp-up-end-rps" ,
1366
+ type = int ,
1367
+ default = None ,
1368
+ help = "The ending request rate for ramp-up (RPS). "
1369
+ "Needs to be specified when --ramp-up-strategy is used." ,
1370
+ )
1371
+
1228
1372
return parser
1229
1373
1230
1374
0 commit comments