1
1
import asyncio
2
2
import codecs
3
- import json
4
3
from pathlib import Path
5
4
from typing import get_args
6
5
7
6
import click
7
+ from pydantic import ValidationError
8
8
9
9
from guidellm .backend import BackendType
10
- from guidellm .benchmark import ProfileType , benchmark_generative_text
10
+ from guidellm .benchmark import ProfileType
11
+ from guidellm .benchmark .entrypoints import benchmark_with_scenario
12
+ from guidellm .benchmark .scenario import GenerativeTextScenario , get_builtin_scenarios
11
13
from guidellm .config import print_config
12
14
from guidellm .preprocess .dataset import ShortPromptStrategy , process_dataset
13
15
from guidellm .scheduler import StrategyType
16
+ from guidellm .utils import cli as cli_tools
14
17
15
18
STRATEGY_PROFILE_CHOICES = set (
16
19
list (get_args (ProfileType )) + list (get_args (StrategyType ))
17
20
)
18
21
19
22
20
- def parse_json (ctx , param , value ): # noqa: ARG001
21
- if value is None :
22
- return None
23
- try :
24
- return json .loads (value )
25
- except json .JSONDecodeError as err :
26
- raise click .BadParameter (f"{ param .name } must be a valid JSON string." ) from err
27
-
28
-
29
- def parse_number_str (ctx , param , value ): # noqa: ARG001
30
- if value is None :
31
- return None
32
-
33
- values = value .split ("," ) if "," in value else [value ]
34
-
35
- try :
36
- return [float (val ) for val in values ]
37
- except ValueError as err :
38
- raise click .BadParameter (
39
- f"{ param .name } must be a number or comma-separated list of numbers."
40
- ) from err
41
-
42
-
43
23
@click .group ()
44
24
def cli ():
45
25
pass
46
26
47
27
48
28
@cli .command (
49
- help = "Run a benchmark against a generative model using the specified arguments."
29
+ help = "Run a benchmark against a generative model using the specified arguments." ,
30
+ context_settings = {"auto_envvar_prefix" : "GUIDELLM" },
31
+ )
32
+ @click .option (
33
+ "--scenario" ,
34
+ type = cli_tools .Union (
35
+ click .Path (
36
+ exists = True ,
37
+ readable = True ,
38
+ file_okay = True ,
39
+ dir_okay = False ,
40
+ path_type = Path , # type: ignore[type-var]
41
+ ),
42
+ click .Choice (get_builtin_scenarios ()),
43
+ ),
44
+ default = None ,
45
+ help = (
46
+ "The name of a builtin scenario or path to a config file. "
47
+ "Missing values from the config will use defaults. "
48
+ "Options specified on the commandline will override the scenario."
49
+ ),
50
50
)
51
51
@click .option (
52
52
"--target" ,
53
- required = True ,
54
53
type = str ,
55
54
help = "The target path for the backend to run benchmarks against. For example, http://localhost:8000" ,
56
55
)
@@ -61,20 +60,20 @@ def cli():
61
60
"The type of backend to use to run requests against. Defaults to 'openai_http'."
62
61
f" Supported types: { ', ' .join (get_args (BackendType ))} "
63
62
),
64
- default = "openai_http" ,
63
+ default = GenerativeTextScenario . get_default ( "backend_type" ) ,
65
64
)
66
65
@click .option (
67
66
"--backend-args" ,
68
- callback = parse_json ,
69
- default = None ,
67
+ callback = cli_tools . parse_json ,
68
+ default = GenerativeTextScenario . get_default ( "backend_args" ) ,
70
69
help = (
71
70
"A JSON string containing any arguments to pass to the backend as a "
72
71
"dict with **kwargs."
73
72
),
74
73
)
75
74
@click .option (
76
75
"--model" ,
77
- default = None ,
76
+ default = GenerativeTextScenario . get_default ( "model" ) ,
78
77
type = str ,
79
78
help = (
80
79
"The ID of the model to benchmark within the backend. "
@@ -83,7 +82,7 @@ def cli():
83
82
)
84
83
@click .option (
85
84
"--processor" ,
86
- default = None ,
85
+ default = GenerativeTextScenario . get_default ( "processor" ) ,
87
86
type = str ,
88
87
help = (
89
88
"The processor or tokenizer to use to calculate token counts for statistics "
@@ -93,16 +92,15 @@ def cli():
93
92
)
94
93
@click .option (
95
94
"--processor-args" ,
96
- default = None ,
97
- callback = parse_json ,
95
+ default = GenerativeTextScenario . get_default ( "processor_args" ) ,
96
+ callback = cli_tools . parse_json ,
98
97
help = (
99
98
"A JSON string containing any arguments to pass to the processor constructor "
100
99
"as a dict with **kwargs."
101
100
),
102
101
)
103
102
@click .option (
104
103
"--data" ,
105
- required = True ,
106
104
type = str ,
107
105
help = (
108
106
"The HuggingFace dataset ID, a path to a HuggingFace dataset, "
@@ -112,15 +110,16 @@ def cli():
112
110
)
113
111
@click .option (
114
112
"--data-args" ,
115
- callback = parse_json ,
113
+ default = GenerativeTextScenario .get_default ("data_args" ),
114
+ callback = cli_tools .parse_json ,
116
115
help = (
117
116
"A JSON string containing any arguments to pass to the dataset creation "
118
117
"as a dict with **kwargs."
119
118
),
120
119
)
121
120
@click .option (
122
121
"--data-sampler" ,
123
- default = None ,
122
+ default = GenerativeTextScenario . get_default ( "data_sampler" ) ,
124
123
type = click .Choice (["random" ]),
125
124
help = (
126
125
"The data sampler type to use. 'random' will add a random shuffle on the data. "
@@ -129,7 +128,6 @@ def cli():
129
128
)
130
129
@click .option (
131
130
"--rate-type" ,
132
- required = True ,
133
131
type = click .Choice (STRATEGY_PROFILE_CHOICES ),
134
132
help = (
135
133
"The type of benchmark to run. "
@@ -138,8 +136,7 @@ def cli():
138
136
)
139
137
@click .option (
140
138
"--rate" ,
141
- default = None ,
142
- callback = parse_number_str ,
139
+ default = GenerativeTextScenario .get_default ("rate" ),
143
140
help = (
144
141
"The rates to run the benchmark at. "
145
142
"Can be a single number or a comma-separated list of numbers. "
@@ -152,6 +149,7 @@ def cli():
152
149
@click .option (
153
150
"--max-seconds" ,
154
151
type = float ,
152
+ default = GenerativeTextScenario .get_default ("max_seconds" ),
155
153
help = (
156
154
"The maximum number of seconds each benchmark can run for. "
157
155
"If None, will run until max_requests or the data is exhausted."
@@ -160,6 +158,7 @@ def cli():
160
158
@click .option (
161
159
"--max-requests" ,
162
160
type = int ,
161
+ default = GenerativeTextScenario .get_default ("max_requests" ),
163
162
help = (
164
163
"The maximum number of requests each benchmark can run for. "
165
164
"If None, will run until max_seconds or the data is exhausted."
@@ -168,7 +167,7 @@ def cli():
168
167
@click .option (
169
168
"--warmup-percent" ,
170
169
type = float ,
171
- default = None ,
170
+ default = GenerativeTextScenario . get_default ( "warmup_percent" ) ,
172
171
help = (
173
172
"The percent of the benchmark (based on max-seconds, max-requets, "
174
173
"or lenth of dataset) to run as a warmup and not include in the final results. "
@@ -178,6 +177,7 @@ def cli():
178
177
@click .option (
179
178
"--cooldown-percent" ,
180
179
type = float ,
180
+ default = GenerativeTextScenario .get_default ("cooldown_percent" ),
181
181
help = (
182
182
"The percent of the benchmark (based on max-seconds, max-requets, or lenth "
183
183
"of dataset) to run as a cooldown and not include in the final results. "
@@ -212,7 +212,7 @@ def cli():
212
212
)
213
213
@click .option (
214
214
"--output-extras" ,
215
- callback = parse_json ,
215
+ callback = cli_tools . parse_json ,
216
216
help = "A JSON string of extra data to save with the output benchmarks" ,
217
217
)
218
218
@click .option (
@@ -222,15 +222,16 @@ def cli():
222
222
"The number of samples to save in the output file. "
223
223
"If None (default), will save all samples."
224
224
),
225
- default = None ,
225
+ default = GenerativeTextScenario . get_default ( "output_sampling" ) ,
226
226
)
227
227
@click .option (
228
228
"--random-seed" ,
229
- default = 42 ,
229
+ default = GenerativeTextScenario . get_default ( "random_seed" ) ,
230
230
type = int ,
231
231
help = "The random seed to use for benchmarking to ensure reproducibility." ,
232
232
)
233
233
def benchmark (
234
+ scenario ,
234
235
target ,
235
236
backend_type ,
236
237
backend_args ,
@@ -254,30 +255,53 @@ def benchmark(
254
255
output_sampling ,
255
256
random_seed ,
256
257
):
258
+ click_ctx = click .get_current_context ()
259
+
260
+ overrides = cli_tools .set_if_not_default (
261
+ click_ctx ,
262
+ target = target ,
263
+ backend_type = backend_type ,
264
+ backend_args = backend_args ,
265
+ model = model ,
266
+ processor = processor ,
267
+ processor_args = processor_args ,
268
+ data = data ,
269
+ data_args = data_args ,
270
+ data_sampler = data_sampler ,
271
+ rate_type = rate_type ,
272
+ rate = rate ,
273
+ max_seconds = max_seconds ,
274
+ max_requests = max_requests ,
275
+ warmup_percent = warmup_percent ,
276
+ cooldown_percent = cooldown_percent ,
277
+ output_sampling = output_sampling ,
278
+ random_seed = random_seed ,
279
+ )
280
+
281
+ try :
282
+ # If a scenario file was specified read from it
283
+ if scenario is None :
284
+ _scenario = GenerativeTextScenario .model_validate (overrides )
285
+ elif isinstance (scenario , Path ):
286
+ _scenario = GenerativeTextScenario .from_file (scenario , overrides )
287
+ else : # Only builtins can make it here; click will catch anything else
288
+ _scenario = GenerativeTextScenario .from_builtin (scenario , overrides )
289
+ except ValidationError as e :
290
+ # Translate pydantic valdation error to click argument error
291
+ errs = e .errors (include_url = False , include_context = True , include_input = True )
292
+ param_name = "--" + str (errs [0 ]["loc" ][0 ]).replace ("_" , "-" )
293
+ raise click .BadParameter (
294
+ errs [0 ]["msg" ], ctx = click_ctx , param_hint = param_name
295
+ ) from e
296
+
257
297
asyncio .run (
258
- benchmark_generative_text (
259
- target = target ,
260
- backend_type = backend_type ,
261
- backend_args = backend_args ,
262
- model = model ,
263
- processor = processor ,
264
- processor_args = processor_args ,
265
- data = data ,
266
- data_args = data_args ,
267
- data_sampler = data_sampler ,
268
- rate_type = rate_type ,
269
- rate = rate ,
270
- max_seconds = max_seconds ,
271
- max_requests = max_requests ,
272
- warmup_percent = warmup_percent ,
273
- cooldown_percent = cooldown_percent ,
298
+ benchmark_with_scenario (
299
+ scenario = _scenario ,
274
300
show_progress = not disable_progress ,
275
301
show_progress_scheduler_stats = display_scheduler_stats ,
276
302
output_console = not disable_console_outputs ,
277
303
output_path = output_path ,
278
304
output_extras = output_extras ,
279
- output_sampling = output_sampling ,
280
- random_seed = random_seed ,
281
305
)
282
306
)
283
307
@@ -316,7 +340,8 @@ def preprocess():
316
340
"Convert a dataset to have specific prompt and output token sizes.\n "
317
341
"DATA: Path to the input dataset or dataset ID.\n "
318
342
"OUTPUT_PATH: Path to save the converted dataset, including file suffix."
319
- )
343
+ ),
344
+ context_settings = {"auto_envvar_prefix" : "GUIDELLM" },
320
345
)
321
346
@click .argument (
322
347
"data" ,
@@ -340,15 +365,15 @@ def preprocess():
340
365
@click .option (
341
366
"--processor-args" ,
342
367
default = None ,
343
- callback = parse_json ,
368
+ callback = cli_tools . parse_json ,
344
369
help = (
345
370
"A JSON string containing any arguments to pass to the processor constructor "
346
371
"as a dict with **kwargs."
347
372
),
348
373
)
349
374
@click .option (
350
375
"--data-args" ,
351
- callback = parse_json ,
376
+ callback = cli_tools . parse_json ,
352
377
help = (
353
378
"A JSON string containing any arguments to pass to the dataset creation "
354
379
"as a dict with **kwargs."
0 commit comments