Skip to content

Commit ff4c672

Browse files
yunfeng-scaleUbuntu
andauthored
Throughput benchmark script (#411)
* Move throughput benchmark script * support local * update * output csv * append * add concurrency * fix yaml --------- Co-authored-by: Ubuntu <ubuntu@ip-10-192-103-234.us-west-2.compute.internal>
1 parent 5de7187 commit ff4c672

File tree

2 files changed

+394
-0
lines changed

2 files changed

+394
-0
lines changed

scripts/requirements.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
numpy==1.24.4
2+
typer==0.9.0
3+
lorem-text==2.1
4+
transformers==4.36.0
5+
chardet==5.2.0

scripts/throughput_benchmarks.py

Lines changed: 389 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,389 @@
1+
import csv
2+
import json
3+
import os
4+
import queue
5+
import random
6+
import threading
7+
import time
8+
import traceback
9+
from dataclasses import dataclass
10+
from enum import Enum
11+
from typing import List, Optional
12+
13+
import numpy as np
14+
import requests
15+
import typer
16+
from lorem_text import lorem
17+
from transformers import AutoTokenizer
18+
19+
AUTH_USER_ID = os.getenv("AUTH_USER_ID")
20+
GATEWAY_URL = os.getenv("GATEWAY_URL")
21+
app = typer.Typer(name="throughput-benchmarks", add_completion=False)
22+
23+
MAX_CONTEXT_WINDOW = 4096
24+
25+
26+
@dataclass
27+
class BenchmarkConfig:
28+
def __init__(self, input_token_count, output_token_count_mean):
29+
self.input_token_count = input_token_count
30+
self.output_token_count_mean = output_token_count_mean
31+
# Here we assume 3x standard deviation is enough to cover the range of output token counts.
32+
# Also assume 3x stddev is rougly half of the mean.
33+
self.output_token_count_std = output_token_count_mean / 6.0
34+
35+
def __repr__(self) -> str:
36+
return f"BenchmarkConfig(input_token_count={self.input_token_count}, output_token_count_mean={self.output_token_count_mean}, output_token_count_std={self.output_token_count_std})"
37+
38+
39+
HF_MODEL_MAPPING = {
40+
"llama-2-7b": "meta-llama/Llama-2-7b-hf",
41+
"llama-2-13b": "meta-llama/Llama-2-13b-hf",
42+
}
43+
44+
45+
class InferenceFramework(Enum):
46+
TEXT_GENERATION_INFERENCE = "tgi"
47+
VLLM = "vllm"
48+
LIGHTLLM = "lightllm"
49+
TENSORRT_LLM = "tensorrt-llm"
50+
51+
@classmethod
52+
def from_value(cls, value):
53+
for member in cls:
54+
if member.value == value:
55+
return member
56+
raise ValueError(f"No member with value {value} in {cls.__name__}")
57+
58+
59+
def send_request(url, request, user=None):
60+
start = time.time()
61+
response = requests.post(
62+
url,
63+
json=request,
64+
auth=(user, ""),
65+
stream=True,
66+
)
67+
first_line = True
68+
for byte_payload in response.iter_lines():
69+
if first_line:
70+
time_to_first_token = time.time() - start
71+
first_line = False
72+
73+
# Skip line
74+
if byte_payload == b"\n":
75+
continue
76+
77+
payload = byte_payload.decode("utf-8")
78+
79+
# Event data
80+
if payload.startswith("data:"):
81+
payload_data = payload.lstrip("data:").rstrip("/n")
82+
payload_json = json.loads(payload_data)
83+
84+
return {
85+
"payload": payload_json,
86+
"time_to_first_token": time_to_first_token,
87+
"total_time": time.time() - start,
88+
}
89+
90+
91+
def pull_and_send_request_from_queue(
92+
model: str,
93+
request_queue: queue.Queue,
94+
result_queue: queue.Queue,
95+
use_localhost: bool,
96+
framework: InferenceFramework,
97+
local_port: int = 5005,
98+
):
99+
while not request_queue.empty():
100+
request = request_queue.get()
101+
if use_localhost:
102+
if framework == InferenceFramework.VLLM:
103+
response = send_request(f"http://localhost:{local_port}/stream", request)
104+
response["num_completion_tokens"] = response["payload"]["count_output_tokens"]
105+
else:
106+
raise NotImplementedError()
107+
else:
108+
response = send_request(
109+
f"{GATEWAY_URL}/v1/llm/completions-stream?model_endpoint_name={model}",
110+
request,
111+
AUTH_USER_ID,
112+
)
113+
response["num_completion_tokens"] = response["payload"]["output"][
114+
"num_completion_tokens"
115+
]
116+
117+
result_queue.put(response)
118+
119+
120+
def generate_request(
121+
framework: InferenceFramework, prompt: str, output_token_count: int, localhost: bool
122+
):
123+
if not localhost:
124+
return {"prompt": prompt, "max_new_tokens": output_token_count, "temperature": 0.0}
125+
126+
if framework == InferenceFramework.TEXT_GENERATION_INFERENCE:
127+
return {
128+
"parameters": {
129+
"do_sample": False,
130+
"max_new_tokens": output_token_count,
131+
"details": False,
132+
},
133+
"inputs": prompt,
134+
}
135+
elif framework == InferenceFramework.VLLM:
136+
return {
137+
"prompt": prompt,
138+
"max_tokens": output_token_count,
139+
"temperature": 0,
140+
"stream": True,
141+
}
142+
elif framework == InferenceFramework.LIGHTLLM:
143+
return {
144+
"parameters": {
145+
"do_sample": False,
146+
"max_new_tokens": output_token_count,
147+
},
148+
"inputs": prompt,
149+
}
150+
elif framework == InferenceFramework.TENSORRT_LLM:
151+
return {
152+
"max_tokens": output_token_count,
153+
"text_input": prompt,
154+
"bad_words": "",
155+
"stop_words": "",
156+
}
157+
else:
158+
raise NotImplementedError()
159+
160+
161+
def send_requests(
162+
model: str,
163+
prompt: str,
164+
output_token_counts: List[int],
165+
use_localhost: bool,
166+
concurrency: int,
167+
framework: InferenceFramework,
168+
local_port: int = 5005,
169+
):
170+
thread_results: queue.Queue = queue.Queue()
171+
requests_queue: queue.Queue = queue.Queue()
172+
for output_token_count in output_token_counts:
173+
request = generate_request(framework, prompt, output_token_count, use_localhost)
174+
requests_queue.put(request)
175+
threads = []
176+
for i in range(concurrency):
177+
thread = threading.Thread(
178+
target=pull_and_send_request_from_queue,
179+
args=(
180+
model,
181+
requests_queue,
182+
thread_results,
183+
use_localhost,
184+
framework,
185+
local_port,
186+
),
187+
)
188+
thread.start()
189+
threads.append(thread)
190+
191+
for thread in threads:
192+
thread.join()
193+
194+
results = []
195+
while not thread_results.empty():
196+
results.append(thread_results.get())
197+
198+
return results
199+
200+
201+
def generate_prompt(num, hf_model):
202+
random.seed(1)
203+
text = lorem.words(num // 2) # Roughly 2 tokens per lorem word
204+
tokenizer = AutoTokenizer.from_pretrained(hf_model)
205+
return tokenizer.decode(tokenizer.encode(text)[: num - 2])
206+
207+
208+
def generate_output_token_counts(mean, std, num, input_token_count):
209+
output = np.random.normal(mean, std, num).astype(int).tolist()
210+
211+
for i in range(len(output)):
212+
output[i] = min(output[i], MAX_CONTEXT_WINDOW - input_token_count)
213+
return output
214+
215+
216+
def run_benchmark(
217+
model: str,
218+
framework: InferenceFramework,
219+
hf_model: str,
220+
config: BenchmarkConfig,
221+
num_trials: int,
222+
use_localhost: bool,
223+
concurrency: int,
224+
verbose: bool,
225+
local_port: int,
226+
):
227+
prompt = generate_prompt(config.input_token_count, hf_model)
228+
229+
prompt_num_tokens = config.input_token_count
230+
231+
output_token_counts = generate_output_token_counts(
232+
config.output_token_count_mean,
233+
config.output_token_count_std,
234+
num_trials,
235+
config.input_token_count,
236+
)
237+
238+
start = time.time()
239+
results = send_requests(
240+
model,
241+
prompt,
242+
output_token_counts,
243+
use_localhost,
244+
concurrency,
245+
framework,
246+
local_port=local_port,
247+
)
248+
end = time.time()
249+
elapsed = end - start
250+
results = [result for result in results if result is not None]
251+
252+
num_sampled_tokens = sum([result["num_completion_tokens"] for result in results])
253+
num_prompt_tokens = prompt_num_tokens * len(results)
254+
n = len(results)
255+
time_to_process_prompt = []
256+
time_per_completion = []
257+
time_to_first_token = []
258+
inter_token_latency = []
259+
for result in results:
260+
avg_time_per_token = (result["total_time"] - result["time_to_first_token"]) / (
261+
result["num_completion_tokens"] - 1
262+
)
263+
time_to_first_token.append(result["time_to_first_token"])
264+
time_to_process_prompt.append(result["time_to_first_token"] - avg_time_per_token)
265+
time_per_completion.append(result["total_time"] - time_to_process_prompt[-1])
266+
inter_token_latency.append(avg_time_per_token)
267+
268+
total_num_tokens = num_sampled_tokens + num_prompt_tokens
269+
avg_prefill_time = sum(time_to_process_prompt) / n
270+
avg_completion_time = sum(time_per_completion) / n
271+
272+
statistics = {
273+
"concurrency": concurrency,
274+
"avg_prompt_throughput": num_prompt_tokens
275+
/ (elapsed * avg_prefill_time / (avg_prefill_time + avg_completion_time)),
276+
"avg_time_to_first_token": sum(time_to_first_token) / n,
277+
"avg_sampling_throughput": num_sampled_tokens
278+
/ (elapsed * avg_completion_time / (avg_prefill_time + avg_completion_time)),
279+
"avg_total_throughput": total_num_tokens / elapsed,
280+
"avg_per_session_sampling_throughput": num_sampled_tokens
281+
/ (elapsed * avg_completion_time / (avg_prefill_time + avg_completion_time))
282+
/ concurrency,
283+
"avg_inter_token_latency": sum(inter_token_latency) / n,
284+
"num_prompt_tokens": prompt_num_tokens,
285+
"avg_num_sampled_tokens": num_sampled_tokens / n,
286+
"elapsed_time": elapsed,
287+
"avg_prefill_time": avg_prefill_time,
288+
"avg_completion_time": avg_completion_time,
289+
"num_requests": num_trials,
290+
"num_successful_requests": n,
291+
"total_num_tokens": total_num_tokens,
292+
"total_num_sampled_tokens": num_sampled_tokens,
293+
}
294+
if verbose:
295+
print(f"Statistics: {statistics}")
296+
297+
# Sleep for 1 seconds between each benchmark.
298+
time.sleep(1)
299+
300+
return statistics
301+
302+
303+
@app.command()
304+
def run_benchmarks(
305+
model: str,
306+
framework: str,
307+
input_token_count: int,
308+
output_token_count_mean: int,
309+
num_trials: int = 50,
310+
output_file: Optional[str] = None,
311+
use_localhost: bool = False,
312+
concurrency: int = 1,
313+
verbose: bool = False,
314+
hf_model: Optional[str] = None,
315+
local_port: int = 5005,
316+
):
317+
"""Run benchmarks."""
318+
all_statistics = []
319+
config = BenchmarkConfig(input_token_count, output_token_count_mean)
320+
try:
321+
if verbose:
322+
print(f"Running benchmark for config {config}")
323+
if hf_model is None:
324+
if model not in HF_MODEL_MAPPING:
325+
raise ValueError(
326+
f"--hf-model must be specified for model {model} since it's not in default mapping."
327+
)
328+
hf_model = HF_MODEL_MAPPING[model]
329+
statistics = run_benchmark(
330+
model,
331+
InferenceFramework.from_value(framework),
332+
hf_model,
333+
config,
334+
num_trials,
335+
use_localhost,
336+
concurrency,
337+
verbose,
338+
local_port,
339+
)
340+
all_statistics.append(statistics)
341+
except Exception:
342+
traceback.print_exc()
343+
344+
if output_file is not None:
345+
header = all_statistics[0].keys()
346+
347+
with open(output_file, "a") as csvfile:
348+
csv_writer = csv.DictWriter(csvfile, fieldnames=header)
349+
csv_writer.writeheader()
350+
csv_writer.writerows(all_statistics)
351+
352+
353+
@app.command()
354+
def run_benchmarks_concurrency_range(
355+
model: str,
356+
framework: str,
357+
input_token_count: int,
358+
output_token_count_mean: int,
359+
num_trials_per_concurrency: int = 5,
360+
output_file: Optional[str] = None,
361+
use_localhost: bool = False,
362+
concurrency_min: int = 1,
363+
concurrency_max: int = 1,
364+
verbose: bool = False,
365+
hf_model: Optional[str] = None,
366+
local_port: int = 5005,
367+
):
368+
if output_file is not None:
369+
# Create empty file
370+
with open(output_file, "w"):
371+
pass
372+
for concurrency in range(concurrency_min, concurrency_max + 1):
373+
run_benchmarks(
374+
model,
375+
framework,
376+
input_token_count,
377+
output_token_count_mean,
378+
num_trials_per_concurrency * concurrency,
379+
output_file,
380+
use_localhost,
381+
concurrency,
382+
verbose,
383+
hf_model,
384+
local_port,
385+
)
386+
387+
388+
if __name__ == "__main__":
389+
app()

0 commit comments

Comments
 (0)